Optimizing Deep Learning Models with SAM

-

: Overparameterization, Generalizability, and SAM

The dramatic success of recent deep learning — especially within the domains of Computer Vision and Natural Language Processing — is built on “overparameterized” models: models with good enough parameters to memorize the training data perfectly. Functionally, a model will be diagnosed as overparameterized when it could possibly easily achieve a near-perfect training accuracy (near 100%) with near-zero training loss for a given task.

Nonetheless, the usefulness of such a model is determined by whether it performs well on the held-out test data drawn from the identical distribution because the training set, but unseen during training. This property is named “generalizability” — the power of a model to keep up performance on recent examples — and it is crucial for any deep learning model to be practically useful.

Classical Machine Learning theory tells us that overparameterized models should catastrophically overfit and subsequently generalize poorly. Nonetheless, probably the most surprising discoveries of the past decade is that models on this class often generalize remarkably well.

This highly counterintuitive phenomenon has been investigated in a series of papers, starting with the seminal works of Belkin et al. (2018) and Nakkiran et al. (2019), which demonstrated that there exists a “double descent” curve for generalizability: as model size increases, generalization first worsens (as classical theory predicts), then improves again beyond a critical threshold — provided the model is trained with the suitable optimization methods.

Figure 1: A schematic representation for the double descent behavior. Image generated by the writer with Gemini.

Figure 1 shows a cartoon of a double descent curve. The y-axis plots test error — a measure of generalizability, where lower error indicates higher generalization — while the x-axis shows the variety of model parameters. As model size increases, training error (dashed blue line) rapidly approaches zero, as expected.

The test error (solid blue line) exhibits a more interesting behavior: it initially decreases with model size — the primary descent, highlighted by the left red circle — after which rises to a peak on the interpolation threshold marked by the vertical dashed line, where the model has the worst generalization. Beyond this threshold, nonetheless, within the overparameterized regime, the test error decreases again — the second descent, highlighted by the correct red circle — and continues to say no as more parameters are added. That is the regime of interest for contemporary deep learning models. 

In Machine Learning, one finds the parameters of a model by minimizing a loss function on the training dataset. But does simply minimizing our favourite loss function — like cross-entropy — on the training dataset guarantee satisfactory generalization properties for the category of overparametrized models? The reply is — generally speaking — no! Whether one is enthusiastic about fine-tuning a pre-trained model or training a model from scratch, it’s important to optimize your training algorithm to be sure that you will have a sufficiently generalizable model. That is what makes the alternative of the optimizer a vital design alternative.

Sharpness-Aware-Minimization (SAM) — introduced in a paper by Foret et al. (2019) — is an optimizer designed to enhance generalizability of an overparameterized model. In this text, I present a pedagogical review of SAM that features:

  1. An intuitive understanding of how SAM works and why it improves generalization.
  2. A deep dive into the algorithm, explaining the important thing mathematical steps involved. 
  3. A PyTorch implementation of the optimizer class in a training loop, including a vital caveat for models with BatchNorm layers.
  4. A fast demonstration of the effectiveness of the optimizer in improving generalization on a picture classification task with a ResNet-18 model.

The whole code utilized in this text will be present in this Github repo — be happy to mess around with it!

The Notion of Sharpness

To start with, allow us to attempt to get an intuitive sense of why simply minimizing the loss function will not be enough for optimal generalization. 

A useful picture to keep in mind is that of the loss landscape. For a big overparametrized model, the loss landscape has multiple local and global minima. The local geometries around such minima can vary significantly along the landscape. For instance, two minima could have nearly similar loss values, yet differ dramatically of their local geometry: one could also be sharp (narrow valley) while the opposite is flat (wide valley).

One formal measure for comparing these local geometries is “sharpness”. At any given point w within the loss landscape with loss function L(w), sharpness S(w) is defined as:

Let me unpack the definition. Imagine you’re at some extent w within the loss landscape and also you perturb the parameters such that the brand new parameter all the time lies inside a ball of radius ρ with center w. Sharpness is then defined because the maximal change within the loss function inside this family of perturbations. Within the literature, it’s also known as the worst-direction sharpness for obvious reasons.

