Review — TNT: Transformer in Transformer

Outperforms ViT, DeiT, PVT

Transformer in Transformer
TNT, by ISCAS & UCAS, Huawei Technologies, and University of Macau
2021 NeurIPS, Over 200 Citations (Sik-Ho Tsang @ Medium)
Image Classification, Transformer, ViT

  • Transformer iN Transformer (TNT) is designed, where the local patches (e.g., 16×16) are regarded as “visual sentences” and present to further divide them into smaller patches (e.g., 4×4) as “visual words”.
  • The attention of each word will be calculated with other words in the given visual sentence with negligible computational costs.
  • Features of both words and sentences are aggregated to enhance the representation ability.
  • (For quick read, please read 1, 2, 5.1, 5.2.)

Outline

  1. Preliminaries for Transformer
  2. Transformer iN Transformer (TNT)
  3. Computational Complexity
  4. Model Variants
  5. Experimental Results

1. Preliminaries for Transformer

1.1. MSA (Multi-head Self-Attention)

  • In single self-attention module, the inputs X are linearly transformed to three parts, i.e. queries Q, keys K, and values V. The scaled dot-product attention is applied on Q, K, and V:
  • MSA splits the queries, keys and values to h parts (heads) and perform the attention function in parallel, and then the output values of each head are concatenated and linearly projected to form the final output.

1.2. MLP (Multi-Layer Perceptron)

  • MLP is applied between self-attention layers for feature transformation and non-linearity:
  • where W and b are the weight and bias. σ is GELU.

1.3. LN (Layer Normalization)

  • LN is a key part in Transformer for stable training and faster convergence. LN is applied over each sample x:
  • where μ and δ are the mean and standard deviation of the feature respectively. γ and β are learnable affine transform parameters.

2. Transformer iN Transformer (TNT)

Transformer-iN-Transformer (TNT) Framework. The inner Transformer block is shared in the same layer. The word position encodings are shared across visual sentences.

2.1. Patch and Sub-Patch

  • Given a 2D image, it is uniformly split into n patches:
  • where (p, p) is the resolution of each image patch, similar to ViT.
  • However, ViT process the sequence of patches which corrupts the local structure of a patch.
  • Transformer-iN-Transformer (TNT) architecture learns both global and local information in an image.

In TNT, the patches are viewed as visual sentences that represent the image. Each patch is further divided into m sub-patches, i.e., a visual sentence is composed of a sequence of visual words:

  • where xi,j is the j-th visual word of the i-th visual sentence, (s, s) is the spatial size of sub-patches.

2.2. Inner Transformer Block

  • With a linear projection, the visual words are transformed into a sequence of word embeddings:
  • where yi,j is the j-th word embedding, c is the dimension of word embedding, and Vec() is the vectorization operation.

In TNT, there are two data flows. One flow operates across the visual sentences and the other processes the visual words inside each sentence.

  • For the word embeddings, a Transformer block is utilized to explore the relation between visual words:
  • where l is the index of the l-th block, and L is the total number of stacked blocks. The input of the first block Yi,0 is just Yi.

This can be viewed as an inner Transformer block, denoted as Tin. This process builds the relationships among visual words by computing interactions between any two visual words. For example, in a patch of human face, a word corresponding to the eye is more related to other words of eyes while interacts less with forehead part.

2.3. Outer Transformer Block

  • For the sentence level, the sentence embedding memories are created to store the sequence of sentence-level representations:
  • where Zclass is the class token similar to ViT, and all of them are initialized as zero.
  • In each layer, the sequence of word embeddings are transformed into the domain of sentence embedding by linear projection and added into the sentence embedding:
  • With the above addition operation, the representation of sentence embedding is augmented by the word-level features. the standard Transformer block is used for transforming the sentence embeddings:

This outer Transformer block Tout is used for modeling relationships among sentence embeddings.

2.4. Summary

  • In summary, the inputs and outputs of the TNT block include the visual word embeddings and sentence embeddings:
  • By stacking the TNT blocks for L times, the Transformer-iN-Transformer (TNT) network is built.

Finally, the classification token serves as the image representation and a fully-connected layer is applied for classification.

2.5. Positional Encoding

  • For sentence embeddings and word embeddings, the corresponding position encodings are added to retain spatial information. The standard learnable 1D position encodings are utilized here.
  • Specifically, each sentence is assigned with a position encodings:
  • where Esentence are the sentence position encodings.
  • As for the visual words in a sentence, a word position encoding is added to each word embedding:
  • where Eword are the word position encodings which are shared across sentences.

