Brief Review — MultiMix: Sparingly Supervised, Extreme Multitask Learning From Medical Images
MultiMix, Multi-Task Learning (MTL), With Pseudo-Labeling and Saliency Map
MultiMix: Sparingly Supervised, Extreme Multitask Learning From Medical Images
MultiMix, by Saratoga High School, Stanford University, University of California, and VoxelCloud, Inc.
2021 ISBI (Sik-Ho Tsang @ Medium)
- MultiMix is proposed, which jointly learns disease classification and anatomical segmentation in a sparingly supervised manner, while preserving explainability through bridge saliency between the two tasks.
- U-Net-like encoder-decoder architecture with skip connection is used for image deconstruction and reconstruction.
- For sparingly-supervised classification, MultiMix leverages data augmentation and pseudo-labeling.
- Inspired by FixMatch, two separate augmentations are applied onto an unlabeled image.
- A single unlabeled image is first weakly augmented, and from that weakly augmented version of the image, a pseudo-label is assumed based on the prediction from the current state of the model.
- Secondly, the same unlabeled image is then augmented strongly, and a classification loss is calculated with the pseudo-label from the weakly augmented image and the strongly augmented image itself.
- Finally, the classification loss is:
- where Ll is the supervised loss, which uses cross-entropy.
- And Lu is the unsupervised loss. Only image-label pair is retained only if the confidence with which the model generates the pseudo-label is above a tuned threshold t, which prevents the model from learning from incorrect and poor labels.
- Saliency maps are also generated based on the predicted classes using the gradients of the encoder. While the segmentation images do not necessarily represent pneumonia, the classification task, the generated maps highlight the lungs, creating images at the final segmentation resolution. It is hypothesized that these saliency maps can be used to guide the segmentation during the decoder phase, yielding improved segmentation while learning from limited labeled data.
- The generated saliency maps are concatenated with the input images, downsampled, and added to the feature maps input to the first decoder stage.
- To ensure consistency, the KL divergence is also computed between segmentation predictions for labeled and unlabeled examples.
- The segmentation loss is:
- where Ll is the supervised Dice loss.
- And Lu is the unsupervised segmentation loss. (It should be the KL divergence.)
- Training: All the models (single-task or multitask) were trained on varying segmentation dataset size |Dsl| (10, 50, full), and classification dataset |Dcl| (100, 1000, full). Each experiment was repeated 5 times.
- The models were trained on the CheX and JSRT datasets.
- In Domain: For classification, the proposed semi-supervised algorithm has significantly improved performance compared to the baseline model.
- For segmentation, the saliency bridge, the primary addition, yields large improvements over the baseline U-Net and U-MTL. With min |Dsl|, a 30% performance gain over its counterparts proves the effectiveness of MultiMix model.
- Cross Domain: Similar trends to In Domain.
- The above figure depicts better consistency by the proposed model over the baselines.
- The segmented lung boundary visualizations also show good agreement with the reference masks by MultiMix over other models.