Review — How Do Vision Transformers Work?
Image Classification: 1989 … 2023 [Vision Permutator (ViP)] [ConvMixer]
==== My Other Paper Readings Are Also Over Here ====
- Multi-head Self Attentions (MSAs) improve not only accuracy but also generalization by flattening the Loss Landscapes.
- ViTs suffer from non-convex losses. Large datasets and Loss Landscape smoothing methods alleviate this problem.
- MSAs are low-pass filters, but Convs are high-pass filters. They are complementary.
- MSAs at the end of a stage play a key role in prediction.
- Based on these insights, AlterNet is proposed, a model in which Conv blocks at the end of a stage are replaced with MSA blocks. AlterNet outperforms CNNs not only in large data regimes but also in small data regimes.
1. ViT Analysis
- (Please skip this section for quick read.)
- According to Loss Landscape, the flatter the Loss Landscape, the better the performance and generalization.
Left: As seen above, MSAs flatten Loss Landscapes.
- Fourier analysis of feature maps shows that MSAs reduce high-frequency signals, while Convs, conversely, amplifies high-frequency components.
Left: In other words, MSAs are low-pass filters, but Convs are high-pass filters.
Right: Convs are vulnerable to high-frequency noise but that MSAs are not. Therefore, MSAs and Convs are complementary.
- (a) Spatial Smoothing: Multi-stage NNs behave like a series connection of small individual models. Thus, applying spatial smoothing at the end of a stage improves accuracy by ensembling transformed feature map outputs from each stage.
- (b) Canonical Transformer: Standard Transformer module having one MSA block per MLP block.
(c) Proposed Alternative Conv & MSAs: An alternating pattern of Convs and MSAs is proposed. NN stages using this design pattern consists of a number of CNN blocks and one (or a few) MSA block at the end of a stage.
1.2. ViT Analysis
(a): The error of the test dataset and the cross-entropy, or the negative log-likelihood, of the training dataset (NLLtrain, the lower the better) is as shown above. The stronger the inductive bias, the lower both the test error and the training NLL. This indicates that ViT does not overfit training datasets.
Comparing with Fig. 1(b) and Fig. 4, large datasets suppress negative Hessian eigenvalues of ViT in the early phase of training.
GAP classifier suppresses negative Hessian max eigenvalues, suggesting that GAP convexify the loss.
- 3×3, 5×5, and 8×8 (global MSA) MSAs are tested on CIFAR-100.
(a): NLLtrain of 3×3 kernel is worse than that of 5×5 kernel, but better than that of global MSA. Although the test errors of 3×3 and 5×5 kernels are comparable, the robustness of 5×5 kernel is significantly better than that of 3×3 kernel on CIFAR-100-C.
(b): 5×5 kernel has fewer negative eigenvalues than global MSA because it restricts unnecessary degrees of freedom.
1.3. MSAs vs Convs
- MSAs almost always decrease the high-frequency amplitude, and MLPs — corresponding to Convs — increase it.
The only exception is in the early stages of the model. In these stages, MSAs behave like Convs, i.e., they increase the amplitude. This could serve as an evidence for a hybrid model that uses Convs in early stages and MSAs in late stages.
- MSAs in ViT tend to reduce the variance; conversely, Convs in ResNet and MLPs in ViT increase it. In conclusion, MSAs ensemble feature map predictions, but Convs do not.
While reducing the feature map uncertainty helps optimization by ensembling and stabilizing the transformed feature maps, ResNet has high variance at the end, we also can improve the performance by using MSAs with a large number of heads in late stages.
- (a) Mini-batch CKA (Nguyen et al., 2021): to measure the similarities. The feature map similarities of CNNs and multi-stage ViTs, such as PiT and Swin, have a block structure.
- (b) One NN unit is removed from already trained ResNet and Swin during the testing phase: In ResNet, removing an early stage layers hurts accuracy more than removing a late stage layers. In Swin, at the beginning of a stage, removing an MLP hurts accuracy. At the end of a stage, removing an MSA seriously impairs the accuracy.
MSAs closer to the end of a stage to significantly improve the performance.
2.1. Build-up Rule
- Considering all the insights, design rules are proposed:
- Alternately replace Conv blocks with MSA blocks from the end of a baseline CNN model.
- If the added MSA block does not improve predictive performance, replace a Conv block located at the end of an earlier stage with an MSA block.
- Use more heads and higher hidden dimensions for MSA blocks in late stages.
2.2. Alter-ResNet-50 on CIFAR-100
Based on the above analysis, MSA blocks (Grey) are inserted at the end of each stage in ResNet-50, to form Alter-ResNet-50. Following Swin, MSAs in stages 1 to 4 have 3, 6, 12, and 24 heads, respectively
AlterNet outperforms CNNs and ViTs.
2.3. Alter-ResNet-50 on ImageNet (Appendix)
This model alternately replaces Conv blocks with MSA blocks from the end of a stage. Following Swin, MSAs in stages 1 to 4 use 3, 6, 12, and 24 heads, respectively.
MSAs improve the performance of CNNs on ImageNet.