Review — BYOL Works Even Without Batch Statistics
BYOL Works Even Without Batch Statistics
BYOL+GN+WS, by DeepMind, Imperial College
2020 arXiv v1, Over 20 Citations (Sik-Ho Tsang @ Medium)
Self-Supervised Learning, Teacher Student, Image Classification, Batch Normalization, BN, Group Normalization, GN
- It is hypothesized that BN is critical to prevent collapse in BYOL where BN flows gradients across batch elements, and could leak information about negative views in the batch.
- In this tech report, it shows that even without the use of BN, but by using GN+WS, which does not use batch statistics, BYOL also performs well.
1. BYOL Brief Review
- BYOL trains its representation using both an online network (parameterized by θ) and a target network (parameterized by ξ).
- As a part of the online network, it further defines a predictor network qθ that is used to predict target projections zξ′ using online projections zθ as inputs.
- Accordingly, the parameters of the online projection are updated following the gradients of the prediction loss.
- In turn, the target network weights are updated as an exponential moving average (EMA) of the online network’s weights:
- with η being a decay parameter.
- As qθ(zθ) is a function of v and zξ′ is a function of v′, BYOL loss can be seen as a measure of similarity between the views v and v′ and therefore resembles the positive term of the InfoNCE loss (Contrastive):
- For an activation tensor X of dimensions (N, H, W, C), GN first splits channels into G equally-sized groups, then normalizes activations with the mean and standard deviation computed over disjoint slices of size (1, H, W, C/G).
- If G=1, it is equivalent to Layer Norm (LN).
- When normalization is over a single one, i.e. G=C, it is equivalent to Instance Norm (IN).
GN operates independently on each batch element and therefore it does not rely on batch statistics.
- WS normalizes the weights corresponding to each activation using weight statistics.
- Each row of the weight matrix W is normalized to get a new weight matrix ^W which is directly used in place of W during training.
- Only the normalized weights ^W are used to compute convolution outputs but the loss is differentiated with respect to non-normalized weights W:
- where I is the input dimension (i.e. the product of input channel dimension and kernel spatial dimension).
3. Experimental Results
3.1. Removing BN Causes Collapse
- First, it is observed that removing all instances of BN in BYOL leads to performance (0.1%) that is no better than random. This is specific to BYOL as SimCLR still performs reasonably well in this regime.
- Nevertheless, solely applying BN to the ResNet encoder is enough for BYOL to achieve high performance (72.1%). It is hypothesized that the main contribution of BN in BYOL is to compensate for improper initialization
3.2. Proper Initialization Allows Working Without BN
- To confirm this assumption, a protocol is designed to mimic the effect of BN on initial scalings and training dynamics, without using or backpropagating through batch statistics .
- Before training, per-activation BN statistics for each layer is computed by running a single forward pass of the network with BN on a batch of augmented data. Then BN layers are removed, but retain the scale γ and offset β parameters and trainable.
- Despite its comparatively low performance (65.7%), the trained representation still provides considerably better classification results than a random ResNet-50 backbone, and is thus necessarily not collapsed.
- More precisely, WS is applied to convolutional and linear parameters by weight standardized alternatives, and all BNs are replaced by GN layers (G=16).