Review — CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification
CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification,
CrossViT, by MIT-IBM Watson AI Lab,
2021 ICCV, Over 500 Citations (Sik-Ho Tsang @ Medium)
1989 … 2022 [ConvNeXt] [PVTv2] [ViT-G] [AS-MLP] [ResTv2] [CSWin Transformer] [Pale Transformer] [Sparse MLP] [MViTv2] [S²-MLP] [CycleMLP] [MobileOne] [GC ViT] [VAN] [ACMix] [CVNets] [MobileViT] [RepMLP] [RepLKNet] [ParNet] 2023 [Vision Permutator (ViP)]
==== My Other Paper Readings Are Also Over Here ====
- CrossViT is proposed, which uses a dual-branch Transformer to process small-patch and large-patch tokens.
- A simple yet effective token fusion module is proposed based on cross attention, which uses a single token for each branch as a query to exchange information with other branches.
- CrossViT Overall Architecture
- Multi-Scale Feature Fusion Strategies
- Model Variants
1. CrossViT Overall Architecture
1.1. Self-Attention in ViT or Transformer
1.2. CrossViT Overall Architecture
- The model is primarily composed of K multiscale Transformer encoders where each encoder consists of two branches: L-branch and S-branch.
- L-Branch: a large (primary) branch that utilizes coarse-grained patch size (Pl) with more Transformer encoders and wider embedding dimensions.
- S-Branch: a small (complementary) branch that operates at fine-grained patch size (Ps) with fewer encoders and smaller embedding dimensions.
- Both branches are fused together L times and the CLS tokens of the two branches at the end are used for prediction.
2. Multi-Scale Feature Fusion Strategies
- Let xi be the token sequence (both patch and CLS tokens) at branch i, where i can be l or s for the large (primary) or small (complementary) branch. xicls and xipatch represent CLS and patch tokens of branch i, respectively.
2.1. (a) All-Attention
A straightforward approach is to simply concatenate all the tokens from both branches, and then fuse information via the self-attention.
- This approach requires quadratic computation time.
- The output zi of the all-attention fusion scheme can be expressed as:
- where fi(·) and gi(·) are the projection and back-projection functions to align the dimension.
2.2. (b) Class Token Fusion
- The CLS token can be considered as an abstract global feature representation of a branch.
Thus, a simple approach is to sum the CLS tokens of two branches. More formally, the output zi of this fusion module can be represented as:
2.3. (c) Pairwise Fusion
- Since patch tokens are located at its own spatial location of an image, a simple heuristic way for fusion is to combine them based on their spatial location.
However, the two branches process patches of different sizes, thus having different number of patch tokens. This method first performs an interpolation to align the spatial size, and then fuse. On the other hand, the two CLS are fused separately.
- The output zi of pairwise fusion of branch l and s can be expressed as:
2.4. (d) Proposed Cross-Attention Fusion:
The fusion involves the CLS token of one branch and patch tokens of the other branch.
- The CLS token at each branch as an agent to exchange information among the patch tokens from the other branch and then back project it to its own branch.
- Specifically, for branch l, it first collects the patch tokens from the S-Branch and concatenates its own CLS tokens to them, as shown above.
- Mathematically, the Cross-Attention (CA) can be expressed as:
Since only CLS is used in the query, the computation and memory complexity of generating the attention map (A) in cross-attention are linear rather than quadratic as in all-attention.
- Multiple heads are used, i.e. MCA:
3. Model Variants
- CrossViT-Ti, CrossViT-S and CrossViT-B set their large (primary) branches identical to the tiny (DeiT-Ti), small (DeiT-S) and base (DeiT-B) models introduced in DeiT, respectively.
- The other models vary by different expanding ratios in FFN(r), depths and embedding dimensions.
- For example, CrossViT-15 has 3 multi-scale encoders, each of which includes 5 regular Transformers, resulting in a total of 15 Transformer encoders.
- A suffix † means the linear patch embeddingis substituted by three convolutional layers as the patch tokenizer, similar to ViT.
4.1. Comparison with DeiT
4.2. SOTA Comparisons
Left: When compared with ViT-B, CrossViT-18† significantly outperforms it by 4.9% (77.9% vs 82.8%) in accuracy while requiring 50% less FLOPs and parameters. Furthermore, CrossViT-18† performs as well as TNT-B and better than the others, but also has fewer FLOPs and parameters.
Right: When compared to the ResNet family, including ResNet, ResNeXt, SENet, ECA-ResNet and RegNet, CrossViT-15 outperforms all of them in accuracy while being smaller and running more efficiently (except ResNet-101, which is slightly faster). In addition, the best models such as CrossViT-15† and CrossViT-18†, when evaluated at higher image resolution, are encouragingly competitive against EfficientNet.
4.3. Transfer Learning
While being better in ImageNet1K, CrossViT is on par with DeiT models on all the downstream classification tasks. This result assures that the proposed models still have good generalization ability rather than only fit to ImageNet1K.
4.4. Ablation Studies
- Left: The proposed cross-attention fusion achieves the best accuracy with minor increase in FLOPs and parameters.
- Right: Two pairs of patch sizes are used such as (8, 16) and (12, 16), and it is observed that the one with (12, 16) achieves better accuracy with fewer FLOPs as shown in Model (A).
- Different channel width and depth in S-branch is tried. Model (B and C) increase FLOPs and parameters without any improvement in accuracy.
- Stack more cross-attention modules (L) or stack more multi-scale Transformer encoders (K) is also tested as Model (D and E). Too frequent fusion of branches does not provide any performance improvement but more FLOPs and parameters. Likewise, using more multi-scale Transformer encoders also does not help in performance.
4.5. Further Studies
- Also, without CLS tokens, 80.0% accuracy is achieved which is is 1% worse than CrossViT-S (81.0%) on ImageNet1K, showing effectiveness of CLS token in summarizing information.
- Using the T2T module to replace linear projection of patch embedding in both branches on CrossViT-18, CrossViT-18+T2T, achieves an top-1 accuracy of 83.0% on ImageNet1K, additional 0.5% improvement over CrossViT-18.