In this way, sentence position encoding can maintain the global spatial information, while word position encoding is used for preserving the local relative position.

3. Computational Complexity

  • The FLOPs of MSA are:
  • The FLOPs of MLP are:
  • where r is the dimension expansion ratio of hidden layer in MLP.
  • Overall, the FLOPs of a standard Transformer block are:
  • As r=4, it can be simplified as:
  • The number of parameters can be obtained as:
  • In TNT, the computation complexity of Tin is:
  • The computation complexity of Tout is:
  • The linear layer has FLOPs of nmcd.
  • In total, the FLOPs of TNT block:
  • Similarly, the parameter complexity of TNT block is calculated as:
  • Although two more components are added in TNT block, the increase of FLOPs is small.

With a small increase of computation and memory cost, TNT block can efficiently model the local structure information and achieve a much better trade-off between accuracy and complexity.

4. Model Variants

Variants of our TNT architecture. ‘Ti’ means tiny, ‘S’ means small, and ‘B’ means base. The FLOPs are calculated for images at resolution 224×224.
  • The patch size is set as 16×16. The number of sub-patches is set as m=4×4=16 by default.
  • As shown above, there are three variants of TNT networks with different model sizes, namely, TNT-Ti, TNT-S and TNT-B. They consist of 6.1M, 23.8M and 65.6M parameters respectively.
  • The corresponding FLOPs for processing a 224×224 image are 1.4B, 5.2B and 14.1B respectively.
Other detailed Settings

5. Experimental Results

5.1. Datasets to be Evaluated

Details of used visual datasets

5.2. ImageNet

Results of TNT and other networks on ImageNet

TNT outperforms all other visual Transformer models.

  • In particular, TNT-S achieves 81.5% top-1 accuracy which is 1.7% higher than the baseline model DeiT-S, indicating the benefit of the introduced TNT framework to preserve local structure information inside the patch.

Compared to CNNs, TNT can outperform the widely-used ResNet and RegNet.

Performance comparison of the representative visual backbone networks on ImageNet
  • TNT models consistently outperform other Transformer-based models by a significant margin.

5.3. Inference Speed

GPU throughput comparison of vision transformer models

TNT is more efficient than DeiT and PVT by achieving higher accuracy with similar inference speed.

5.4. Ablation Study

Effect of position encoding

Removing sentence/word position encoding results in a 0.8%/0.7% accuracy drop respectively, and removing all position encodings heavily decrease the accuracy by 1.0%.

Effect of #heads in inner Transformer block in TNT-S
  • The head width of 64 is used in outer Transformer block. A proper number of heads (e.g., 2 or 4) in inner Transformer block achieve the best performance.
Effect of #words m
  • The value of m has slight influence on the performance, m=16 is used by default for its efficiency.
Exploring SE module in TNT

SE module, in SENet, inserted into TNT can further improve the accuracy slightly.

5.5. Visualization

Visualization of the features of DeiT-S and TNT-S.

In TNT, the local information are better preserved compared to DeiT.

Attention maps of different queries in the inner Transformer. Red cross symbol denotes the query location

For a given query visual word, the attention values of visual words with similar appearance are higher, indicating their features will be interacted more relevantly with the query. These interactions are missed in ViT and DeiT.

Visualization of the attention maps between all patches in outer Transformer block.
  • The attention of TNT can focus on the meaningful patches in Block-12, while DeiT still pays attention to the tree which is not related to the pandas.
Example attention maps from the output token to the input space
  • The output feature mainly focus on the patches related to the object to be recognized.

5.6. Transfer Learning

Results on downstream image classification tasks with ImageNet pre-training (384: larger 384 fine-tune size)

TNT outperforms DeiT in most datasets with less parameters, which shows the superiority of modeling pixel-level relations to get better feature representation.

Results on downstream imResults of object detection on COCO2017 val set with ImageNet pre-trainingage classification tasks with ImageNet pre-training

DETR with TNT-S backbone outperforms the representative pure Transformer detector DETR+PVT-Small by 3.5 AP with similar parameters.

Results of semantic segmentation on ADE20K val set with ImageNet pre-training

With similar parameters, Trans2Seg with TNT-S backbone achieves 43.6% mIoU, which is 1.0% higher than that of PVT-small backbone and 2.8% higher than that of DeiT-S backbone.

Results of Faster R-CNN+FPN object detection on COCO minival set with ImageNet pre-training

TNT achieves much better performance than ResNet and DeiT backbones, indicating its generalization for FPN-like framework.

--

--