Brief Review — Sharpness-Aware Minimization for Efficiently Improving Generalization

SAM, Simultaneously Minimizing Training Loss and Sharpness Loss, Using Loss Landscape

Sik-Ho Tsang
5 min readFeb 1, 2023
(left) Error rate reduction obtained by switching to SAM. Each point is a different dataset / model / data augmentation. (middle) A sharp minimum to which a ResNet trained with SGD converged. (right) A wide minimum to which the same ResNet trained with SAM converged.

Sharpness-Aware Minimization for Efficiently Improving Generalization,
Sharpness-Aware Minimization (SAM), by Google Research, and Blueshift, Alphabet,
2021 ICLR, Over 400 Citations (Sik-Ho Tsang @ Medium)
Image Classification

  • A novel, effective procedure for instead simultaneously minimizing training loss value and loss for sharpness.
  • Sharpness-Aware Minimization (SAM), seeks parameters that lie in neighborhoods having uniformly low loss.

Outline

  1. Sharpness-Aware Minimization (SAM)
  2. Results

1. Sharpness-Aware Minimization (SAM)

1.1. Motivations

  • LD(w) is the population loss that we want to minimize while LS(w) is the training set loss that we can only have for minimization.
  • Typical optimization approaches can easily result in suboptimal performance at test time. In particular, for modern models, LS(w) is typically non-convex in w, with multiple local and even global minima that may yield similar values of LS(w) while having significantly different generalization performance (i.e., significantly different values of LD(w)).

Motivated by the connection between sharpness of the Loss Landscape and generalization, SAM is proposed using Sharpness for minimization.

1.2. SAM

1.2.1. Proof

  • Rather than seeking out parameter values w that simply have low training loss value LS(w), parameter values are seek out, whose entire neighborhoods have uniformly low training loss value (equivalently, neighborhoods having both low loss and low curvature).
  • The generalization ability is bounded in terms of neighborhood-wise training loss. For any ρ>0, with high probability over training set S generated from distribution D:
  • where h is a strictly increasing function (under some technical conditions on LD(w)).
  • To make explicit the sharpness term, the right hand side of the inequality above can be rewritten as:

The term in square brackets [] captures the sharpness of LS at w by measuring how quickly the training loss can be increased by moving from w to a nearby parameter value; this sharpness term is then summed with the training loss value itself LS(w) and a regularizer on the magnitude of w (h).

  • Standard L2 regularization term is used for h.
  • The following Sharpness-Aware Minimization (SAM) problem is formulated:
  • In the figure at the top, the Loss Landscape for a model that converged to minima found by minimizing either LS(w) or LSAMS(w), illustrating that the sharpness-aware loss prevents the model from converging to a sharp minimum.

1.2.2. Approximation for Efficient Computation

  • In order to minimize LSAMS(w), an efficient and effective approximation to ∇wLSAMS(w) is derived.
  • The inner maximization problem is first approximated via a first-order Taylor expansion of LS(w+ε) w.r.t. around 0, obtaining:
  • In turn, the value ^ε(w) that solves this approximation is given by the solution to a classical dual norm problem:
  • where 1/p+1/q=1. Substituting back into equation (1) of LSAMS(w) and differentiating:
  • This approximation to ∇wLSAMS(w) can be straightforwardly computed via automatic differentiation, as implemented in common libraries such as JAX, TensorFlow, and PyTorch.
  • To further accelerate the computation, the second-order terms are dropped, obtaining the final gradient approximation:
  • Indeed, including the second-order terms, in that initial experiment, surprisingly degrades performance.
Left: SAM algorithm, Right: Schematic of the SAM parameter update.
  • The above algorithm shows the overall algorithm process.

2. Results

2.1. Image Classification from Scratch

Results for SAM on state-of-the-art models on CIFAR-{10, 100}

SAM improves generalization across all settings evaluated for CIFAR-10 and CIFAR-100.

  • For instance, applying SAM to a PyramidNet with ShakeDrop regularization yields 10.3% error on CIFAR-100, which is, to authors’ knowledge, a new state-of-the-art on this dataset without the use of additional data, at that moment.
Test error rates for ResNets trained on ImageNet, with and without SAM.

SAM again consistently improves performance.

  • Note that SAM enables increasing the number of training epochs while continuing to improve accuracy without overfitting.

2.2. Finetuning

Top-1 error rates for finetuning EfficientNet-b7 (left; ImageNet pretraining only) and EfficientNet-L2 (right; pretraining on ImageNet plus additional data, such as JFT) on various downstream tasks.
  • SAM uniformly improves performance relative to finetuning without SAM.

Furthermore, in many cases, SAM yields novel state-of-the-art performance, including 0.30% error on CIFAR-10, 3.92% error on CIFAR-100, and 11.39% error on ImageNet.

2.3. Robustness to Label Noise

Test accuracy on the clean test set for models trained on CIFAR-10 with noisy labels. Lower block is our implementation, upper block gives scores from the literature.
  • SAM seeks out model parameters that are robust to perturbations suggests SAM’s potential to provide robustness to noise in the training set.

SAM provides a high degree of robustness to label noise, on par with that provided by state-of-the art procedures that specifically target learning with noisy labels.

  • Indeed, simply training a model with SAM outperforms all prior methods specifically targeting label noise robustness, with the exception of MentorMix (Jiang et al., 2019).

Reference

[2021 ICLR] [Sharpness-Aware Minimization (SAM)]
Sharpness-Aware Minimization for Efficiently Improving Generalization

1.1. Image Classification

1989–2021 … [Sharpness-Aware Minimization (SAM)] 2022 [ConvNeXt] [PVTv2] [ViT-G] [AS-MLP] [ResTv2] [CSWin Transformer] [Pale Transformer] [Sparse MLP] [MViTv2] [S²-MLP] [CycleMLP] [MobileOne] [GC ViT] [VAN] 2023 [Vision Permutator (ViP)]

==== My Other Previous Paper Readings ====

--

--

Sik-Ho Tsang

PhD, Researcher. I share what I learn. :) Linktree: https://linktr.ee/shtsang for Twitter, LinkedIn, etc.