Review — BYOL Works Even Without Batch Statistics
BYOL Works Even Without Batch Statistics
BYOL+GN+WS, by DeepMind, Imperial College
2020 arXiv v1, Over 20 Citations (Sik-Ho Tsang @ Medium)
Self-Supervised Learning, Teacher Student, Image Classification, Batch Normalization, BN, Group Normalization, GN
- It is hypothesized that BN is critical to prevent collapse in BYOL where BN flows gradients across batch elements, and could leak information about negative views in the batch.
- In this tech report, it shows that even without the use of BN, but by using GN+WS, which does not use batch statistics, BYOL also performs well.
1. BYOL Brief Review
- BYOL trains its representation using both an online network (parameterized by θ) and a target network (parameterized by ξ).
- As a part of the online network, it further defines a predictor network qθ that is used to predict target projections zξ′ using online projections zθ as inputs.
- Accordingly, the parameters of the online projection are updated following the gradients of the prediction loss.
- In turn, the target network weights are updated as an exponential moving average (EMA) of the online network’s weights:
- with η being a decay parameter.
- As qθ(zθ) is a function of v and zξ′ is a function of v′, BYOL loss can be seen as a measure of similarity between the views v and v′ and therefore resembles the positive term of the InfoNCE loss (Contrastive):
2.1. Group Normalization (GN)
- For an activation tensor X of dimensions (N, H, W, C), GN first splits channels into G equally-sized groups, then normalizes activations with the mean and standard deviation computed over disjoint slices of size (1, H, W, C/G).
- If G=1, it is equivalent to Layer Norm (LN).
- When normalization is over a single one, i.e. G=C, it is equivalent to Instance Norm (IN).
GN operates independently on each batch element and therefore it does not rely on batch statistics.
2.2. Weight Standardization (WS)
- WS normalizes the weights corresponding to each activation using weight statistics.
- Each row of the weight matrix W is normalized to get a new weight matrix ^W which is directly used in place of W during training.
- Only the normalized weights ^W are used to compute convolution outputs but the loss is differentiated with respect to non-normalized weights W:
- where I is the input dimension (i.e. the product of input channel dimension and kernel spatial dimension).
Contrary to BN, LN, and GN, WS does not create additional trainable weights.
3. Experimental Results
3.1. Removing BN Causes Collapse
- First, it is observed that removing all instances of BN in BYOL leads to performance (0.1%) that is no better than random. This is specific to BYOL as SimCLR still performs reasonably well in this regime.
- Nevertheless, solely applying BN to the ResNet encoder is enough for BYOL to achieve high performance (72.1%). It is hypothesized that the main contribution of BN in BYOL is to compensate for improper initialization
3.2. Proper Initialization Allows Working Without BN
- To confirm this assumption, a protocol is designed to mimic the effect of BN on initial scalings and training dynamics, without using or backpropagating through batch statistics .
- Before training, per-activation BN statistics for each layer is computed by running a single forward pass of the network with BN on a batch of augmented data. Then BN layers are removed, but retain the scale γ and offset β parameters and trainable.
- Despite its comparatively low performance (65.7%), the trained representation still provides considerably better classification results than a random ResNet-50 backbone, and is thus necessarily not collapsed.
This confirms that BYOL does not need BN to prevent collapse. Thus, authors try to explore other refined element-wise normalization procedures.
3.3. Using GN with WS Leads to Competitive Performance
- More precisely, WS is applied to convolutional and linear parameters by weight standardized alternatives, and all BNs are replaced by GN layers (G=16).