Review — FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence

FixMatch, Greatly Simplifies UDA & ReMixMatch

FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence, FixMatch, by Google Research
2020 NeurIPS, Over 800 Citations (Sik-Ho Tsang @ Medium)
Semi-Supervised Learning, Image Classification, CNN, UDA, ReMixMatch

  • FixMatch, an algorithm that is a significant simplification of existing Semi-Supervised Learning (SSL) methods, is proposed.
  • FixMatch first generates pseudo-labels using the model’s predictions on weakly-augmented unlabeled images.
  • For a given image, the pseudo-label is only retained if the model produces a high-confidence prediction.
  • The model is then trained to predict the pseudo-label when fed a strongly-augmented version of the same image.

Outline

  1. Background
  2. FixMatch
  3. Experimental Results

1. Background

1.1. Notations

  • For an L-class classification problem, let X={(xb, pb):b∈(1, …, B)} be a batch of B labeled examples, where xb are the training examples and pb are one-hot labels.
  • Let U={ub:b∈(1, …, μB)} be a batch of μB unlabeled examples where μ is a hyperparameter that determines the relative sizes of X and U.
  • Let pm(y|x) be the predicted class distribution produced by the model for input x, and the cross-entropy between two probability distributions p and q as H(p, q).
  • Two types of augmentations as part of FixMatch: strong and weak, denoted by A() and α() respectively.

1.2. Consistency Regularization

  • Consistency regularization utilizes unlabeled data by relying on the assumption that the model should output similar predictions when fed perturbed versions of the same image. The loss function is:

1.3. Pseudo-Label (PL)

  • Pseudo-Label (PL) leverages the idea of using the model itself to obtain artificial labels for unlabeled data.
  • Only artificial labels whose largest class probability fall above a predefined threshold are retained.
  • Letting qb=pm(y|ub), Pseudo-Label (PL) uses the following loss function:
  • where
  • and τ is the threshold.
  • The arg max of the model’s output is the use of “hard” labels, which makes the label one-hot. The use of a hard label makes Pseudo-Label (PL) closely related to entropy minimization, where the model’s predictions are encouraged to be low-entropy (i.e., high-confidence) on unlabeled data.

2. FixMatch

FixMatch
  • The loss function for FixMatch consists of two cross-entropy loss terms: a supervised loss ls applied to labeled data and an unsupervised loss lu.

2.1. Supervised Loss

  • Specifically, ls is just the standard cross-entropy loss on weakly augmented labeled examples:

2.2. Unsupervised Loss

  • FixMatch computes an artificial label for each unlabeled example (All labeled data as part of unlabeled data without their labels when constructing U.) which is then used in a standard cross-entropy loss.
  • To obtain an artificial label, the model’s predicted class distribution is first computed given a weakly-augmented version of a given unlabeled image:
  • Then, pseudo-label is selected using arg max of qb:
  • except that the cross-entropy loss is enforced against the model’s output for a strongly-augmented version of ub:
  • The loss minimized by FixMatch is simply the weighted sum of ls and lu:
  • where λu is a fixed scalar hyperparameter denoting the relative weight of the unlabeled loss.

The weak and strong augmentations are the key success of FixMatch.

2.3. Augmentation in FixMatch

  • Weak augmentation is a standard flip-and-shift augmentation strategy. Specifically, images are randomly flipped horizontally with a probability of 50% on all datasets except SVHN and images are randomly translated by up to 12.5% vertically and horizontally.
  • For strong augmentation, variants of AutoAugment which do not require the augmentation strategy to be learned ahead of time with labeled data, such as RandAugment, and CTAugment (in ReMixMatch), are considered. Both RandAugment and CTAugment randomly select transformations for each sample. Then, Cutout is further used after augmentation.
  • For RandAugment, the magnitude that controls the severity of all distortions is randomly sampled from a pre-defined range (RandAugment with random magnitude was also used for UDA), whereas the magnitudes of individual transformations are learned on-the-fly for CTAugment.

