Review — CoTr: Efficiently Bridging CNN and Transformer for 3D Medical Image Segmentation
CoTr, Deformable Transformer Used In Between Encoder & Decoder
CoTr: Efficiently Bridging CNN and Transformer for 3D Medical Image Segmentation,
CoTr, by Northwestern Polytechnical University, and The University of Adelaide,
2021 MICCAI, Over 150 Citations (Sik-Ho Tsang @ Medium)
Medical Imaging, Medical Image Analysis, Image Segmentation, Transformer4.2. Biomedical Image Segmentation
2015 … 2021 [Expanded U-Net] [3-D RU-Net] [nnU-Net] [TransUNet]
My Other Previous Paper Readings Are Also Over Here
- Convolutional neural network and a Transformer (CoTr) is proposed, where the CNN is constructed to extract feature representations and an efficient deformable Transformer (DeTrans) is built to model the long-range dependency on the extracted feature maps.
- Different from the vanilla Transformer which treats all image positions equally, DeTrans pays attention only to a small set of key positions by introducing the deformable self-attention mechanism.
Outline
- CoTr Design Difference
- CoTr Model Architecture
- Results
1. CoTr Design Difference
- (a) CNN: encoder is composed of multiple stacked convolutional layers.
- (b) SETR: encoder is purely formed from self-attention layers, i.e., Transformer.
- (c) TransUNet and (d) CoTr encoders are both the hybrid of CNN and Transformer.
- Differently, TransUNet only processes the low-resolution feature maps from the last stage due to the high computation and spatial complexities.
CoTr is able to process the multi-scale and high-resolution feature maps.
2. CoTr Model Architecture
2.1. CNN-encoder
- The CNN-encoder, FCNN(·) contains a Conv-IN-LeakyReLU block and three stages of 3D residual blocks.
- The Conv-IN-LeakyReLU block contains a 3D convolutional layer followed by an instance normalization (IN) [18] and leaky rectified linear unit (LeakyReLU) activation.
- The numbers of 3D residual blocks in three stages are three, three, and two, respectively.
Given an input image x with a height of H, a width of W, and a depth (i.e., number of slices) of D, the feature maps produced by FCNN(·) can be formally expressed as:
- where L indicates the number of feature levels, Θ denotes the parameters of the CNN-encoder, and C denotes the number of channels.
2.2. DeTrans-Encoder
- The CNN-encoder cannot capture the long-range dependency of pixels effectively, DeTrans-encoder is used to introduce the multi-scale deformable self-attention (MS-DMSA) mechanism for efficient long-range contextual modeling.
- The DeTrans-encoder is a composition of an input-to-sequence layer and LD stacked deformable Transformer (DeTrans) layers.
2.2.1. Input-to-Sequence Transformation
- Feature maps {fl} where l from 1 to L is flattened into a 1D sequence, yet flattening the features leads to losing the spatial information.
The 3D positional encoding sequence using sine and cosine functions with different frequencies is added to the flatten features to compute the positional coordinates of each dimension pos:
- where # ∈ {D, H, W} indicates each of three dimensions, and:
For each feature level l, PED, PEH, and PEW are concatenated as the 3D positional encoding pl and combine it with the flattened fl via element-wise summation to form the input sequence of DeTrans-encoder.
2.2.2. MS-DMSA Layer
- The self-attention layer in the conventional Transformer would look over all possible locations in the feature map. It has the drawback of slow convergence and high computational complexity.
- (It is better to read Deformable DETR before here.)
Following Deformable DETR, MS-DMSA layer focuses only on a small set of key sampling locations around a reference location, instead of all locations.
- where K is the number of sampled key points, Ψ(·) is a linear projection layer, Λ(zq)ilqk ∈ [0, 1] is the attention weights, Δpilqk ∈ R³ is the sampling offset of the k-th sampling point in the l-th feature level, and σl(·) re-scales ˆpq to the l-th level feature.
- Λ(zq)ilqk and Δpilqk are obtained via linear projection over the query feature zq.
- Then, the MS-DMSA layer can be formulated as:
- where H is the number of attention heads, and Φ(·) is a linear projection layer that weights and aggregates the feature representation of all attention heads.
2.2.3. DeTrans Layer
- The DeTrans layer is composed of a MS-DMSA layer and a feed forward network, each being followed by the layer norm.
- The skip connection strategy is employed in each sub-layer to avoid gradient vanishing.
- The DeTrans-encoder is constructed by repeatedly stacking DeTrans layers.
2.3. Decoder
- The decoder is a pure CNN architecture, which progressively upsamples the feature maps to the input resolution.
- Besides, the skip connections between encoder and decoder are also added to keep more low-level details for better segmentation.
- Deep supervision strategy is used by adding auxiliary losses to the decoder outputs with different scales.
- The loss function is the sum of the Dice loss and cross-entropy loss.
2.4. Details
- The hidden size in MS-DMSA and feed forward network to 384 and 1536 respectively.
- LD=6, H=6, and K=4.
- Besides, two variants of CoTr with small CNN-encoders are formed, denoted as CoTr∗ and CoTr†.
- In CoTr∗, there is only one 3D residual block in each stage of CNN-encoder.
- In CoTr†, the number of 3D residual blocks in each stage of CNN-encoder is two.
3. Results
3.1. SOTA Comparison
- First, although the Transformer architecture is not limited by the type of input images, the ViT-B/16 pre-trained on 2D natural images does not work well on 3D medical images. The suboptimal performance may be attributed to the domain shift between 2D natural images and 3D medical images.
- Second, ‘CoTr w/o CNN-encoder’ has about 22M parameters and outperforms the SETR with about 100M parameters. It is believed that a lightweight Transformer may be more friendly for medical image segmentation tasks, where there is usually a small training dataset.
- Third, CoTr∗ with comparable parameters significantly outperforms ‘CoTr w/o CNN-encoder’, improving the average Dice over 11 organs by 4%.
It suggests that the hybrid CNN-Transformer encoder has distinct advantages over the pure Transformer encoder in medical image segmentation.
- The same CNN-encoder is used and decoder but DeTrans-encoder is replaced with ASPP (DeepLabv3), PP (PSPNet), and Non-Local modules, respectively.
CoTr elevates consistently the segmentation performance over ‘CoTr w/o DeTrans’ on all organs and improves the average Dice by 1.4%.
- The original 2D TransUNet is extended to a 3D version by using 3D CNN-encoder and decoder is as done in CoTr.
CoTr steadily beats TransUNet in the segmentation of all organs, particularly for the gallbladder and pancreas segmentation.
3.2. Ablation Study
- (a-c): In the DeTrans-encoder, there are three hyper-parameters, i.e., K, H, and LD, which represent the number of sampled key points, heads, and stacked DeTrans layers, respectively.
- To investigate the impact of their settings on the segmentation, K is set to 1, 2, and 4, H is set to 2, 4, and 6, and LD is set to 2, 4, and 6.
It shows that increasing the number of K, H, or LD can improve the segmentation performance.
- (d): To demonstrate the performance gain resulted from the multi-scale strategy, it is also attempted to train CoTr with single-scale feature maps from the last stage.
- Using multi-scale feature maps instead of single-scale feature maps can effectively improve the average Dice by 1.2%.
Authors mentioned CoTr can be extended to deal with other tasks (e.g., brain structure or tumor segmentation) in the future.