Review — There Are Many Consistent Explanations of Unlabeled Data: Why You Should Average
SWA & Fast SWA, Improving Π-Model & Mean Teacher
There Are Many Consistent Explanations of Unlabeled Data: Why You Should Average
SWA & Fast SWA, by Cornell University
2019 ICLR, Over 100 Citations (Sik-Ho Tsang @ Medium)
Semi-Supervised Learning, Image Classification
- Stochastic Weight Averaging (SWA) is proposed, which averages weight along the trajectory of SGD with a modified learning rate schedule.
- Fast-SWA is further proposed, which accelerates convergence by averaging multiple points within each cycle of a cyclical learning rate schedule.
Outline
- Conventional Π-Model & Mean Teacher (MT)
- Proposed Stochastic Weight Averaging (SWA) & Fast SWA
- Experimental Results
1. Conventional Π-Model & Mean Teacher (MT)
1.1. Consistency Loss
- In the semi-supervised setting, we have access to labeled data DL={(xLi, yLi)} for i from 1 to NL, and unlabeled data DU={xUi} for i=1 to NU.
- Given two perturbed inputs x’, x’’ of x and the perturbed weights w’f and w’g, the consistency loss penalizes the difference between the student’s predicted probabilities f(x’, w’f) and the teacher’s g(x’’, w’g).
- This loss is typically the Mean Squared Error or KL divergence:
- The total loss used to train the model can be written as:
- where for classification LCE is the cross entropy between the model predictions and supervised training labels. The parameter λ>0 controls the relative importance of the consistency term in the overall loss.
1.2. Π-Model
- Π-Model uses the student model f as its own teacher.
- But the data (input) is perturbed by random translations, crops, flips and additive Gaussian noise. Binary dropout is used for weight perturbation.
1.3. Mean Teacher (MT)
- The teacher weights wg are the exponential moving average (EMA) of the student weights wf:
- where the decay rate α is usually set between 0.9 and 0.999.
2. Proposed Stochastic Weight Averaging (SWA) & Fast SWA
2.1. Cyclical Schedule
- Stochastic Weight Averaging (SWA) is a recent approach by Izmailov et al., 2018, which is based on averaging weights traversed by SGD with a modified learning rate schedule.
- For the first l≤l0 epochs, the network is pre-trained using the cosine annealing schedule.
- After l epochs, a cyclical schedule is used, repeating the learning rates from epochs [l-c, l], where c is the cycle length.
2.2. SWA
- Left (green dot): SWA collects the networks corresponding to the minimum values of the learning rate and averages their weights. The model with the averaged weights wSWA is then used to make predictions.
- SWA is applied to the student network both for the Π-Model and Mean Teacher model.
- However, SWA updates the average weights only once per cycle, which means that many additional training epochs are needed in order to collect enough weights for averaging.
2.3. Fast SWA
- Left (red dot): A modification of SWA that averages networks corresponding to every k<c epochs starting from epoch l-c. Average multiple weights are obtained within a single epoch setting k<1.
3. Experimental Results
For all quantities of labeled data, fast-SWA substantially improves test accuracy in both architectures.
The above table shows the summary that fast-SWA can significantly improve the performance of both the Π-Model and Mean Teacher Model.
Please feel free to read the paper for more detailed results if interested.
Reference
[2019 ICLR] [SWA & Fast SWA]
There Are Many Consistent Explanations of Unlabeled Data: Why You Should Average
Pretraining or Weakly/Semi-Supervised Learning
2004 [Entropy Minimization, EntMin] 2013 [Pseudo-Label (PL)] 2015 [Ladder Network, Γ-Model] 2016 [Sajjadi NIPS’16] 2017 [Mean Teacher] [PATE & PATE-G] [Π-Model, Temporal Ensembling] 2018 [WSL] [Oliver NeurIPS’18] 2019 [VAT] [Billion-Scale] [Label Propagation] [Rethinking ImageNet Pre-training] [MixMatch] [SWA & Fast SWA] 2020 [BiT] [Noisy Student] [SimCLRv2]