Review — Medical Transformer: Gated Axial-Attention for Medical Image Segmentation
Medical Transformer (MedT), Modify Axial-DeepLab Axial-Attention Block
Medical Transformer: Gated Axial-Attention for Medical Image Segmentation,
Medical Transformer (MedT), by Johns Hopkins University, and The State University of New Jersey
2021 MICCAI, Over 400 Citations (Sik-Ho Tsang @ Medium)Biomedical Image Segmentation
2015 … 2022 [UNETR] [Half-UNet] [BUSIS] [RCA-IUNet] 2023 [DCSAU-Net]
==== My Other Paper Readings Are Also Over Here ====
- Convolutional architectures lack understanding of long-range dependencies in the image, while Transformers are difficult to efficiently train because the number of data samples is relatively low.
- In this paper, a gated axial-attention model, Medical Transformer (MedT), is proposed by introducing an additional control mechanism in the self-attention module.
- Local-Global training strategy (LoGo) is also proposed to learn global and local features using the whole image and patches respectively.
Outline
- Medical Transformer (MedT)
- Results
1. Medical Transformer (MedT)
1.1. Conventional Self-Attention
- Given an input feature map x, y is the output of a self-attention layer:
- where queries q=WQx, keys k=WKx and values v=WVx are all projections computed from the input x.
However, self-attention layer does not utilize any positional information while positional information is often useful in vision models to capture structure of an object.
1.2. Conventional Self-Attention
- As in [6], Axial-Attention decomposed self-attention into two self-attention modules. The axial attention consequently applied on height and width axis, effectively model original self-attention mechanism with much better computational efficacy.
Though [6] is rejected in 2020 ICLR, axial-attention reduces number of parameters which can help for training the Transformer.
- Moreover, as in Shaw NAACL’18, a relative position bias term (riw) is added to make the affinities sensitive to the positional information:
1.3. Proposed Gated Axial-Attention
- The axial-attention proposed in Axial-DeepLab is able to compute non-local context with good computational efficiency, able to encode positional bias into the mechanism.
Authors argue that with small-scale datasets, the positional bias is difficult to learn and hence will not always be accurate in encoding long-range interactions.
- A modified axial-attention block is proposed that can control the influence positional bias can exert in the encoding of non-local context:
- Also, GQ, GK, GV1, GV2 are learnable parameters and together they create gating mechanism which control influence.
Typically, if a relative positional encoding is learned accurately, the gating mechanism will assign it high weight compared to the ones which are not learned accurately.
- According to authors’ GitHub, GQ, GK, GV1, GV2 are defined as:
1.4. Local-Global training strategy (LoGo)
- To improve the overall understanding of the image, two branches are used in the network, i.e., a global branch which works on the original resolution of the image, and a local branch which operates on patches of the image.
- In the global branch, the number of gated axial transformer layers is fewer as it is observed that the first few blocks of the proposed transformer model is sufficient to model long range dependencies.
- In the local branch, 16 patches of size I/4×I/4 are used as input.
- The output feature maps of both of the branches are then added and passed through a 1×1 convolution layer to produce the output segmentation mask.
This strategy improves the performance as the global branch focuses on high-level information and the local branch can focus on finer details.
1.5. Loss Function
- Binary cross-entropy (CE) loss between the prediction and the ground truth is used:
2. Results
- The proposed method is able to overcome such issue with the help of gated axial attention and LoGo both individually perform better than the other methods.
The final architecture MedT performs better than Gated axial attention, LoGo and all the previous methods.
As MedT takes into account pixel-wise dependencies that are encoded with gating mechanism, it is able to learn those dependencies better than the axial attention U-Net. This makes the predictions more precise as they do not miss-classify pixels near the segmentation mask.