Review — Shape-Aware Semi-Supervised 3D Semantic Segmentation for Medical Images
Shape-Aware Semi-Supervised 3D Semantic Segmentation for Medical Images,
SASSNet, by ShanghaiTech University, and Shanghai Engineering Research Center of Intelligent Vision and Imaging,
2020 MICCAI, Over 110 Citations (Sik-Ho Tsang @ Medium)
Medical Imaging, Medical Image Analysis, Semantic Segmentation, Multi-Task Learning, V-Net
- Shape-aware semi-supervised segmentation strategy is proposed to leverage abundant unlabeled data and to enforce a geometric shape constraint on the segmentation output.
- Multi-task deep network is used that jointly predicts semantic segmentation and signed distance map (SDM) of object surfaces.
- During training, an adversarial loss between the predicted SDMs of labeled and unlabeled data is introduced so that the network is able to capture shape-aware features more effectively.
Outline
- Shape-Aware Semi-supervised Segmentation Network (SASSNet)
- Results
1. Shape-Aware Semi-supervised Segmentation Network (SASSNet)
1.1. Model Architecture
- A V-Net backbone is used, with the add of a lightweighted SDM head in parallel with the original segmentation head.
- SDM head is composed by a 3D convolution block followed by the tanh activation.
- Given an input image X of size H×W×D, the segmentation head generates a confidence score map M ∈ [0, 1] of size H×W×D and the SDM head predicts a SDM S ∈ [−1, 1] of size H×W×D as follows:
- where each element of S indicates the signed distance of a corresponding voxel to its closest surface point after normalization.
- (Please read about SDM for more details if interested.)
1.2. Shape-Aware Semi-supervised Learning
- A multi-task loss is developed, which consists of a supervised loss Ls on the labeled set and an adversarial loss La on the entire set to enforce consistency of the model predictions.
- The training set contains N labeled data and M unlabeled data, where N<<M. The labeled set as Dl={Xn, Yn, Zn} where n is from 1 to N, and unlabeled set as Du={Xm} where m is from N+1 to N+M.
- Xn is input volume, Yn is the segmentation annotations and Zn is the groundtruth SDMs derived from Yn.
- Supervised Loss Ls: A dice loss ldice and a mean square loss lmse for the segmentation and SDM output of the multi-task segmentation network, are employed respectively:
- Adversarial Loss La: A discriminator network (GAN) is used to tell apart the predicted SDMs from the labeled set, which should be high-quality due to the supervision, and the ones from the unlabeled set.
- where the predicted SDMs are:
- The discriminator consists of 5 convolution layers followed by an MLP. The network takes a SDM and input volume as input, fuses them through convolution layers, and predicts its class probability of being labeled data.
- Overall Training Pipeline: The overall training objective V(θ, ζ) combines the supervised and the adversarial loss defined above and the learning task can be written as:
- Alternative training is used. Given a fixed discriminator, segmentation network is trained. To simplify the training, the first loss term in La is ignored.
- And a surrogate loss as in GAN is used for the generator, the learning problem for the segmentation network can be written as:
- Then, given a fixed segmentation network, the binary cross entropy loss is used to train the discriminator.
- An annealing strategy based on a time-dependent Gaussian warm-up function to slowly increase the loss weight β.
2. Results
2.1. SOTA Comparisons
- The first setting (Top): which takes 20% of training data as labeled data (16 labeled), and the others as unlabeled data for semi-supervised training.
The proposed SASSNet outperforms all the other semi-supervised networks in both Dice (89.54%) and Jaccard (81.24%), and achieves competitive results on other metrics. In particular, SASSNet surpasses UA-MT in Dice without resorting to a complex multiple network architecture.
- The second setting (Bottom): A more challenging setting in which we only have 8 labeled images for training.
SASSNet outperforms UA-MT with a large margin (Dice: +2.56% without NMS and +3.07% with NMS).
2.2. Qualitative Results
SASSNet tends to generate more foreground regions, which leads to slightly worse performance on ASD and 95HD. However, it also produce better segmentation preserving the original object shape. By contrast, UA-MT often misses inner regions of target objects and generates irregular shapes.
2.3. Ablation Study
- The first row is a V-Net trained with only the labeled data, which is the base model.
- A SDM head is first added, denoted as V-Net+SDM, and as shown in the second row, such joint learning improves segmentation results by 1.1% in Dice.
Then, unlabeled data and adversarial loss are added, denoted as V-Net+SDM+GAN, which significantly improves the performance (5.7% in Dice).
- The Mean Teacher (MT) framework (last two rows) is also evaluated. One is the original UA-MT and the other is the proposed segmentation network with the MT consistency loss.
The proposed SASSNet outperforms both methods with higher Dice and Jaccard scores.
Reference
[2020 MICCAI] [SASSNet]
Shape-Aware Semi-Supervised 3D Semantic Segmentation for Medical Images
4.5. Biomedical Image Semi-Supervised Learning
2019 [UA+MT] 2020 [SASSNet]