One can readily see that for a pointy minimum — a steep, narrow valley — the worth of the loss function will change dramatically with small perturbations in certain directions and result in a high value for sharpness. For a flat minimum however — a large valley — the worth of loss function will change relatively slowly with small perturbations and result in a lower value for sharpness. Due to this fact, sharpness gives a measure of flatness for a given minimum within the loss landscape. 

There exists a deep connection between the local geometry of a minimum — especially the sharpness measure— and the generalization property of the resultant model. During the last decade, a major amount of theoretical and empirical research has gone into clarifying this connection. As an illustration — because the paper by Keskar et al. (2016) points out — global minima with similar values of the loss function can have significantly different generalization properties depending on their sharpness measures.

The fundamental lesson that appears to be emerge from these studies is: flatter (less sharp) minima are positively correlated with higher generalization of models. Particularly, the model should avoid getting stuck in a pointy minima during training if it has to generalize well. Due to this fact, for training a model with good generalization, one must be sure that the optimization procedure not only minimizes the loss function but in addition seeks to maximise the flatness (or equivalently minimize the sharpness) of the minima

That is precisely the issue that the SAM optimizer is designed to unravel, and that is what we turn to in the following section. 

A fast aside: note that the above picture gives a conceptual explanation of why an overparameterized model can potentially avoid the issue of overfitting. It’s because a big model has a wealthy loss landscape which provides a multiplicity of flat global minima with excellent generalization properties.

The Sharpness-Aware Minimization (SAM) Algorithm

Allow us to recall the usual optimization of a model. It involves finding model parameters that minimize a given loss function computed over a mini-batch B. At every time-step, one computes the gradient of the loss with respect to the parameters, and updates the parameters based on the rule: 

Unlike SGD or Adam, SAM doesn’t minimize L directly. As an alternative, at a given point within the loss landscape, it first scans its neighborhood of a given size ρ and finds the perturbation that maximizes the loss function. Within the second step, it minimizes this maximum loss function. This permits the optimizer to search out parameters that lie in neighborhoods with uniformly low loss value, which ends up in smaller sharpness values and flatter minima.

Let’s discuss the procedure in a little bit more detail. The loss function for the SAM optimizer is:

where ρ denotes the upper certain on the scale of the perturbations. The perturbation that maximizes the function L (often called adversarial perturbation because it maximizes the traditional loss) will be found by noting that:

where the second equality is an approximation obtained by Taylor-expanding the perturbed function in step one, and the last equality follows from the ϵ-independence of the primary term in square brackets within the previous step. This last equality will be solved for the adversarial perturbation as follows:

Plugging this back within the equation for the SAM loss, one can compute the gradients of the SAM loss to the leading order in derivatives of ϵ:

That is essentially the most crucial equation for the optimization procedure. To the leading order in derivatives of ϵ, the gradients of the SAM loss function will be approximated by the gradients of the traditional loss function evaluated on the adversarially perturbed point. Using the above formula for the gradients, one can now execute the usual optimizer step:

This completes one full SAM iteration. Next, allow us to translate the algorithm from English to PyTorch.

PyTorch Implementation in a Training Loop

An illustrative example of a training loop with a SAM optimizer is given within the code block sam_training_loop.py. For concreteness, we have now chosen a generic image classification problem, but the identical structure broadly holds for a big selection of Computer Vision and NLP tasks. The SAM optimizer class is shown within the code block sam_optimizer_class.py.

Note that defining a SAM optimizer requires specifying two pieces of knowledge:

  1. A base optimizer (like SGD or Adam), since SAM involves a typical optimizer step ultimately.
  2. A hyperparameter ρ, which puts an upper certain on the scale of the admissible perturbations.

