# Review — ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring

## ReMixMatch, Improving MixMatch by Introducing Two New Techniques

ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation AnchoringReMixMatch, by Google Research ,and Google Cloud AI2020 ICLR, Over 400 Citations(Sik-Ho Tsang @ Medium)

Semi-Supervised Learning, Image Classification

**ReMixMatch improves****MixMatch**

**Distribution alignment**encourages the marginal distribution of predictions on unlabeled data to be close to the marginal distribution of ground-truth labels.**Augmentation anchoring**feeds multiple strongly augmented versions of an input into the model and encourages each output to be close to the prediction for a weakly-augmented version of the same input.

# Outline

**MixMatch****Brief Review****Distribution Alignment****Augmentation Anchoring****Putting It All Together****Experimental Results**

# 1. MixMatch Brief Review

- MixMatch generates “guessed labels” for each unlabeled example, and then using fully-supervised techniques to train on the original labeled data along with the guessed labels for the unlabeled data.
- MixMatch first produces
of each unlabeled datapoint. A guessed label is generated by*K*weakly augmented versions**averaging the***K*predictions. **Sharpening**is used reduce the entropy of the label distribution.**Pairs of examples (**from the combined set of labeled examples and unlabeled examples with label guesses are fed into the*x*1,*p*1), (*x*2,*p*2)**mixup**.- A standard
**cross-entropy loss**is used for**labeled data**, whereas the loss for**unlabeled data**is computed using a**mean square error**. - (Please feel free to read MixMatch if interested.)

# 2. Distribution Alignment

Distribution Alignmentenforces that the aggregate of predictions on unlabeled data matches the distribution of the provided labeled data.

**A running average of the model’s predictions on unlabeled data**is maintained, which is referred to as**~**.*p*(*y*)- Given the model’s prediction
*q*=*pmodel*(*y*|*u*,*θ*) on an unlabeled example*u*,and then*q*is scaled by the ratio*p*(*y*)/~*p*(*y*)**renormalized**to form a valid probability distribution:

- where:

- Then,
**~***q*is used as the label guess for*u*. - In practice,
**~**is computed as*p*(*y*)**the moving average of the model’s predictions on unlabeled examples over the last 128 batches**. And the**marginal class distribution**is estimated*p*(*y*)**based on the labeled examples seen**during training.

# 3. Augmentation Anchoring

- A variant of AutoAugment or RandAugment, CTAugment, is proposed.
- It is hypothesized that
**MixMatch****with****AutoAugment****is unstable**: MixMatch averages the prediction across*K*augmentations.**Stronger augmentation**can result in**disparate predictions**, so their average may not be a meaningful target.

Instead, given an unlabeled input, an “anchor” is

first generated by applying weak augmentationto it. Then,using CTAugment (described below).Kstrongly-augmented versions of the same unlabeled input are generated

The guessed label(after applying distribution alignment and sharpening) is used as thetarget for all of theof the image.Kstrongly-augmented versions

## 3.1. Control Theory Augment (CTAugment)

CTAugment only samples augmentations that fall within the network tolerance.

- Let
be the*m***vector of bin weights**for some**distortion parameter for some transformation**. At the beginning of training, all magnitude bins are initialized to have a weight set to 1. - At each training step,
**for each image two transformations are sampled uniformly at random.** - A modified set of bin weights ^
*m*is produced where ^*mi*=*mi*if*mi*>0.8 and ^*mi*=0 otherwise, and magnitude bins are sampled from Categorical(Normalize(^*m*)). - The resulting transformations are applied to a labeled example
*x*with label*p*to obtain an augmented version ^*x*. - Then,
**the extent to which the model’s prediction matches the label**is measured as:

**The weight for each sampled magnitude bin is updated**as:

- where
*ρ*=0.99 is a fixed exponential decay hyperparameter.

To be brief, CTAugment is a kind of RandAugment which samples the magnitude of transformation.

**4. Putting It All Together**

- The main purpose of this algorithm is to produce the collections
*X*’ and*U*’, consisting of augmented labeled and unlabeled examples with mixup applied.**The labels and label guesses in**are fed into*X*’ and*U*’**standard cross-entropy loss**terms against the model’s predictions. - As shown above,
**^**is also outputted, which consists of*U*1**a single heavily-augmented version**of each unlabeled image and**its label guesses without****mixup***U*1 is used in**two additional loss**terms which provide a mild boost in performance in addition to improved stability:**Pre-****mixup****unlabeled loss**and**Rotation loss**.

**Pre-****mixup****unlabeled loss**: the guessed labels and predictions for example in ^*U*1 are fed as-is into a**separate cross-entropy loss term**.**Rotation loss**: Idea from self-supervised learning**RotNet**is used where**each ^**and the model needs to*U*1 image is rotated**predict the rotation amount.**

- In total, the
**ReMixMatch loss**is:

- In practice,
*λr*=λ^*U*1=0.5. λ*U*=1.5.*T*=0.5 for sharpening, and*Beta*=0.75 for mixup.*K*=8. The final model as an exponential moving average over the trained model weights with a decay of 0.999.

**5. Experimental Results**

Realistic Semi-Supervised Learning (Oliver NeurIPS’18) recommendationsfor performing semi-supervised learning (SSL) evaluations.

## 5.1. CIFAR-10 & SHVN

ReMixMatch sets the new state-of-the-artfor all numbers of labeled examples.

- Most importantly,
**ReMixMatch is 16× more data efficient than****MixMatch** **ReMixMatch reaches state-of-the-art at 250 labeled examples**, and within the margin of error for state-of-the-art otherwise.

## 5.2. STL-10

Using the same WRN-37–2 network (23.8 million parameters),

the error rate is reduced by a factor of two compared toMixMatch.

## 5.3. Towards Few-Shot Learning

**High variance**is to be expected when choosing so few labeled examples at random.- On CIFAR-10, ReMixMatch obtains a median-of-five error rate of 15.08%.
- On SVHN ReMixMatch reaches 3.48% error.
- On SVHN with the “extra” dataset, ReMixMatch reaches 2.81% error.

ReMixMatch is able to work in extremely low-label settings.

## 5.4. Ablation Study

**Removing the rotation loss reduces accuracy at 250 labels by only 0.14 percentage points**, but it is found that**in the 40-label setting, rotation loss is necessary to prevent collapse.****Changing the cross-entropy loss on unlabeled data to an L2 loss as used in****MixMatch****hurts performance dramatically**, as does removing either of the augmentation components.

This validates using augmentation anchoring in place of the consistency regularization mechanism of MixMatch.

(Later on, there is another approach called FixMatch proposed, which outperforms ReMixMatch. I hope I can review it in the coming future.)

## Reference

[2020 ICLR] [ReMixMatch]

ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring

## Pretraining or Weakly/Semi-Supervised Learning

**2004 … 2019** [VAT] [Billion-Scale] [Label Propagation] [Rethinking ImageNet Pre-training] [MixMatch] [SWA & Fast SWA] [S⁴L] **2020 **[BiT] [Noisy Student] [SimCLRv2] [UDA] [ReMixMatch]