Brief Review — Do Vision Transformers See Like Convolutional Neural Networks?
Do Vision Transformers See Like Convolutional Neural Networks?,
Raghu NeurIPS’21, by Google Research, Brain Team,
2021 NeurIPS, Over 380 Citations (Sik-Ho Tsang @ Medium)Image Classification: 1989 … 2023 [Vision Permutator (ViP)] [ConvMixer]
==== My Other Paper Readings Are Also Over Here ====
- This paper raises a central question: How are Vision Transformers solving these tasks? Are they acting like convolutional networks, or learning entirely different visual representations?
1. Representation Similarity and CKA (Centered Kernel Alignment)
- (Please skip this part for quick read, and just treat CKA as a tool to measure representation similarity.)
- Analyzing (hidden) layer representations of neural networks is challenging because their features are distributed across a large number of neurons. This distributed aspect also makes it difficult to meaningfully compare representations across neural networks.
Centered kernel alignment (CKA) [17, 10] addresses these challenges, enabling quantitative comparisons of representations within and across networks.
- Specifically, CKA takes as inputs X and Y which are representations (activation matrices), of two layers, with m×p1 and m×p2 neurons respectively, evaluated on the same m examples.
- Letting K=X X^T and L=Y Y^T denote the Gram matrices for the two layers (which measures the similarity of a pair of datapoints according to layer representations), CKA computes:
- where HSIC is the Hilbert-Schmidt independence criterion [15].
- Given the centering matrix H and the centered Gram matrices K’, L’:
- HSIC is computed as:
- which is the similarity between these centered Gram matrices.
CKA is invariant to orthogonal transformation of representations (including permutation of neurons), and the normalization term ensures invariance to isotropic scaling. These properties enable meaningful comparison and analysis of neural network hidden representations.
- To work at scale with models and tasks, the unbiased estimator of HSIC [39] is approximated using minibatches, as suggested in [29].
- (It is better to understand CKA & HSIC more by reading [10,15,17,29,39])
2. ViT and ResNet Analysis
2.1. Representation Structure of ViTs and Convolutional Networks
- The above figure shows the results as a heatmap, for multiple ViTs and ResNets. It is observed clear differences between the internal representation structure between the two model architectures:
- ViTs show a much more uniform similarity structure, with a clear grid like structure.
- Lower and higher layers in ViT show much greater similarity than in the ResNet, where similarity is divided into different (lower/higher) stages.
- The lower half of 60 ResNet layers are similar to approximately the lowest quarter of ViT layers. In particular, many more lower layers in the ResNet are needed to compute similar representations to the lower layers of ViT.
- The top half of the ResNet is approximately similar to the next third of the ViT layers.
- The final third of ViT layers is less similar to all ResNet layers, likely because this set of layers mainly manipulates the CLS token representation.
(i) ViT lower layers compute representations in a different way to lower layers in the ResNet, (ii) ViT also more strongly propagates representations between lower and higher layers (iii) the highest layers of ViT have quite different representations to ResNet.
2.2. Local and Global Information in Layer Representations
Even in the lowest layers of ViT, self-attention layers have a mix of local heads (small distances) and global heads (large distances). This is in contrast to CNNs, which are hardcoded to attend only locally in the lower layers.
At higher layers, all self-attention heads are global.
Lower layer effective receptive fields for ViT are indeed larger than in ResNets, and while ResNet effective receptive fields grow gradually, ViT receptive fields become much more global midway through the network.
- ViT receptive fields also show strong dependence on their center patch due to their strong residual connections.
3. Scaling & Transfer Learning
- Linear classifer probes are trained on ImageNet classes for models pretrained on JFT-300M vs models only pretrained on ImageNet.
JFT-300M pretained models achieve much higher accuracies even with middle layer representations.
Also, larger ViT models learn much stronger intermediate representations than the ResNets.
- (There are still many analysis, please read the paper directly.)
4. Limitations
- CKA [17] is used, which summarizes measurements into a single scalar only, to provide quantitative insights on representation similarity.
- More fine-grained methods are needed to reveal additional insights and variations in the representations.