Review: Virtual Adversarial Training (VAT)

VAT for Semi-Supervised Learning, Outperforms Ladder Network, Γ-Model & Π-Model

Figure From Author’s GitHub

Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning
VAT, by Preferred Networks, Inc., ATR Cognitive Mechanisms Laboratories, Ritsumeikan University, and Kyoto University
2019 TPAMI, Over 1500 Citations (Sik-Ho Tsang @ Medium)
This paper is extended from “Distributional Smoothing with Virtual Adversarial Training” in 2016 ICLR with over 400 Citations.

  • A new measure of local smoothness of the conditional label distribution given input is proposed.
  • Virtual adversarial loss is defined as the robustness of the conditional label distribution around each input data point against local perturbation.

Outline

  1. Virtual Adversarial Training (VAT)
  2. Experimental Results

1. Virtual Adversarial Training (VAT)

Virtual Adversarial Training (VAT) (Figure from Amit Chaudhary)
  • In Temporal Ensembling and Mean Teacher, MSE is used for estimating the similarity between two predictions.
  • In contrast, in Virtual Adversarial Training (VAT), KL divergence is used:
  • where x is input, r is a small perturbation on x, y is output, and Q is the set of labels.
Adversarial Direction (Solid Arrow) (Figure from Divam Gupta)
  • The perturbation r should be in the adversarial direction such that the prediction of the perturbed input should be different from the original one, i.e. the KL divergence between the two output distributions should be large:
  • where ε is the norm constraint.
  • Local Distribution Smoothing (LDS) loss is defined:

The loss LDS(x,θ) can be considered as a negative measure of the local smoothness of the current model at each input data point x.

  • The regularization term proposed in this paper is the average of LDS(x*,θ) over all input data points:
  • where Nl is the number of labelled samples, Nul is the number of unlabelled samples, Dl is the labelled samples, Dul is the unlabelled samples.
  • The full objective function is:
  • where l(Dl,θ) is the negative log-likelihood for the labeled dataset. VAT is a training method with the regularizer Rvadv.
VAT Algorithm
  • To perform VAT, first, get M randomly selected samples.
  • Generate a random unit vector for each sample, to calculate rvadv by taking the gradient.
Codes From Author’s GitHub
  • The above codes are from the author.
  • d in the code is equal to r in the paper.
  • LDSs are large for the points at the class boundary, and getting smaller after each update.

2. Experimental Results

2.1. MNIST

Test Performance of Semi-Supervised Learning Methods on MNIST with the Permutation Invariant Setting
  • NN with four hidden layers, of {1200, 600, 300, 150}, is used.

VAT outperforms many other semi-supervised methods except Ladder Network or GANs.

2.2. SVHN & CIFAR-10

Test Performance of Semi-Supervised Learning Methods on SVHN and CIFAR-10 without Image Data Augmentation
  • Two CNNs, Conv-Small and Conv-Large, are used.

VAT achieved the test error rate of 14.82%, which outperformed the state-of-the-art methods for semi-supervised learning on CIFAR-10.

  • With EntMin, ‘VAT+EntMin’ outperformed the state-of-the-art methods for semi-supervised learning on both SVHN and CIFAR-10.

2.3. Ablation of ε and α

Ablation of ε and α
  • α is fixed to be 1. ε is the only hyperparameter to be tuned.

2.4. Virtual Adversarial Examples

Virtual Adversarial Examples using Different Values of ε

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store