[Paper] DARTS: Differentiable Architecture Search (Image Classification)

The normal cell for convolution for CIFAR-10 along epoch.

In this story, DARTS: Differentiable Architecture Search (DARTS), by CMU, and DeepMind, is presented. In this paper:

  • Architecture search method is based on the continuous relaxation of the architecture representation, allowing efficient search of the architecture using gradient descent.
  • This is unlike conventional approaches which apply evolution or reinforcement learning, and it is orders of magnitude faster than state-of-the-art non-differentiable techniques.

This is a paper in 2019 ICLR with over 900 citations. (Sik-Ho Tsang @ Medium)

Outline

  1. Search Space
  2. Continuous Relaxation and Optimization
  3. Approximate Architecture Gradient
  4. Experimental Results

1. Search Space

1.1. Fixed Outer Structure

Fixed Outer Structure (From NASNet)
  • Similar to NASNet, PNASNet and AmoebaNet, fixed outer structures are used for different datasets.
  • The learned cell could either be stacked to form a convolutional network or recursively connected to form a recurrent network.

1.2. Directed Acyclic Graph

Directed Acyclic Graph (From ENAS)
  • A cell is a directed acyclic graph consisting of an ordered sequence of N nodes.
  • Each node x(i) is a latent representation (e.g. a feature map in convolutional networks).
  • Each directed edge (i, j) is associated with some operation o(i, j) that transforms x(i).
  • The cell is assumed to have two input nodes and a single output node.
  • For convolutional cells, the input nodes are defined as the cell outputs in the previous two layers.
  • For recurrent cells, these are defined as the input at the current step and the state carried from the previous step.
  • The output of the cell is obtained by applying a reduction operation (e.g. concatenation) to all the intermediate nodes.
  • A special zero operation is also included to indicate a lack of connection between two nodes.

2. Continuous Relaxation and Optimization

2.1. Continuous Relaxation

DARTS Overview: (a) Directed Acyclic Graph, (b) All Possible Paths (c) Paths During Searching Using Softmax (d) Confirmed Paths After Softmax
  • Let O be a set of candidate operations (e.g., convolution, max pooling, zero) where each operation represents some function o() to be applied to x(i).
  • (b) & (c): The categorical choice of a particular operation is relaxed to a softmax over all possible operations:
  • where the operation mixing weights for a pair of nodes (i, j) are parameterized by a vector α(i, j) of dimension |O|.
  • (d): At the end of search, a discrete architecture can be obtained by replacing each mixed operation o(i, j) with the most likely operation:
  • Thus, α is treated as the encoding of the architecture.

After relaxation, the goal is to jointly learn the architecture α and the weights w within all the mixed operations. DARTS aims to optimize the validation loss, but using gradient descent. (not using apply evolution or reinforcement learning)

2.2. Optimization

  • Ltrain and Lval are denoted the training loss and the validation loss, respectively.
  • The goal for architecture search is to find α* that minimizes the validation loss Lval(w*,α*), where the weights w* associated with the architecture are obtained by minimizing the training loss w* = argmin w Ltrain(w,α*).
  • This implies a bilevel optimization problem with α as the upper-level variable and w as the lower-level variable:

3. Approximate Architecture Gradient

3.1. Approximation

DARTS
  • Evaluating the architecture gradient exactly can be prohibitive due to the expensive inner optimization.
  • A simple approximation scheme is as follows:
  • where w denotes the current weights maintained by the algorithm, and ξ is the learning rate for a step of inner optimization. The idea is to approximate w*(α) by adapting w using only a single training step, without solving the inner optimization.
  • The above equation will reduce to ∇αLval(w,α) if w is already a local optimum for the inner optimization and thus ξwLtrain(w,α)=0.
  • Setting ξ=0 makes the equation become a first-order approximation, which leads to worse performance.

3.2. Deriving Architecture

  • To form each node in the discrete architecture, the top-k strongest operations (from distinct nodes) are retained among all non-zero candidate operations collected from all the previous nodes.
  • k=2 for convolutional cells and k=1 for recurrent cell. (The zero operations are excluded)

4. Experimental Results

4.1. CIFAR-10

The reduction cell learned on CIFAR-10 along epoch. (The normal cell one is at the top of the story)
The Final Normal cell learned on CIFAR-10.
The Final Reduction cell learned on CIFAR-10.
  • O: 3×3 and 5×5 separable convolutions, 3×3 and 5×5 dilated separable convolutions, 3×3 max pooling, 3×3 average pooling, identity, and zero.
  • ReLU-Conv-BN order is used.
  • Each separable convolution is always applied twice.
  • N=7 nodes.
  • The output node is defined as the depthwise concatenation of all the intermediate nodes
  • The architecture encoding therefore is (αnormal; αreduce), where αnormal is shared by all the normal cells and αreduce is shared by all the reduction cells.
Comparison with state-of-the-art image classifiers on CIFAR-10
  • DARTS achieved comparable results with the SOTA approaches while using three orders of magnitude less computation resources (i.e. 1.5 or 4 GPU days vs 2000 GPU days for NASNet and 3150 GPU days for AmoebaNet).
  • Moreover, with slightly longer search time, DARTS outperformed ENAS by discovering cells with comparable error rates but less parameters.
  • The longer search time is due to the fact that we have repeated the search process four times for cell selection.

4.2. Penn Treebank

Recurrent cell learned on PTB along epoch.
The Final Recurrent cell learned on PTB.
  • N=12 nodes.
  • The very first intermediate node is obtained by linearly transforming the two input nodes, adding up the results and then passing through a tanh. The rest of the cell is learned.
  • The recurrent network consists of only a single cell, no repetitive patterns are assumed within the recurrent architecture.
Comparison with state-of-the-art language models on PTB (lower perplexity is better).
  • A cell discovered by DARTS achieved the test perplexity of 55.7, on par with the state-of-the-art model enhanced by a mixture of softmaxes, better than all the rest of the architectures that are either manually or automatically discovered.
  • In terms of efficiency, the overall cost (4 runs in total) is within 1 GPU day, which is comparable to ENAS and significantly faster than NAS.
  • Nevertheless, with comparable or less search cost, DARTS is able to significantly improve upon random search in both cases (2.76±0.09 vs 3.29±0.15 on CIFAR-10, 55.7 vs 59.4 on PTB).

4.3. ImageNet

Comparison with state-of-the-art image classifiers on ImageNet in the mobile setting
  • The cell learned on CIFAR-10 is indeed transferable to ImageNet.
  • It is worth noticing that DARTS achieves competitive performance with the state-of-the-art RL method while using three orders of magnitude less computation resources.

4.4. WT2

Comparison with state-of-the-art language models on WT2.
  • The cell identified by DARTS transfers to WT2 better than ENAS.
  • The transferability is weaker between PTB and WT2 (as compared to that between CIFAR-10 and ImageNet) could be explained by the relatively small size of the source dataset (PTB) for architecture search.

--

--

Get the Medium app