A single iteration of the optimizer involves two forward passes and two backward passes. Let’s trace out the important thing steps of the code in sam_training_loop.py:

  1. Line 5 computes the loss function L(w, B) for the present mini-batch B — the primary forward pass.
  2. Line 6 computes the gradients of the loss function L(w, B) — the primary backward pass
  3. Line 7 calls the function sam_optimizer.first_step from the SAM optimizer class (see below) that computes the adversarial perturbation using the formula discussed above, and perturbs the weights of the model as discussed before.
  4. Line 10 computes the loss function for the perturbed model — the second forward pass
  5. Line 11 computes the gradients of the loss function for the perturbed model— the second backward pass
  6. Line 12 calls the function sam_optimizer.second_step from the optimizer class (see below) that restores the weights to w_t after which uses the bottom optimizer to update the weights w_t using the gradients computed on the perturbed point. 

A Caveat: SAM with BatchNorm

There may be a vital point that one must take into accout while deploying SAM in a training loop if the model has any module that features batch-normalization layers. During training, BatchNorm implements the normalization using the present batch statistics and updates the running statistics at every forward pass. For evaluation, it uses the running statistics. 

Now, as we saw above, SAM involves two forward passes per iteration. For the primary pass, BatchNorm works in the usual fashion. In the course of the second pass, nonetheless, we’re using perturbed weights to compute loss, and the naive training function within the code block sam_training_loop.py will allow the BatchNorm layers to update the running statistics throughout the second pass as well. That is undesirable since the running statistics should only reflect the behavior of the model, not the perturbed model which is simply an intermediate step for computing gradients. Due to this fact, one has to explicitly disable the running statistics update throughout the second pass and enable it before the following iteration.

For this purpose, we’ll use two explicit functions disable_bn_stats and enable_bn_stats within the training loop — easy examples of such functions are shown in code block running_stat.py — they toggle the track_running_stats parameter (line 4 and line 9) of BatchNorm function in PyTorch. The modified training loop is given within the code block mod_train.py

Demo: Image classification with ResNet-18

Finally, let’s exhibit how the SAM optimization improves the generalization of a model in a concrete example. We’ll consider a picture classification problem using the Fashion-MNIST dataset (MIT License): it consists of 60,000 training images and 10,000 testing images across 10 distinct, mutually exclusive classes, where each image is grayscale with 28*28 pixels.

Because the classifier model, we’ll select a PreAct ResNet-18 with none pre-training. While a discussion on the precise ResNet-18 architecture shouldn’t be very relevant for our purpose, allow us to recall that the model consists of a sequence of constructing blocks, each of which is made up of convolutional layers, BatchNorm layers, ReLU activation with skipped connections. The PreAct (pre-activation) indicates that the activation function (ReLU) comes before the convolutional layer in each block. For a typical ResNet-18, it’s the opposite way round. I might refer the reader to the paper — He et al. (2015) — for more details on the architecture.

What is very important to notice, nonetheless, is that this model has about 11.2 million parameters, and subsequently from the attitude of classical Machine Learning, it’s an overparameterized model with the parameter-to-sample ratio being about 186:1. Also, for the reason that model includes BatchNorm layers, we have now to watch out about disabling the running statistics for the second pass, while using SAM.

We are actually able to perform the next experiment. We train the model on the Fashion-MNIST dataset with the usual SGD optimizer first after which with the SAM optimizer using the identical SGD as the bottom optimizer. We’ll consider an easy setup with a hard and fast learning rate lr=0.05 and with the momentum and the weight-decay each set to zero. The hyperparameter ρ in SAM is about to 0.05. All runs are performed on a single A100 GPU. 

Since each SAM weight update requires two backpropagation steps — one to compute the perturbations and one other to compute the ultimate gradients — for a good comparison each non-SAM training run must execute twice as many epochs as each SAM training run. We’ll subsequently have to check a metric from one epoch of SAM training run to a metric from two epochs of non-SAM training run. We’ll call this a “standardized epoch” and a metric recorded at standardized epochs can be labelled as metric_st. We’ll restrict the experiment to 150 standardized epochs, which suggests the SAM training runs for 150 epochs and the non-SAM training runs for 300 epochs. We’ll train the SAM-optimized model for a further 50 epochs to get an idea of how the model behaves on longer training. 