2.4. Overall

FixMatch Algorithm

A weakly-augmented image is fed into the model to obtain predictions (red box in figure). When the model assigns a probability to any class which is above a threshold (dotted line), the prediction is converted to a one-hot pseudo-label.

Then, we compute the model’s prediction for a strong augmentation of the same image. The model is trained to make its prediction on the strongly-augmented version match the pseudo-label via a cross-entropy loss.

  • The above hyperparameters are used across all amounts of labeled examples and datasets other than ImageNet.

2.5. Differences from SOTA Approaches

Comparison of SSL algorithms which include a form of consistency regularization and which (optionally) apply some form of post-processing to the artificial labels
  • The above table compares the details of SOTA methods such as Π-Model, Temporal Ensembling, Mean Teacher, VAT, UDA, MixMatch, & ReMixMatch.
  • FixMatch bears the closest resemblance to two recent methods: UDA and ReMixMatch. Neither of them uses Pseudo-Labeling (PL), but both approaches “sharpen” the artificial label to encourage the model to produce high-confidence predictions.
  • UDA in particular also only enforces consistency when the highest probability in the predicted class distribution for the artificial label is above a threshold. The thresholded Pseudo-Labeling of FixMatch has a similar effect to sharpening.
  • ReMixMatch anneals the weight of the unlabeled data loss.

FixMatch can be viewed as a substantially simplified version of UDA and ReMixMatch, with many components removed (sharpening, training signal annealing from UDA, distribution alignment and the rotation loss from ReMixMatch, etc.).

3. Experimental Results

3.1. CIFAR, SVHN, & STL-10

Error rates for CIFAR-10, CIFAR-100, SVHN and STL-10 on 5 different folds
  • WRN-28–2 with 1.5M parameters for CIFAR-10 and SVHN, WRN-28–8 for CIFAR-100, and WRN-37–2 for STL-10, are used.
  • Performing better with less supervision is the central goal of SSL in practice since it alleviates the need for labeled data. FixMatch is the first to run any experiments at 4 labels per class on CIFAR-100.

FixMatch substantially outperforms each of these methods while nevertheless being simpler.

  • For example, FixMatch achieves an average error rate of 11.39% on CIFAR-10 with 4 labels per class. The lowest error rate achieved on CIFAR-10 with 400 labels per class was 13.13%.
  • FixMatch’s results are state-of-the-art on all datasets except for CIFAR-100 where ReMixMatch performs a bit better.
  • On STL-10, FixMatch achieves the state-of-the-art performance of ReMixMatch despite being significantly simpler.

3.2. ImageNet

  • 10% of the training data is used as labeled examples and treat the rest as unlabeled examples. ResNet-50 is used and RandAugment is used as strong augmentation for this experiment.

FixMatch achieves a top-1 error rate of 28.54±0.52%, which is 2.68% better than UDA. FixMatch’s top-5 error rate is 10.87±0.28%.

  • While S⁴L holds state-of-the-art on semi-supervised ImageNet with a 26.79% error rate, it leverages 2 additional training phases (Pseudo-Label re-training and supervised fine-tuning) to significantly lower the error rate from 30.27% after the first phase.

FixMatch outperforms S⁴L after its first phase, and it is possible that a similar performance gain could be achieved by incorporating these techniques into FixMatch.

3.3. Barely Supervised Learning

FixMatch reaches 78% CIFAR-10 accuracy using only above 10 labeled images
  • 1 sample per class is used. 78% median accuracy is obtained.

3.4. Ablation Study

Plots of ablation studies on FixMatch
  • (a): Varying the confidence threshold for Pseudo-Label (PL).
  • (b): Measuring the effect of “sharpening” the predicted label distribution.
Ablation study with different strong data augmentation of FixMatch
  • Both Cutout and CTAugment (in ReMixMatch) are required to obtain the best performance; removing either results in a significant increase in error rate.
  • (There are still many results in the Appendix of the paper, please feel free to read if interested.)

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store