Review — Self-Supervised Pre-Training of Swin Transformers for 3D Medical Image Analysis

Swin UNETR with Self-Supervised Learning

Sik-Ho Tsang
5 min readMar 20, 2023

Self-Supervised Pre-Training of Swin Transformers for 3D Medical Image Analysis,
Self-Supervised Swin UNETR, by Vanderbilt University, and NVIDIA,
2022 CVPR, Over 70 Citations (Sik-Ho Tsang @ Medium)
Medical Imaging, Medical Image Analysis, Image Segmentation, U-Net, Transformer, Vision Transformer, ViT

Biomedical Image Segmentation
2015 … 2022 [UNETR]
Biomedical Image Self-Supervised Learning
2018 … 2022 [BT-Unet] [Taleb JDiagnostics’22]
==== My Other Paper Readings Are Also Over Here ====

  • 3D Swin UNEt TRansformers (Swin UNETR), is proposed with a hierarchical encoder for self-supervised pre-training.
  • Tailored proxy tasks are designed for learning the underlying pattern of human anatomy.

Outline

  1. Brief Review of Swin UNETR
  2. Self-Supervised Swin UNETR
  3. Results

1. Brief Review of Swin UNETR

Overview of the Swin UNETR architecture.

1.1. Input

  • The Swin UNETR creates non-overlapping patches of the input data and uses a patch partition layer to create windows with a desired size for computing the self-attention.

1.2. Encoder

  • Swin UNETR encoder has 4 stages which comprise of 2 Transformer blocks at each stage.
  • Two Transformer blocks are W-MSA and SW-MSA, which refer to regular and window partitioning multi-head self-attention modules respectively.
  • (Please feel free to read about Swin Transformer for the shifted window self attention.)

1.3. Decoder

  • The encoded feature representations in the Swin Transformer are fed to a CNN-decoder via skip connection at multiple resolutions.
  • At each stage, the output feature representations are reshaped and fed into a residual block comprising of two 3×3×3 convolutional layers that are normalized by instance normalization.
  • The resolution of the feature maps are increased by a factor of 2 using a deconvolutional layer and the outputs are concatenated with the outputs of the previous stage.

1.4. Output

  • The final segmentation outputs are computed by using a 1×1×1 convolutional layer and a sigmoid activation function.

2. Self-Supervised Swin UNETR

An animation of original images (left) and their reconstructions (right) (Figure from Authors’ Github)
Overview of proposed pre-training framework.
  • Input CT images are randomly cropped into sub-volumes and augmented with random inner cutout and rotation, then fed to the Swin UNETR encoder as input.
  • Masked volume inpainting, contrastive learning, and rotation prediction are used as proxy tasks for learning contextual representations of input images.

2.1. Masked Volume Inpainting

  • The cutout augmentation masks out ROIs in the sub-volume X randomly with volume ratio of s.
  • A transpose convolution layer is attached to the encoder as the reconstruction head and denote its output as XˆM.
  • The reconstruction objective is defined by an L1 loss between X and XˆM:

2.2. Image Rotation

  • For simplicity, R classes of 0°, 90°, 180°, 270° rotations are employed along the z-axis.
  • An MLP classification head is used for predicting the softmax probabilities ˆyr of rotation categories.
  • Given the ground truth yr, a cross-entropy loss is used for rotation prediction task:
  • The 3D rotation and cutout also serves simultaneously as an augmentation transformation for contrastive learning.

2.3. Contrastive Coding

  • Given a batch of augmented sub-volumes, the contrastive coding allows for a better representation learning by maximizing the mutual information between positive pairs (augmented samples from same sub-volume), while minimizing that between negative pairs (views from different sub-volumes).
  • Similar to SimCLR, the 3D contrastive coding loss between a pair vi and vj is defined as:
  • The contrastive learning loss function strengthens the intra-class compactness as well as the inter-class separability.

2.4. Total Loss

  • The total loss is:
  • A grid-search hyper-parameter optimization was performed which estimated the optimal values of λ1=λ2=λ3=1.

3. Results

3.1. BTCV Multi-organ Segmentation Challenge

Leaderboard Dice results of BTCV challenge on multi-organ segmentation.

Compared with other top submissions, the proposed Swin UNETR achieves the best performance.

Qualitative visualizations of the proposed Swin UNETR and baseline methods.

The representative samples demonstrate the success of identifying organ details by Swin UNETR.

3.2. Segmentation Results on MSD

Overall performance of top-ranking methods on all 10 segmentation tasks in the MSD public test leaderboard.

Overall, Swin UNETR presents the best average Dice of 78.68% across all ten tasks and achieves the top ranking in the MSD leaderboard.

MSD test dataset performance comparison of Dice and NSD. Benchmarks obtained from MSD test leaderboard5.
  • The detail number of multiple tasks are shown above.

The proposed Swin UNETR achieves state-of-the-art performance in Task01 BrainTumour, Task06 Lung, Task07 Pancreas, and Task10 Colon.

The results are comparable for Task02 Heart, Task03 Liver, Task04 Hippocampus, Task05 Prostate, Task08 HepaticVessel and Task09 Spleen.

Qualitative results of representative MSD CT tasks.

Swin UNETR with self-supervised pre-training demonstrates visually better segmentation results in the CT tasks.

3.3. Ablation Study

The indication of Dice gap between using pre-training (Green) and scratch model (Blue) on MSD CT tasks validation set.
  • A comparison of all MSD CT tasks using pre-trained model against training from scratch can be observed above.

Distinct improvement can be observed for Task03 Liver, Dice of 77.77% comparing to 75.27%. Task08 Hepatic Vessel achieves 68.52% against 64.63%. Task10 Colon shows the largest improvement, from 34.83% to 43.38%.

Data-efficient performance on BTCV test dataset. Significance under Wilcoxon Signed Rank test, :p<0.001.

The proposed approach can reduce the annotation effort by at least 40% for BTCV task.

Pre-trained weights using 100, 3000 and 5000 scans are compared for fine-tuning on the BTCV dataset for each organ.
  • The fine-tuning results are obtained from pre-training 100, 3,000, and 5,000 scans.

The proposed model can benefit from larger pre-training datasets with increasing size of unlabeled data.

Ablation study of the effectiveness of each objective function in the proposed pre-training loss.
  • On BTCV test set, using pre-trained weights by inpainting achieves the highest improvement at single task modeling.

Overall, employing all proxy tasks achieves best Dice of 84.72%.

--

--

Sik-Ho Tsang

PhD, Researcher. I share what I learn. :) Linktree: https://linktr.ee/shtsang for Twitter, LinkedIn, etc.