Review — A New Ensemble Learning Framework for 3D Biomedical Image Segmentation
Semi-Supervised Learning, Improved Result Using Unlabeled Data With Pseudo Labels
In this story, A New Ensemble Learning Framework for 3D Biomedical Image Segmentation, (NN-Fit), by University of Notre Dame, is reviewed. In this paper:
- A fully convolutional network based meta-learner is designed to learn how to improve the results from 2D and 3D models (base-learners).
- A new training methods, i.e. random fit and nearest-neighbor (NN) fit, are proposed to minimize the overfitting in meta-learner.
This is a paper in 2019 AAAI. (Sik-Ho Tsang @ Medium)
Outline
- Problems in 3D Biomedical Image Segmentation
- Framework Overview
- Meta-Learner Training Using Pseudo-Labels
- Experimental Results
1. Problems in 3D Biomedical Image Segmentation
- Due to the limitations of both GPU memory and computing power, when designing 2D/3D CNNs for 3D biomedical image segmentation, the trade-off between the field of view and utilization of inter-slice information in 3D images remains a major concern.
- In addition, usually dataset is small in 3D biomedical image segmentation which easily lead to overfitting.
2. Framework Overview
- The proposed approach has two main components, as shown above:
- A group of 2D and 3D base-learners that are trained to explore the training data from different geometric perspectives.
- An ensemble learning framework that uses a deep learning based meta-learner to combine the results from the base-learners.
2.1. 2D and 3D Base-Learners
2.1.1. 2D Models
- The 2D model basically follows the structure of that in Suggestive Annotation (SA).
- It integrates in recent advances of deep learning network design structures, such as batch normalization, residual networks and bottleneck design.
2.1.2. 3D Models
- As for the 3D model, DenseVoxNet is used.
- 3D FCN architecture is used to fully incorporate 3D image cues and geometric cues.
- It utilizes the state-of-the-art dense connectivity to accelerate the training process, improve parameters and computational efficiency, and maintain abundant (both low- and high-complexity) features
- It takes advantage of auxiliary side paths for deep supervision to improve the gradient flow.
2D models can have large fields of view in 2D slices while 3D models can better utilize 3D image information in a smaller field of view.
The mix of 2D and 3D base-learners creates the first level of diversity.
- To further boost diversity, multiple 2D views (representations) of the 3D images (e.g., xy, xz, and yz views) can be created.
Thus, in the framework, four base-learners are used: A 3D DenseVoxNet for utilizing full 3D information; three 2D FCNs for large fields of view in the xy, xz, and yz planes.
2.2. Meta-Learner
- Given a set of image samples, X={x1, x2, …, xn}, and a set of base-learners, F={f1, f2, …, fm}, a pseudo-label set for each xi can be obtained as PLi={f1(xi), f2(xi), …, fm(xi)}.
- That means the pseudo labels are the output of the base-learners.
- The input of our meta-learner H includes two parts: xi and S(PLi), where S is a function of PLi that forms a representation of PLi.
- Separate encoding blocks (i.e., DenseBlock 1.1 and DenseBlock 1.2) are used for extracting information from S(PLi) and xi, respectively, before the information fusion.
- The auxiliary loss in the side path can improve the gradient flow within the network.
- There are multiple design choices for constructing S, for example: Concatenation or averaging. But averaging has shown to have slightly better results.
3. Meta-Learner Training Using Pseudo-Labels
- Instead of using the manually labelled ground truth to supervise the meta-learner training, the pseudo-labels produced by the base-learners are treated as ground truth.
- Because there are multiple possible targets (pseudo-labels) for the meta-learner to fit, the meta-learner is unlikely to overfit any fixed target.
- So, unlabelled data can be used.
- The meta-learner training consists of two phases: (1) random-fit, and (2) nearest-neighbor-fit.
3.1. Random-Fit
- In the first training phase (which aims to train the meta-learner H to reach a near-optimal solution), cross-entropy loss is used:
- where θH is the meta-learner’s model parameters and lmce is a multi-class cross-entropy criterion.
- In the SGD-based optimization, for one image sample xi, the random-fit algorithm randomly chooses a pseudo-label from PLi and sets it as the current “ground truth” for xi (see Algorithm 1).
- This ensures the supervision signals not to impose any bias towards any base-learner.
3.2. Nearest-Neighbor-Fit (NN-Fit)
- In the second training phase, the meta-learner is aimed to be trained to fit the nearest pseudo-label, to help the model training process converge.
4. Experimental Results
- As unlabeled data can be used for training, there are 3 settings: Supervised Learning, Semi-Supervised Learning and Transductive Setting.
4.1. Supervised Learning (Only Training Data)
- Without using unlabeled data, the proposed meta-learner outperforms the above methods on nearly all the metrics and has a very high overall score, 0.215 (Proposed) vs -0.161 (DenseVoxNet), -0.036 (tri-planar), and 0.108 (VFN).
- As shown above, 2D and 3D base-learners already achieve better results.
- The proposed meta-learner further improves the accuracy of the base-learners, and also achieves a result that is considerably better than the known state-of-the-art methods (0.9967 vs. 0.9866).
4.2. Semi-Supervised Learning
- The training set of HVSMR 2016 is randomly divided into two groups evenly, Sa and Sb.
- Sa is labeled data and Sb is unlabeled data.
- By leveraging unlabeled images, the proposed approach can improve the model accuracy and generalize well to unseen test data.
4.3. Transductive Setting
- The full training data is used to train our base learners, and the training and testing data are used to train our meta-learner.
- The transductive setting plays an important role in many biomedical image segmentation tasks (e.g., for making biomedical discoveries). For example, after biological experiments are finished, one may have all the raw images available and the sole remaining goal is to train a model to attain the best possible segmentation results.
- Improved results are obtained, which as shown in Tables in 4.1.
4.4. Ablation Study
- Different combinations of experimental settings are tried.
- e.g.: Using GT data, or PL data, the use of test data for training (transductive learning), random fit alone or both random and NN fits.
- For transductive learning, S9 using PL data only obtains the best results, even better than using GT+PL data (S5).
- (There are also other findings for this table, if interested, please feel free to read the paper.)
Reference
[2019 AAAI] [NN-Fit]
A New Ensemble Learning Framework for 3D Biomedical Image Segmentation
Biomedical Image Segmentation
2015: [U-Net]
2016: [CUMedVision1] [CUMedVision2 / DCAN] [CFS-FCN] [U-Net+ResNet] [MultiChannel] [V-Net] [3D U-Net]
2017: [M²FCN] [Suggestive Annotation (SA)] [3D U-Net+ResNet] [Cascaded 3D U-Net] [DenseVoxNet]
2018: [QSA+QNT] [Attention U-Net] [RU-Net & R2U-Net] [VoxResNet] [UNet++] [H-DenseUNet]
2019: [DUNet] [NN-Fit]
2020: [MultiResUNet] [UNet 3+] [VGGNet for COVID-19] [Dense-Gated U-Net (DGNet)]