Brief Review — Multi-task Learning for Left Atrial Segmentation on GE-MRI

Deep U-Net + Multi-Task Learning + Post Processing

Sik-Ho Tsang
5 min readNov 4, 2022

Multi-task Learning for Left Atrial Segmentation on GE-MRI
Multi-Task Deep U-Net
, by Imperial College London,
2018 STACOM, Over 40 Citations (Sik-Ho Tsang @ Medium)
Medical Image Analysis, Multi-Task Learning, Image Classification, Image Segmentation, U-Net

  • A multi-task learning framework, using Deep U-Net is proposed for left atrial segmentation in gadolinium-enhanced magnetic resonance images (GE-MRI), which performs both atrial segmentation and pre/post ablation classification at the same time.

Outline

  1. Multi-Task Deep U-Net
  2. Data
  3. Results

1. Multi-Task Deep U-Net

Multi-Task Deep U-Net

1.1. Segmentation

  • The Multi-Task Deep U-Net is derived from 2D U-Net.
  • Since the largest size of images in the evaluation dataset is 640 × 640 in x-y planes, the receptive field of U-Net is increased by adding more pooling layers. It consists of five down-sampling blocks and five up-sampling blocks, with BN and ReLU used.

By aggregating both coarse and fine features learned at different scales from the down-sampling path and up-sampling path, our network is supposed to achieve better segmentation performance than those networks without the aggregation operations.

1.2. Classification

  • The classification task is performed by utilizing image features learned from the down-sampling path. Specifically, features after the 4th max pooling layer are extracted.
  • In order to generate fix-length feature vectors learned from input images with different sizes and scales, spatial pyramid pooling, originated from SPPNet, is applied.
  • The vector is then processed through 2 fully-connected layers, and softmax to calculate class probabilities (pre/postablation) for each image. Dropout of 0.5 is used on FC layers.

1.3. Loss Function

  • The loss function L for our multi-task network is:
  • where LS is segmentation score, LC is classification score, λ=1.
  • For segmentation part, pixel-wise cross-entropy loss is employed.
  • For classification part, sigmoid cross-entropy is used.

The classification loss works as a regularization term, enabling the network to learn the high-level representation that generalizes well on both tasks.

1.4. Post-Processing

  • During the inference time, axial slices extracted from a 3D image are fed into the network slice by slice.
  • By concatenating these segmentation results slice by slice, a rough 3D mask for each patient is produced. In order to refine the boundary of those masks, 3D morphological dilation and erosion are performed, and the largest connected component is kept for each volume.

2. Data

  • (Please kindly skip this section for fast read.)
Original images slices from different views

2.1. Dataset

  • The 2018 Atrial Segmentation Challenge dataset is used, which contains a training set of 100 3D gadolinium-enhanced magnetic resonance imaging scans (GE-MRIs) along with corresponding LA manual segmentation mask and pre/post ablation labels for training and validation.
  • There is an additional set of 54 images without labels for testing.
  • For model training and evaluation, the training set into 80:20 is randomly split.
  • There are two sizes of images: 576×576 and 640×640.
  • The exhibits large differences in images sizes and image contrast, as shown above.

2.2. Preprocessing & Augmentation

  • The image intensity is normalized to zero mean and unit variance.
  • Random horizontal/vertical flip with a probability of 50%.
  • Random rotation with degree range from −10 to +10.
  • Random shifting along X and Y axis within the range of 10 percent of its original image size.
  • Zooming with a factor between 0.7 and 1.3.
  • Random gamma correction as a way of contrast augmentation:
  • where F(x, y) is the original value of each pixel in an image, and G(x, y) is the transformed value for each pixel. γ is randomized from the range of (0.8, 2.0) for each image.
  • With contrast augmentation, the proposed network does not need image contrast enhancement as pre-processing.
Segmentation accuracy using a single-task Deep U-net with different contrast processing strategies

Compared with other pre-processing image contrast enhancement such as AGC and CLAHE, gamma augmentation instead of preprocessing has higher Dice score.

  • Multi-scale cropping is used to increase the data variety, so that network can analyze images with different contexts. The cropped sizes include 256×256, 384×384, 480×480, 512×512, 576×576, 640×640.
  • The network is firstly trained with cropped images where the left atrium taking a large portion of the image and then the image size is gradually increased. The network learns to segment from easy scenarios to hard scenarios.

3. Results

3.1. Quantitative Results

Segmentation accuracy results based on different measurements for different networks and methods

The segmentation performance was greatly improved by increasing the depth of the network, especially in Mitral Valve (MV) planes.

  • By using the multi-task Deep U-Net followed by post-processing, a Dice score of 0.901 is produced.

3.2. Qualitative Results

Example segmentations for axial slices from our MRI set using different methods.
  • The proposed multi-task U-Net is more robust than the other two with only one segmentation goal.

By sharing features with segmentation and related pre/post ablation classification, the network is forced to learn better representation to obtain better segmentation results.

3D visualization of three samples from the validation set. Blue objects are the ground truth, Green ones are the predicted segmentation
  • The model achieved high overlap ratio between our 3D segmentation result and the ground truth in different subjects.
  • However, one significant failure mode can be observed around the region of pulmonary veins. One possible reason might be that the number and the length of pulmonary veins vary from person to person, making it too hard for the network to learn from limited cases.

3.3. Time

  • The total processing procedure took approximately 10s on average on one Nvidia Titan Xp GPU.

3.4. Atrial Segmentation Challenge 2018

  • An ensemble method called Boostrap Aggregating (Bagging) is used.
  • The same model is trained for 5 times, each with a random subset and then the class probabilities are averaged.
  • On the 54 test cases given by the organizers improved from an averaged Dice score of 0.9197 to 0.9206.
  • By parameter sharing, features are enhanced by classification objective, and the segmentation performance is improved.

Reference

[2018 STACOM] [Multi-Task Deep U-Net]
Multi-task Learning for Left Atrial Segmentation on GE-MRI

1.11. Biomedical Multi-Task Learning

2018 [ResNet+Mask R-CNN] [cU-Net+PE] [Multi-Task Deep U-Net]

My Other Previous Paper Readings

--

--

Sik-Ho Tsang

PhD, Researcher. I share what I learn. :) Linktree: https://linktr.ee/shtsang for Twitter, LinkedIn, etc.