Review — Recurrent U-Net for Resource-Constrained Segmentation
1. Recurrent U-Net (R-UNet)
The goal is to operate in resource-constrained environments and keep the model relatively simple.
1.1. (a) & (b) Overall Architecture
- U-Net is used as the base architecture.
- The first convolutional unit has 8 feature channels, and, following the original U-Net strategy, the channel number doubles after every pooling layer in the encoder.
- The decoder relies on transposed convolutions.
- Group normalization is used in all convolutional layers for small batch size.
Recursions are integrated on 1) the predicted segmentation mask s and 2) multiple internal states of the network.
- The former can be achieved by simply concatenating, at each recurrent iteration t, the previous segmentation mask st−1 to the input image, and passing the resulting concatenated tensor through the network.
- For the latter, a subset of the encoding and decoding layers of the U-Net is replaced with a recurrent unit. And there are two variants of its internal mechanism, as below.
1.2. (c) Dual-gated Recurrent Unit (DRU)
- Inspired by Gated Recurrent Unit (GRU), the recurrent unit used here, which replaces multiple encoding and decoding layers of the segmentation network, is similar to GRU, in order to preserve the underlying motivation of GRUs.
- Specifically, at iteration t, given the activations elt and the previous hidden state ht−1, the aim is to produce a candidate update ˆh for the hidden state and combine it with the previous one according to how reliable the different elements of this previous hidden state tensor are.
- Left: To determine this reliability, an update gate defined by a tensor is used:
- where fz(·) denotes an encoder-decoder network with the same architecture as the portion of the U-Net that is replaced with the proposed recurrent unit.
- Bottom & Bottom Right: Similarly, the candidate update is obtained as:
- where fh(·) is a network with the same architecture as fz(·), but a separate set of parameters, ⊙ denotes the element-wise product, and r is a reset tensor allowing to mask parts of the input used to compute ˆh. It is computed as:
- When sigmoid output is 1, there is no masking. When sigmoid output is 0, the input is close to 0 (masked).
- Given these different tensors, the new hidden state ht is computed as:
- Finally, the output of the recurrent unit, which corresponds to the activations of the lth decoding layer is predicted as:
- where fs(·) is a simple convolutional block.
Since it relies on two gates, r and z, it is called Dual-gated Recurrent Unit (DRU).
1.3. (d) Single-gated Recurrent Unit (SRU)
- DRU may become memory-intensive depending on the choice of l.
- SRU has a structure similar to that of the DRU, but without the reset tensor r:
- SRU comes at a very little loss in segmentation accuracy.
- The cross-entropy loss is used.
- Supervision at each iteration of the recurrence is used. Thus, the overall loss as:
- where N represents the number of recursions, and is set to 3.
- α=1, so that all iterations have equal importance, or α=0.4, seek to put more emphasis on the final prediction.
- Hand-segmentation benchmarks as above are used. However, they are relatively small, with at most 4,800 images in total.
- A larger dataset is acquired on authors’ own. Because this work was initially motivated by an augmented virtuality project whose goal is to allow someone to type on a keyboard while wearing a head-mounted display, 50 people are asked to type on 9 keyboards while wearing an HTC Vive, resulting in a total of 12,536 annotated frames, as above.
- 20/20/60% are used to split for train/validation/test to set up a challenging scenario.
- Retina Vessels, Roads. and Cityscapes are also used.
2.2. Hand Segmentation
- Ours-SRU(l) denotes different l cases, e.g.: l=3.
- U-Net-B uses batch normalization and U-Net-G uses group normalization.
- Rec-Last is proposed to add a recurrent unit after a convolutional segmentation network to process sequential data.
- Rec-Middle, uses the recurrent unit to replace the bottleneck between the U-Net encoder and decoder, instead of being added at the end of the network.
- Rec-Simple proposes a recursive refinement process, which concatenates the segmentation mask with the input image and feed it into the network.
- U-Net-VGG16 and DRU-VGG16 are proposed to replace the U-Net and DRU encoder with a pretrained VGG-16 backbone respectively.
- Similar for U-Net-ResNet50 and DRU-ResNet50.
Overall, among the light models, the recurrent methods usually outperform the one-shot ones. Besides, among the recurrent ones, Ours-DRU(4) and Ours-SRU(0) clearly dominate, with Ours-DRU(4) usually outperforming Ours-SRU(0) by a small margin.
Ours-DRU(4) is better than the heavy RefineNet model on 4 out of the 5 datasets, despite RefineNet representing the current state of the art.
DRU-VGG16 model, which, by using a pretrained deep backbone, yields the overall best performance.
- DRU-VGG16 outperforms Ours-DRU, e.g., by 0.02 mIoU points on KBH. This, however, comes at a cost.
To be precise, DRU-VGG16 has 41.38M parameters. This is 100 times larger than Ours-DRU(4), which has only 0.36M parameters.
Moreover, DRU-VGG16 runs only at 18 fps, while Ours-DRU(4) reaches 61 fps.
2.2. Retina Vessel
- This may be due to the availability of only limited data, which leads to overfitting for such a very deep network.
Even the tiny vessel branches in the retina which are ignored by the human annotators could be correctly segmented by the proposed algorithm. Better viewed in color and zoom in.
The proposed methods also outperform all the baselines by a clear margin on this task, with or without ImageNet pretraining.
- In particular, Ours-DRU(4) yields an mIoU 8 percentage point (pp) higher than U-Net-G, and DRU-VGG16 5pp higher than U-Net-VGG16. This verifies that the recurrent strategy helps.
Ours-DRU is consistently better than U-Net-G and than the best recurrent baseline, i.e., Rec-Last.
- Furthermore, doubling the number of channels of the U-Net backbone increases accuracy, and so does using a pretrained VGG-16 as encoder.
It is practical for real-time application, reaching 55 frames-per-second (fps) to segment 230×306 images on an NVIDIA TITAN X with only 12G memory.