Review — TNT: Transformer in Transformer
Transformer in Transformer
TNT, by ISCAS & UCAS, Huawei Technologies, and University of Macau
2021 NeurIPS, Over 200 Citations (Sik-Ho Tsang @ Medium)
Image Classification, Transformer, ViT
- Transformer iN Transformer (TNT) is designed, where the local patches (e.g., 16×16) are regarded as “visual sentences” and present to further divide them into smaller patches (e.g., 4×4) as “visual words”.
- The attention of each word will be calculated with other words in the given visual sentence with negligible computational costs.
- Features of both words and sentences are aggregated to enhance the representation ability.
- (For quick read, please read 1, 2, 5.1, 5.2.)
Outline
- Preliminaries for Transformer
- Transformer iN Transformer (TNT)
- Computational Complexity
- Model Variants
- Experimental Results
1. Preliminaries for Transformer
- (If your know Transformer, please skip this part.)
- The basic components in Transformer, including MSA (Multi-head Self-Attention), MLP (Multi-Layer Perceptron) and LN (Layer Normalization).
1.1. MSA (Multi-head Self-Attention)
- In single self-attention module, the inputs X are linearly transformed to three parts, i.e. queries Q, keys K, and values V. The scaled dot-product attention is applied on Q, K, and V:
- MSA splits the queries, keys and values to h parts (heads) and perform the attention function in parallel, and then the output values of each head are concatenated and linearly projected to form the final output.
1.2. MLP (Multi-Layer Perceptron)
- MLP is applied between self-attention layers for feature transformation and non-linearity:
- where W and b are the weight and bias. σ is GELU.
1.3. LN (Layer Normalization)
- LN is a key part in Transformer for stable training and faster convergence. LN is applied over each sample x:
- where μ and δ are the mean and standard deviation of the feature respectively. γ and β are learnable affine transform parameters.
2. Transformer iN Transformer (TNT)
2.1. Patch and Sub-Patch
- Given a 2D image, it is uniformly split into n patches:
- where (p, p) is the resolution of each image patch, similar to ViT.
- However, ViT process the sequence of patches which corrupts the local structure of a patch.
- Transformer-iN-Transformer (TNT) architecture learns both global and local information in an image.
In TNT, the patches are viewed as visual sentences that represent the image. Each patch is further divided into m sub-patches, i.e., a visual sentence is composed of a sequence of visual words:
- where xi,j is the j-th visual word of the i-th visual sentence, (s, s) is the spatial size of sub-patches.
2.2. Inner Transformer Block
- With a linear projection, the visual words are transformed into a sequence of word embeddings:
- where yi,j is the j-th word embedding, c is the dimension of word embedding, and Vec() is the vectorization operation.
In TNT, there are two data flows. One flow operates across the visual sentences and the other processes the visual words inside each sentence.
- For the word embeddings, a Transformer block is utilized to explore the relation between visual words:
- where l is the index of the l-th block, and L is the total number of stacked blocks. The input of the first block Yi,0 is just Yi.
This can be viewed as an inner Transformer block, denoted as Tin. This process builds the relationships among visual words by computing interactions between any two visual words. For example, in a patch of human face, a word corresponding to the eye is more related to other words of eyes while interacts less with forehead part.
2.3. Outer Transformer Block
- For the sentence level, the sentence embedding memories are created to store the sequence of sentence-level representations:
- where Zclass is the class token similar to ViT, and all of them are initialized as zero.
- In each layer, the sequence of word embeddings are transformed into the domain of sentence embedding by linear projection and added into the sentence embedding:
- With the above addition operation, the representation of sentence embedding is augmented by the word-level features. the standard Transformer block is used for transforming the sentence embeddings:
This outer Transformer block Tout is used for modeling relationships among sentence embeddings.
2.4. Summary
- In summary, the inputs and outputs of the TNT block include the visual word embeddings and sentence embeddings:
- By stacking the TNT blocks for L times, the Transformer-iN-Transformer (TNT) network is built.
Finally, the classification token serves as the image representation and a fully-connected layer is applied for classification.
2.5. Positional Encoding
- For sentence embeddings and word embeddings, the corresponding position encodings are added to retain spatial information. The standard learnable 1D position encodings are utilized here.
- Specifically, each sentence is assigned with a position encodings:
- where Esentence are the sentence position encodings.
- As for the visual words in a sentence, a word position encoding is added to each word embedding:
- where Eword are the word position encodings which are shared across sentences.
In this way, sentence position encoding can maintain the global spatial information, while word position encoding is used for preserving the local relative position.
3. Computational Complexity
- The FLOPs of MSA are:
- The FLOPs of MLP are:
- where r is the dimension expansion ratio of hidden layer in MLP.
- Overall, the FLOPs of a standard Transformer block are:
- As r=4, it can be simplified as:
- The number of parameters can be obtained as:
- In TNT, the computation complexity of Tin is:
- The computation complexity of Tout is:
- The linear layer has FLOPs of nmcd.
- In total, the FLOPs of TNT block:
- Similarly, the parameter complexity of TNT block is calculated as:
- Although two more components are added in TNT block, the increase of FLOPs is small.
With a small increase of computation and memory cost, TNT block can efficiently model the local structure information and achieve a much better trade-off between accuracy and complexity.
4. Model Variants
- The patch size is set as 16×16. The number of sub-patches is set as m=4×4=16 by default.
- As shown above, there are three variants of TNT networks with different model sizes, namely, TNT-Ti, TNT-S and TNT-B. They consist of 6.1M, 23.8M and 65.6M parameters respectively.
- The corresponding FLOPs for processing a 224×224 image are 1.4B, 5.2B and 14.1B respectively.
5. Experimental Results
5.1. Datasets to be Evaluated
5.2. ImageNet
TNT outperforms all other visual Transformer models.
- In particular, TNT-S achieves 81.5% top-1 accuracy which is 1.7% higher than the baseline model DeiT-S, indicating the benefit of the introduced TNT framework to preserve local structure information inside the patch.
Compared to CNNs, TNT can outperform the widely-used ResNet and RegNet.
- However, all the Transformer-based models are still inferior to EfficientNet which utilizes special depth-wise convolutions, so it is yet a challenge of how to beat EfficientNet using pure Transformer.
- TNT models consistently outperform other Transformer-based models by a significant margin.
5.3. Inference Speed
TNT is more efficient than DeiT and PVT by achieving higher accuracy with similar inference speed.
5.4. Ablation Study
Removing sentence/word position encoding results in a 0.8%/0.7% accuracy drop respectively, and removing all position encodings heavily decrease the accuracy by 1.0%.
- The head width of 64 is used in outer Transformer block. A proper number of heads (e.g., 2 or 4) in inner Transformer block achieve the best performance.
- The value of m has slight influence on the performance, m=16 is used by default for its efficiency.
SE module, in SENet, inserted into TNT can further improve the accuracy slightly.
5.5. Visualization
In TNT, the local information are better preserved compared to DeiT.
For a given query visual word, the attention values of visual words with similar appearance are higher, indicating their features will be interacted more relevantly with the query. These interactions are missed in ViT and DeiT.
- The attention of TNT can focus on the meaningful patches in Block-12, while DeiT still pays attention to the tree which is not related to the pandas.
- The output feature mainly focus on the patches related to the object to be recognized.
5.6. Transfer Learning
TNT outperforms DeiT in most datasets with less parameters, which shows the superiority of modeling pixel-level relations to get better feature representation.
DETR with TNT-S backbone outperforms the representative pure Transformer detector DETR+PVT-Small by 3.5 AP with similar parameters.
With similar parameters, Trans2Seg with TNT-S backbone achieves 43.6% mIoU, which is 1.0% higher than that of PVT-small backbone and 2.8% higher than that of DeiT-S backbone.
TNT achieves much better performance than ResNet and DeiT backbones, indicating its generalization for FPN-like framework.
Reference
[2021 NeurIPS] [TNT]
Transformer in Transformer
Image Classification
1989 … 2021: [Learned Resizer] [Vision Transformer, ViT] [ResNet Strikes Back] [DeiT] [EfficientNetV2] [MLP-Mixer] [T2T-ViT] [Swin Transformer] [CaiT] [ResMLP] [ResNet-RS] [NFNet] [PVT, PVTv1] [CvT] [HaloNet] [TNT]