In trying to ascertain which optimizer gives higher generalization, we’ll compare the next two metrics after each standardized epoch of coaching:

  1. Test accuracy: Performance of the model on the test dataset.
  2. Generalizability gap: Difference between the training accuracy and test accuracy.

The test accuracy is an absolute measure of how well the model generalizes after a certain number of coaching epochs. The generalizability gap, however, is a diagnostic that tells you ways much a model is overfitting at a given stage of coaching. 

Allow us to begin by comparing the training_loss_st and training_accuracy_st graphs, as shown in Figure 3. The model with SGD reaches near-zero loss and shut to 99% training accuracy inside 150 epochs, as expected of an overparametrized model. It is clear that SAM trains slowly in comparison with SGD and takes more standardized epochs to succeed in a near-perfect training accuracy. This is clear from the indisputable fact that the training loss in addition to the training accuracy continues to enhance as one trains the SAM-optimized model for more epochs beyond the stipulated 150.

Figure 3. Comparison of the standardized training losses and training accuracies.

Test accuracy. The graphs in Figure 4 compares the test accuracies for the 2 cases after each standardized epoch.

Figure 4. Comparison of the standardized test accuracies.

The SGD-optimized model reaches 92% test accuracy around epoch 50 and plateaus around that value for the following 100 epochs. The SAM-optimized model generalizes poorly within the initial phase of the training — until around 80 epochs — as evident from the lower test accuracies on this phase in comparison with the SGD graph. Nonetheless, around epoch 80, it catches up with the SGD graph and eventually surpasses it by a skinny margin.

For this specific run, at the tip of 150 epochs, the test accuracy for SAM stands at test_SAM = 92.5%, while that for SGD is test_SGD = 92.0%. Note that that is despite the indisputable fact that the SAM-trained model has a much lower training accuracy and training loss at this stage. If one trains the SAM-model for one more 50 epochs, the test accuracy improves barely to 92.7%.

Generalization Gap. The evolution of the generalization gap after each standardized epoch in course of the training process is shown in Figure 5.

Figure 5: Comparison of the generalization gap.

The gap for the SGD model grows steadily with training and after 150 epochs reaches gap_SGD=6.8%, while for SAM it grows way more slowly and reaches gap_SAM= 2.3%. On further training for one more 50 epochs, the gap for SAM climbs to around 3%, however it continues to be much lower in comparison with the SGD value.

While the difference in test accuracies is small between the 2 optimizers for the Fashion-MNIST dataset, there may be a non-trivial difference within the generalization gaps, which demonstrates that optimizing with SAM leads to raised generalization.

Concluding Remarks

In this text, I presented a pedagogical review of SAM as an optimizer that significantly improves the generalization of overparameterized deep learning models. We discussed the motivation and intuition behind SAM, walked through a step-by-step breakdown of the algorithm, and studied an easy example demonstrating its effectiveness in comparison with a typical SGD optimizer.

There are several interesting facets of SAM that I didn’t have a likelihood to cover here. Let me briefly mention two of them. First, as a practical tool, SAM is especially useful for fine-tuning pre-trained models on small datasets — something explored intimately by Foret et al.(2019) for CNN-type architectures and in lots of subsequent works for more general architectures. Second, since we opened our discussion with the connection between flat minima within the loss landscape and generalization, it’s natural to ask whether a SAM-trained model — which demonstrably improves generalizability — does indeed converge to a flatter minimum. It is a non-trivial query, requiring a careful evaluation of the Hessian spectrum of the trained model and a comparison with its SGD-trained counterpart. But that’s a story for one more day!


Thanks for reading! If you will have enjoyed the article, and would have an interest to read more pedagogical articles on deep learning, do follow me on Medium and LinkedIn. Unless otherwise stated, all images and graphs utilized in this text were generated by the writer.

ASK ANA

What are your thoughts on this topic?
Let us know in the comments below.

0 0 votes
Article Rating
guest
0 Comments
Oldest
Newest Most Voted
Inline Feedbacks
View all comments

Share this article

Recent posts

0
Would love your thoughts, please comment.x
()
x