How integrating BatchNorm in a normal Vision transformer architecture results in faster convergence and a more stable network
Introduction
The Vision Transformer (ViT) is the primary purely self-attention-based architecture for image classification tasks. While ViTs do perform higher than the CNN-based architectures, they require pre-training over very large datasets. In an try and search for modifications of the ViT which can result in faster training and inference — especially within the context of medium-to-small input data sizes — I started exploring in a previous article ViT-type models which integrate Batch Normalization (BatchNorm) of their architecture. BatchNorm is understood to make a deep neural network converge faster — a network with BatchNorm achieves higher accuracy in comparison with the base-line model when trained over the identical variety of epochs. This in turn hurries up training. BatchNorm also acts as an efficient regularizer for the network, and allows a model to be trained with a better learning rate. The primary goal of this text is to analyze whether introducing BatchNorm can result in similar effects in a Vision Transformer.
For the sake of concreteness, I’ll deal with a model where a BatchNorm layer is introduced within the Feedforward Network (FFN) inside the transformer encoder of the ViT, and the LayerNorm preceding the FFN is omitted. In every single place else within the transformer — including the self-attention module — one continues to make use of LayerNorm. I’ll check with this version of ViT as ViTBNFFN — Vision Transformer with BatchNorm within the Feedforward Network. I’ll train and test this model on the MNIST dataset with image augmentations and compare the Top-1 accuracy of the model with that of the usual ViT over numerous epochs. I’ll select equivalent architectural configuration for the 2 models (i.e. equivalent width, depth, patch size and so forth) in order that one can effectively isolate the effect of the BatchNorm layer.
Here’s a fast summary of the primary findings:
- For an affordable selection of hyperparameters (learning rate and batch size), ViTBNFFN does converge faster than ViT, provided the transformer depth (i.e variety of layers within the encoder) is sufficiently large.
- As one increases the educational rate, ViTBNFFN seems to be more stable than ViT, especially at larger depths.
I’ll open with a temporary discussion on BatchNorm in a deep neural network, illustrating a few of the properties mentioned above using a concrete example. I’ll then discuss intimately the architecture of the model ViTBNFFN. Finally, I’ll take a deep dive into the numerical experiments that study the results of BatchNorm within the Vision Transformer.
The Dataset : MNIST with Image Augmentation
Allow us to begin by introducing the augmented MNIST dataset which I’ll use for all of the numerical experiments described in this text. The training and test datasets are given by the function get_datasets_mnist() as shown in Code Block 1.
The necessary lines of code are given in lines 5–10, which list the small print of the image augmentations I’ll use. I actually have introduced three different transformations:
- RandomRotation(degrees=20) : A random rotation of the image with the range of rotation in degrees being (-20, 20).
- RandomAffine(degrees = 0, translate = (0.2, 0.2)) : A random affine transformation, where the specification translate = (a, b) implies that the horizontal and vertical shifts are sampled randomly within the intervals [- image_width × a, image_width × a] and [-image_height × b, image_height × b] respectively. The degrees=0 statement deactivates rotation since now we have already taken it under consideration via random rotation. One can even include a scale transformation here but we implement it using the zoom out operation.
- RandomZoomOut(0,(2.0, 3.0), p=0.2) : A random zoom out transformation, which randomly samples the interval (2.0, 3.0) for a float r and outputs a picture with output_width = input_width × r and output_height = input_height × r. The float p is the probability that the zoom operation is performed. This transformation is followed by a Resize transformation in order that the ultimate image is again 28 × 28.
Batch Normalization in a Deep Neural Network
Allow us to give a fast review of how BatchNorm improves the performance of a deep neural network. Suppose zᵃᵢ denotes the input for a given layer of a deep neural network, where a is the batch index which runs from a=1,…, Nₛ and that i is the feature index running from i=1,…, C. The BatchNorm operation then involves the next steps:
- For a given feature index i, one first computes the mean and the variance over the batch of size Nₛ i.e.
2. One normalizes the input using the mean and variance computed above (with ϵ being a small positive number):
3. Finally, one shifts and rescales the normalized input for each feature i:
where there is no such thing as a summation over the index i, and the parameters (γᵢ, βᵢ) are trainable.
Consider a deep neural network for classifying the MNIST dataset. I’ll select a network consisting of three fully-connected hidden layers, with 100 activations each, where each hidden layer is endowed with a sigmoid activation function. The last hidden layer feeds right into a classification layer with 10 activations corresponding to the ten classes of the MNIST dataset. The input to this neural network is a Second-tensor of shape b × 28² — where b is the batch size and every 28 × 28 MNIST image is reshaped right into a 28²-dimensional vector. On this case, the feature index runs from i=1, …, 28².
This model is comparable to the one discussed in the unique BatchNorm paper — I’ll check with this model as DNN_d3. One may consider a version of this model where one adds a BatchNorm layer before the sigmoid activation function in each hidden layer. Allow us to call the resultant model DNNBN_d3. The thought is to grasp how the introduction of the BatchNorm layer affects the performance of the network.
To do that, allow us to now train and test the 2 models on the MNIST dataset described above, with CrossEntropyLoss() because the loss function and the Adam optimizer, for 15 epochs. For a learning rate lr=0.01 and a training batch size of 100 (we decide a test batch size of 5000), the test accuracy and the training loss for the models are given in Figure 1.
Evidently, the introduction of BatchNorm makes the network converge faster — DNNBN achieves a better test accuracy and lower training loss. BatchNorm can subsequently speed up training.
What happens if one increases the educational rate? Generally speaking, a high learning rate might result in gradients blowing up or vanishing, which might render the training unstable. Specifically, larger learning rates will result in larger layer parameters which in turn give larger gradients during backpropagation. BatchNorm, nonetheless, ensures that the backpropagation through a layer isn’t affected by a scaling transformation of the layer parameters (see Section 3.3 of this paper for more details). This makes the network significantly more proof against instabilities arising out of a high learning rate.
To display this explicitly for the models at hand, allow us to train them at a much higher learning rate lr=0.1 — the test accuracy and the training losses for the models on this case are given in Figure 2.
The high learning rate manifestly renders the DNN unstable. The model with BatchNorm, nonetheless, is perfectly well-behaved! A more instructive technique to visualize this behavior is to plot the accuracy curves for the 2 learning rates in a single graph, as shown in Figure 3.
While the model DNN_d3 stops training on the high learning rate, the impact on the performance of DNNBN_d3 is significantly milder. BatchNorm subsequently allows one to coach a model at a better learning rate, providing one more technique to speed up training.
The Model ViTBNFFN : BatchNorm within the FeedForward Network
Allow us to begin by briefly reviewing the architecture of the usual Vision Transformer for image classification tasks, as shown within the schematic diagram of Figure 4. For more details, I refer the reader to my previous article or certainly one of the various excellent reviews of the subject in Towards Data Science.
Functionally, the architecture of the Vision Transformer could also be divided into three primary components:
- Embedding layer : This layer maps a picture to a “sentence” — a sequence of tokens, where each token is a vector of dimension dₑ (the embedding dimension). Given a picture of size h × w and c color channels, one first splits it into patches of size p × p and flattens them — this offers (h × w)/p² flattened patches (or tokens) of dimension dₚ = p² × c, that are then mapped to vectors of dimension dₑ using a learnable linear transformation. To this sequence of tokens, one adds a learnable token — the CLS token — which is isolated at the tip for the classification task. Schematically, one has:
Finally, to this sequence of tokens, one adds a learnable tensor of the identical shape which encodes the positional embedding information. The resultant sequence of tokens is fed into the transformer encoder. The input to the encoder is subsequently a 3d tensor of shape b × N × dₑ — where b is the batch size, N is the variety of tokens including the CLS token, and dₑ is the embedding dimension.
2. Transformer encoder : The transformer encoder maps the sequence of tokens to a different sequence of tokens with the identical number and the identical shape. In other words, it maps the input 3d tensor of shape b × N × dₑ to a different 3d tensor of the identical shape. The encoder can have L distinct layers (defined because the depth of the transformer) where each layer is made up of two sub-modules as shown in Figure 5— the multi-headed self-attention (MHSA) and the FeedForward Network (FFN).
The MHSA module implements a non-linear map on the 3d tensor of shape b × N × dₑ to a 3d tensor of the identical shape which is then fed into the FFN as shown in Figure 2. That is where information from different tokens get mixed via the self-attention map. The configuration of the MHSA module is fixed by the variety of heads nₕ and the top dimension dₕ.
The FFN is a deep neural network with two linear layers and a GELU activation in the center as shown in Figure 6.
The input to this sub-module is a 3d tensor of of shape b × N × dₑ. The linear layer on the left transforms it to a 3d tensor of shape b × N × d_mlp, where d_mlp is the hidden dimension of the network. Following the non-linear GELU activation, the tensor is mapped to a tensor of the unique shape by the second layer.
3. MLP Head : The MLP Head is a fully-connected network that maps the output of the transformer encoder — 3d tensor of shape b × N × dₑ — to a Second tensor of shape b × d_num where d_num is the variety of classes within the given image classification task. This is finished by first isolating the CLS token from the input tensor after which putting it through the connected network.
The model ViTBNFFN has the identical architecture as described above with two differences. Firstly, one introduces a BatchNorm Layer within the FFN of the encoder between the primary linear layer and the GELU activation as shown in Figure 7. Secondly, one removes the LayerNorm preceding the FFN in the usual ViT encoder (see Figure 5 above).
Because the linear transformation acts on the third dimension of the input tensor of shape b × N × dₑ , we must always discover dₑ because the feature dimension of the BatchNorm. The PyTorch implementation of the brand new feedforward network is given in Code Block 2.
The built-in BatchNorm class in PyTorch all the time takes the primary index of a tensor because the batch index and the second index because the feature index. Due to this fact, one needs to remodel our 3d tensor with shape b × N × dₑ to a tensor of shape b × dₑ × N before applying BatchNorm, and reworking it back to b × N × dₑ afterwards. As well as, I actually have used the Second BatchNorm class (because it is barely faster than the 1d BatchNorm). This requires promoting the 3d tensor to a 4d tensor of shape b × dₑ × N × 1 (line 16) and reworking it back (line 18) to a 3d tensor of shape b × N × dₑ. One can use the 1d BatchNorm class without changing any of the outcomes presented within the section.
The Experiment
With a set learning rate and batch size, I’ll train and test the 2 models — ViT and ViTBNFFN — on the augmented MNIST dataset for 10 epochs and compare the Top-1 accuracies on the validation dataset. Since we’re fascinated with understanding the results of BatchNorm, we could have to check the 2 models with equivalent configurations. The experiment will likely be repeated at different depths of the transformer encoder keeping the remainder of the model configuration unchanged. The precise configuration for the 2 models that I exploit on this experiment is given as follows :
- Embedding layer: An MNIST image is a grey-scale image of size 28× 28. The patch size is p= 7, which means that the variety of tokens is 16 + 1 =17 including the CLS token. The embedding dimension is dₑ = 64.
- Transformer encoder: The MHSA submodule has nₕ = 8 heads with head dimension dₕ=64. The hidden dimension of the FFN is d_mlp = 128. The depth of the encoder will likely be the one variable parameter on this architecture.
- MLP head: The MLP head will simply consist of a linear layer.
The training and testing batch sizes will likely be fixed at 100 and 5000 respectively for all of the epochs, with CrossEntropyLoss() because the loss function and Adam optimizer. The dropout parameters are set to zero in each the embedding layer in addition to the encoder. I actually have used the NVIDIA L4 Tensor Core GPU available at Google Colab for all of the runs, which have been recorded using the tracking feature of MLFlow.
Allow us to start by training and testing the models at the educational rate lr= 0.003. Figure 8 below summarizes the 4 graphs which plot the accuracy curves of the 2 models at depths d=4, 5, 6 and seven respectively. In these graphs, the notation ViT_dn (ViTBNFFN_dn) denotes ViT (ViTBNFFN) with depth of the encoder d=n and the remainder of the model configuration being the identical as specified above.
For d= 4 and d= 5 (the highest row of graphs), the accuracies of the 2 models are comparable — for d=4 (top left) ViT does somewhat higher, while for d=5 (top right) ViTBNFFN surpasses ViT marginally. For d < 4, the accuracies remain comparable. Nonetheless, for d=6 and d=7 (the underside row of graphs), ViTBNFFN does significantly higher than ViT. One can check that this qualitative feature stays the identical for any depth d ≥ 6.
Allow us to repeat the experiment at a rather higher learning rate lr = 0.005. The accuracy curves of the 2 models at depths d=1, 2, 3 and 4 respectively are summarized in Figure 9.
For d= 1 and d= 2 (the highest row of graphs), the accuracies of the 2 models are comparable — for d=1 ViT does somewhat higher, while for d=2 they’re almost indistinguishable. For d=3 (bottom left), ViTBNFFN achieves a rather higher accuracy than ViT. For d=4 (bottom right), nonetheless, ViTBNFFN does significantly higher than ViT and this qualitative feature stays the identical for any depth d ≥ 4.
Due to this fact, for an affordable selection of learning rate and batch size, ViTBNFFN converges significantly faster than ViT beyond a critical depth of the transformer encoder. For the range of hyperparameters I consider, plainly this critical depth gets smaller with increasing learning rate at a set batch size.
For the deep neural network example, we saw that the impact of a high learning rate is significantly milder on the network with BatchNorm. Is there something analogous that happens for a Vision Transformer? That is addressed in Figure 10. Here each graph plots the accuracy curves of a given model at a given depth for 2 different learning rates lr=0.003 and lr=0.005. The primary column of graphs corresponds to ViT for d=2, 3 and 4 (top to bottom) while the second column corresponds to ViTBNFFN for a similar depths.
Consider d=2 — given by the highest row of graphs — ViT and ViTBNFFN are comparably impacted as one increases the educational rate. For d = 3 — given by the second row of graphs — the difference is critical. ViT achieves a much lower accuracy at the upper learning rate — the accuracy drops from about 91% to around 78% at the tip of epoch 10. Alternatively, for ViTBNFFN, the accuracy at the tip of epoch 10 drops from about 92% to about 90%. This qualitative feature stays the identical at higher depths too — see the underside row of graphs which corresponds to d=4. Due to this fact, the impact of the upper learning rate on ViTBNFFN looks significantly milder for sufficiently large depth of the transformer encoder.
Conclusion
In this text, I actually have studied the results of introducing a BatchNorm layer contained in the FeedForward Network of the transformer encoder in a Vision Transformer. Comparing the models on an augmented MNIST dataset, there are two primary lessons that one may draw. Firstly, for a transformer of sufficient depth and for an affordable selection of hyperparameters, the model with BatchNorm achieves significantly higher accuracy in comparison with the usual ViT. This faster convergence can greatly speed up training. Secondly, much like our intuition for deep neural networks, the Vision Transformer with BatchNorm is more resilient to a better learning rate, if the encoder is sufficiently deep.