Review — Bridging the Gap Between Vision Transformers and Convolutional Neural Networks on Small Datasets
Bridging the Gap Between Vision Transformers and Convolutional Neural Networks on Small Datasets,
DHVT, by University of Science and Technology of China,
2022 NeurIPS (Sik-Ho Tsang @ Medium)
1989 … 2023 [Vision Permutator (ViP)] [ConvMixer]
==== My Other Paper Readings Are Also Over Here ====
- There are two weaknesses of ViTs in inductive biases, that is, the spatial relevance and diverse channel representation.
- Dynamic Hybrid Vision Transformer (DHVT) is proposed to enhance the two inductive biases.
- On spatial aspect, a hybrid structure is adopted, in which convolution is integrated into patch embedding and multi-layer perceptron module, forcing the model to capture the token features as well as their neighboring features.
- On channel aspect, a dynamic feature aggregation module in MLP and a brand new “head token” design in multi-head self-attention module are introduced to help re-calibrate channel representation and make different channel group representation interacts with each other. The fusion of weak channel representation forms a strong enough representation.
- Dynamic Hybrid Vision Transformer (DHVT)
1. Dynamic Hybrid Vision Transformer (DHVT)
1.1. Sequential Overlapping Patch Embedding (SOPE)
- The modified patch embedding is called Sequential Overlapping Patch Embedding (SOPE), which contains several successive convolution layers of 3×3 convolution with stride s=2, Batch Normalization and GELU activation. The relation between the number of convolution layers and the patch size is P=2^k.
- SOPE is able to eliminate the discontinuity brought by the vanilla patch embedding module, preserving important low-level features. It is able to provide position information to some extent.
- Two affine transformations are adopted before and after the series of convolution layers. This operation rescales and shifts the input feature, and it acts like normalization, making the training performance more stable on small datasets.
- The whole process of SOPE can be formulated as follows.
- where α and β are learnable parameters, and initialized as 1 and 0 respectively.
1.2. Encoder Overall Architecture
- The feature maps are then reshaped as patch tokens and concatenated with a class token, and are fed into encoder layers.
- Each encoder contains Layer Normalization, multi-head self-attention and feed forward network. The MHSA is modified as Head-Interacted Multi-Head Self-Attention (HI-MHSA) and feed forward network is modified as Dynamic Aggregation Feed Forward (DAFF).
- After the final encoder layer, the output class token will be fed into the linear head for final prediction.
1.3. Dynamic Aggregation Feed Forward (DAFF)
- The vanilla feed forward network (FFN) in ViT is formed by two fully-connected layers and GELU.
- DAFF integrates depth-wise convolution (DWCONV) from MobileNetV1 in FFN. With the inductive bias brought by depth-wise convolution, the model is forced to capture neighboring features, solving the problem on spatial view. It greatly reduces the performance gap when training from scratch on small datasets, and converges faster than standard CNNs.
- Also, similar mechanism as SE module from SENet is used.
- Xc, Xp denote class token and patch tokens respectively.
- Specifically, class token is split from sequence as Xc, before the projection layers. The remaining patch tokens Xp goes through a depth-wise integrated multi-layer perceptron with a shortcut inside.
- The output patch tokens will then be averaged into a weight vector W. After the squeeze-excitation operation, the output weight vector will be multiplied with class token channel-wise.
- Then the re-calibrated class token will be concatenated with output patch tokens to restore the token sequence.
- The process can be formulated as:
1.4. Head-Interacted Multi-Head Self-Attention (HI-MHSA)
In the original MHSA module, each attention head has not interacted with others. Under this circumstance, the representation in each channel group is too weak for recognition especially when it is lack of training data.
- (a) Head Token Generation: In HI-MHSA, each D-dimensional token, including class token, will be reshaped into h parts. Each part contains d channels, where D=d×h.
- All the separated tokens are averaged in their own parts. Thus we get totally h tokens and each one is d-dimensional.
- All such intermediate tokens will be projected into D-dimension again, resulting in h head tokens in total.
- Finally, they are concatenated with patch tokens and class token.
- (b) MHSA: After MHSA, the head tokens will be averaged and added to class token:
1.5. Datasets & Variants
- DHVT-T: 12 encoder layers, embedding dimension of 192, MLP ratios of 4, attention heads of 4 on CIFAR-100 and DomainNet, and 3 on ImageNet-1K.
- DHVT-S: 12 encoder layers, embedding dimension of 384, MLP ratios of 4, attention heads of 8 on CIFAR-100, 6 on DomainNet and ImageNet-1K.
The main focus is training from scratch on small datasets.
2.1. DomainNet & ImageNet-1K
On DomainNet, DHVT shows better results than standard ResNet-50.
On ImageNet-1K, DHVT-T reaches 76.47 accuracy and DHVT-S reaches 82.3 accuracy. Authors mentioned that this is the best performance under such a non-hierarchical vision transformer structure with class token.
DHVT-T reaches 83.54 with 5.8M parameters. And DHVT-S reaches 85.68 with only 22.8M parameters.
2.3. Ablation Studies
- Table 5: A baseline performance of 67.59 is obtained from DeiT-T with 4 heads, training from scratch with 300 epochs.
- When removing absolute positional embedding, the performance drops drastically to 58.72.
- When adopting SOPE and removing absolute position embedding, the performance does not drop so drastically.
When both SOPE and DAFF are adopted, the positional information will be encoded comprehensively, and SOPE will also help address the non-overlapping problem here, preserving fine-grained low-level features in the early stage.
- Table 6: The stable performance gain is observed which is brought by head tokens across different model structures.
When adopting all three modifications, a +13.26 accuracy gain, is obtained successfully bridging the performance gap with CNNs.
Different head token activates on different patch tokens, exhibiting their diverse representations.