Review — DS-TransUNet: Dual Swin Transformer U-Net for Medical Image Segmentation
DS-TransUNet: Dual Swin Transformer U-Net for Medical Image Segmentation,
DS-TransUNet, by Harbin Institute of Technology, and The Chinese University of Hong Kong,
2022 TIM, Over 100 Citations (Sik-Ho Tsang @ Medium)
- The patch division used in the existing Transformer-based models usually ignores the pixel-level intrinsic structural features inside each patch.
- Dual Swin Transformer U-Net (DS-TransUNet) is proposed, which incorporate the hierarchical Swin Transformer with U-Net.
- Particularly in encoder, dual-scale encoding is used to extract the coarse and fine-grained feature representations of different semantic scales.
- Meanwhile, a well-designed Transformer interactive fusion (TIF) module is proposed to effectively perform multiscale information fusion.
1. Use of Swin Transformer
- (a) Conventional ViT: is useful to model the long-range dependency but with quadratic computational complexity.
- (b) Swin Transformer: constraint the self-attention within a window so as to imitate the complexity issue. Yet, the correlation is removed due to the use of local window.
- Thus, Swin Transformer uses the shifted window, so that every pixels are having self-attention with cascaded self-attention module.
- Swin Transformer has different stages, the inputted features will go through patch merging layer to reduce the feature resolution and increase dimension.
This will reduce the number of tokens by 2×, perform 2× downsampling of resolution, and increase the output dimension by 2×. So the output resolutions of four stages are H/s × W/s, H/2s × W/2s, H/4s × W/4s, and H/8s × W/8s, and the dimensions are C, 2C, 4C, and 8C, respectively, where s is the patch size.
In this work, dual Swin Transformers are used at encoding path.
2.1. Overall Architecture
- Given an input medical image, it is first split into non-overlapping patches at two scales and fed into the two branches of encoder separately, and then the output feature representations of different scales will be fused by the TIF module.
- Finally, the fused features are restored to the same resolution as an input image after the up-sampling process based on the Swin Transformer block.
- As such, the final mask predictions are obtained.
2.2. Dual-Scale Encoding Mechanism
ViT and Swin Transformer treat an image as a sequence of non-overlapping patches, ignoring the pixel-level intrinsic structure features inside each patch, which will lead to loss of shallow features such as edges and lines information.
- Specifically, two independent branches with patch size of s=4 (primary) and s=8 (complementary) are used for feature extraction at different spatial levels.
- Specifically, the outputs of two branches are denoted as Fi & Gi.
2.3. Transformer Interactive Fusion (TIF) Module
- Then, the transformed output of Gi is obtained by:
- where ˆgi has the size of 1×(i×C) and LP(·) stands for the linear projection. Avgpool(·) means the average-pooling layer, followed by the flatten operation.
The token ˆgi represents the global abstract information of Gi to interact with Fi at the pixel level.
- Specifically, Fi is concatenated with ˆgi into a sequence of 1+h×w tokens, which are fed into the Transformer layer for computing global self-attention.
- where Fiout as the final output of small-scale branch in TIF.
- Giout, has the size of (h/2×w/2)×(i×c), is obtained from the large-scale branch. Finally, the output feature representation can be acquired as follows:
- where Conv3×3(·) is a 3×3 convolution layer and Up(·) means the 2× bilinear up-sampling process.
- The resulting features Ziout are passed to the decoder via skip connections.
- Specifically, the output of stage 4 in the encoder is used as the initial input of the decoder.
- In each stage of the decoder, the input features are up-sampled by 2×.
- After that, the output is fed into the Swin Transformer block for self-attention computation. There are some advantages of such a design:
- It allows to make full use of the features from the encoder and up-sampling, and
- it can build long-range dependencies and obtain global context information during the up-sampling process to achieve better decoding performance.
- There are 3 stages, where each stage will increase the resolution of feature maps by 2× and reduce the output dimension by 2×.
- Finally, the input image is downsampled by cascading two blocks to get low-level features with resolution of (H/2)×(W/2) and H×W, where each block consists a 3×3 convolutional layer, a group normalization layer, and a ReLU layer successively.
- At the end, all the features at the end are used to predict the segmentation masks.
2.5. Loss Function
- The loss function is composed of weighted IoU loss LWIoU and binary cross-entropy loss LWBCE.
- Inspired by , deep supervision helps the model training by additionally supervising the output S2 of stage 4 in the encoder and S3 of stage 1 in the decoder, which means the final loss function Ltotal can be written as:
- where α, β, and γ are empirically set to 0.6, 0.2, and 0.2, respectively.
- Polyp segmentation task: 5 public endoscopic image datasets.
- 3 additional medical image segmentation tasks: skin lesion segmentation on ISIC 2018 dataset; 2) gland segmentation on gland segmentation (GLAS) dataset; and 3) nuclei segmentation on 2018 Data Science Bowl (Bowl) dataset.
3.2. Polyp Segmentation
DS-TransUNet-L can achieve the highest scores on almost all evaluation metrics for independent datasets, which indicates that Swin Transformer has tremendous potential to replace the traditional CNNs.
DS-TransUNet produces high-quality segmentation masks on cross-study of the polyp segmentation task.
3.3. Three Additional Segmentation Tasks
DS-TransUNet consistently outperforms these Transformer-based competitors.
DS-TransUNet still outperforms the previous baselines and yields the highest mDice and mIoU scores of 0.878 and 0.791. DS-TransUNet is clearly superior to the recent Transformer-based work.
DS-TransUNet achieves the highest scores 0.922 and 0.943 in terms of F1 (mDice) and recall.
(a) Skin Lessons: DS-TransUNet can effectively capture the boundaries of skin lesions and generate better segmentation prediction.
(b) Gland: DS-TransUNet can bring excellent performance to distinguish the gland itself from the surrounding tissue.
(c) Data Science Bowl: DS-TransUNet can concurrently predict the boundaries of dozens of cell nuclei much more accurately than the existing baselines.
3.4. Ablation Study
- U-Net is considered as a vanilla baseline.
- “U w/ TE” denotes the U-shaped model with a standard Transformer-based encoder.
- “U w/ SE” denotes the U-shaped model with the Swin-Transformer-based encoder.
- “U w/ SE+SD” represents the U-shaped model with both the Swin-Transformer-based encoder and decoder.
- “U w/ DSE + SD” is the U-shaped model with the proposed Dual-Swin-Transformer-based encoder and Swin-Transformer-based decoder.
“U w/ DSE + SD + TIF” is the full DS-TransUNet architecture, which yields the best performance.
Authors claim that DS-TransUNet can not only produce a good complexity parameter trade-off but also achieve the best segmentation performance.
- With patch size of (4, 4), DS-TransUNet struggles to have a satisfactory result due to the lack of complementary feature information.
- Oversize patch size of the encoder provides inadequate fine-grained features for medical image segmentation, causing the pixel-level accuracy to decrease.
Patch size of (4, 8) obtains best performance with acceptable FLOPs.