Review — UNeXt: MLP-Based Rapid Medical Image Segmentation Network
UNeXt: MLP-Based Rapid Medical Image Segmentation Network,
UNeXt, by Johns Hopkins University,
2022 MICCAI, Over 50 Citations (Sik-Ho Tsang @ Medium)
- UNeXt is proposed with convolutions used in early stages and MLP used in latent stages.
- A tokenized MLP block is proposed to tokenize and project the convolutional features and MLPs are used to model the representation.
- The channels of the inputs are Shifted so that MLPs focus on learning local dependencies.
1.1. Overall Architecture
- UNeXt is an encoder-decoder architecture with two stages: 1) Convolutional stage, and a 2) Tokenized MLP stage.
- The input image is passed through the encoder where the first 3 blocks are convolutional and the next 2 are Tokenized MLP blocks.
- The decoder has 2 Tokenized MLP blocks followed by 3 convolutional blocks.
- Each encoder block reduces the feature resolution by 2 using max-pooling layer with pool window 2×2 and each decoder block increases the feature resolution by 2 using bilinear interpolation.
- Skip connections are also included between the encoder and decoder. The number of channels across each block is a hyperparameter denoted as C1 to C5. In experiments, C1=32, C2=64, C3=128, C4=160, and C5=256 unless stated otherwise.
- Each convolutional block is equipped with a convolution layer, a batch normalization layer and ReLU activation. a kernel size of 3×3, stride of 1 is used.
1.2. Shifted MLP
- In shifted MLP, the axis of the channels of conv features is first shifted before tokenizing. This helps the MLP to focus on only certain locations of the conv features thus inducing locality to the block. Authors said that the intuition here is similar to Swin Transformer.
- As the Tokenized MLP block has 2 MLPs, the features are shifted across width in one and across height in another like the axial-attention in Axial-DeepLab.
- The features are split to h different partitions and shifted by j=5 locations according to the specified axis.
1.3. Tokenized MLP Stage
- To tokenize, a kernel size of 3 is used first used to change the number of channels to the embedding dimension E (number of tokens).
- Then, these tokens are passed to a shifted MLP (across width) with hidden dimensions of the MLP, H=768 unless stated otherwise.
- Next, depthwise convolutional layer (DWConv) is used. It helps to encode a positional information, as suggested in SegFormer , which has better performance than ViT when train/test resolutions are different. Also, it uses fewer number of parameters and hence increases efficiency.
- GELU is used, which is a more smoother alternative and is found to perform better where ViT and BERT also use GELU.
- The features are passed through another shifted MLP (across height) that converts the dimensions from H to O.
- A residual connection is used here to add the original tokens to residuals. Layer normalization (LN) is then used. Finally, the output features are passed to the next block.
1.4. Loss Function
- A combination of binary cross entropy (BCE) and dice loss is used:
2.1. SOTA Comparisons
UNeXt clearly outperforms all the other networks in terms of computational complexity.
- Swin-Unet (Not shown in the figure) is heavy with 41.35 M parameters and also computationally complex with 11.46 GFLOPs.
- MLP-Mixer focuses on channel mixing and token mixing to learn a good representation.
Authors also experimented with MLP-Mixer as encoder and a normal convolutional decoder. But the performance was not optimal for segmentation and it was still heavy with around 11M parameters.
2.2. Qualitative Results
UNeXt produces competitive segmentation predictions compared to the other methods.
2.3. Ablation Study
When the depth is reduced and only a 3-level deep architecture is used, which is basically the Conv stage of UNeXt, the number of parameters and complexity are reduced significantly but also the performance is reduced by 4%.
But when the tokenized MLP block is used, it improves the performance significantly.
Increasing the channels (UNeXt-L) further improves the performance while adding on to computational overhead.
Although decreasing it (UNeXt-S) reduces the performance (the reduction is not drastic) but we get a very lightweight model.