Skip to content

Sharpness-Aware Minimization (SAM)

Description

SAM seeks parameters that lie in neighborhoods with uniformly low loss values, leading to better generalization.

Its key features are the following:

  • Looks for "flat" minima instead of sharp ones
  • Improves robustness against input perturbations
  • Generally provides better generalization than standard SGD

Example

import torch

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05):
        self.rho = rho
        self.base_optimizer = base_optimizer(self.param_groups)

    def step(self):
        # First forward-backward pass
        grad_norm = self._grad_norm()
        scale = self.rho / (grad_norm + 1e-12)

        # Perturb weights
        for group in self.param_groups:
            for p in group['params']:
                e_w = p.grad * scale
                p.add_(e_w)

        # Second forward-backward pass
        self.base_optimizer.step()