Brief Review — COVID-MTL: Multitask learning with Shift3D and random-weighted loss for COVID-19 diagnosis and severity assessment
COVID-MTL, Simultaneous Detection and Severity Assessment of COVID-19
COVID-MTL: Multitask Learning with Shift3D and Random-Weighted Loss for COVID-19 Diagnosis and Severity Assessment,
COVID-MTL, by The University of Sydney, Shanghai Jiao Tong University, Shandong First Medical University, and Shandong Academy of Medical Sciences
2022 J. Patten Recognition (Sik-Ho Tsang @ Medium)
- COVID-MTL learns different COVID-19 tasks in parallel through the proposed novel random-weighted loss function, which assigns learning weights under Dirichlet distribution to prevent task dominance.
- A new 3D real-time augmentation algorithm (Shift3D) is also proposed, which introduces space variances for 3D CNN components by shifting low-level feature representations of volumetric inputs in three dimensions.
- COVID-MTL consists of six major components for diagnosis and severity assessment on COVID-19 CT inputs.
- An unsupervised 3D lung segmentation module is first used to extract lung volumes from chest CT scans. (More details in 1.2. Below.)
- Then, a feature extractor is to obtain high-throughput CT features including intensity, texture, and wavelet features from the segmented lung volumes. (More details in 1.3. Below.)
- Next, the segmented lung volumes and extracted lung features are fed into a ShiftNet3D and a feed-forward neural network (FNN), respectively. ShiftNet3D includes a Shift3D layer to boost network performance through introducing space variances on low-level feature representations of the volumetric inputs and 8 consecutive 3D Fire modules (the backbone of SqueezeNet) are used as shared hidden layers between all tasks. (More details in 1.4. Below.)
- To learn task-specific representations, each task has its own output layer (TSL layers) and loss function.
- High-level feature representations obtained from CT imaging features through the auxiliary FNN (AFNN) are concatenated with each TSL layer for performance enhancement. As such, there are two backpropagation paths in the overall MTL network.
- Last, a random-weighted loss function is attached to calculate the combined task loss so that different COVID-19 tasks can be trained simultaneously by using weighted total loss as guidance. (More details in 1.5. Below.)
1.2. Unsupervised 3D Lung Segmentation
- An active contour-based algorithm is proposed to refine the initial segmentation results produced by the classical method .
- The inflated contours of the initially segmented lungs are used as seeds for the refinements.
- (The refinement is not in details in the paragraph. But as shown in the psuedo codes above, snake is used.)
- The proposed refinement algorithm can accurately detect white lung areas from COVID-19 CT studies in different scenarios.
1.3. High-Throughput Lung Features
- After the lung volumes were automatically segmented by the proposed unsupervised algorithm, a total of 375 high-throughput lung features:
- including First Order Statistics, Gray Level Cooccurrence Matrix (GLCM), Gray Level Run Length Matrix (GLRLM), Gray Level Size Zone Matrix (GLSZM), and Wavelet features, were extracted from corresponding lung volumes for the cohort study.
- According to results section, for each COVID-19 task, feature importance was generated after training the machine learning model (LGBM in this work). The top 10 most important imaging features were selected for each task.
- A 3D real-time augmentation method, named Shift3D, which introduces space variances through randomly shifting low-level feature representations of the volumetric inputs in three dimensions (or 6 directions).
- Because the geographical location of human organs in CT scans varies from one case to another, even for different scans of the same patient.
1.5. Random-Weighted Multitask Loss
- The random weights are drawn from the Dirichlet distribution so that that each of the k tasks has the chance to be prioritized if the number of iterations in joint training is large.
- The Dirichlet distribution uses a probability density function (PDF):
- where B(α) is a gamma function:
- Since the objective function for each task is a cross-entropy loss:
- The total loss function of a MTL model with random-weighted loss is:
- For the special case when K=2, the weights for the two learning tasks:
- One can draw random weights n=2 times and average the results to avoid potential heavy fluctuation of a single Dirichlet draw, thus the total loss function can be finally modeled as:
In comparison to single-task models, the MTL approach achieved around 1.5% −6% higher accuracy for detection against radiology and significant higher diagnosis performance against nucleic acid tests, e.g., over 3% accuracy promotion compared to ShiftNet3D, and around 7–9% compared to ResNet3D and SqueezeNet3D.
- COVID-MTL with Shift3D consistently outperform models without Shift3D.
- Both the classical method  and the proposed algorithm achieved high segmentation performance.
The proposed refinement method consistently improved the state-of-the-art and the benefits are expected coming from the improvements of under-segmentation in white lung areas.
In comparison to single-task models, the MTL approach achieved around 1.5% −6% higher accuracy.
- ShiftNet3D yielded similar performance compared to its backbone model (SqueezeNet3D).
In comparison, the COVID-MTL achieved a slight performance boost with an AUC of 0.800 ±0.020 and an accuracy of 66.67%.
- (There are still a lot of experimental results that are not shown here. Please feel free to read the paper directly if interested.)