[Paper] DARTS: Differentiable Architecture Search (Image Classification)

In this story, DARTS: Differentiable Architecture Search (DARTS), by CMU, and DeepMind, is presented. In this paper:

  • Architecture search method is based on the continuous relaxation of the architecture representation, allowing efficient search of the architecture using gradient descent.
  • This is unlike conventional approaches which apply evolution or reinforcement learning, and it is orders of magnitude faster than state-of-the-art non-differentiable techniques.

This is a paper in 2019 ICLR with over 900 citations. (Sik-Ho Tsang @ Medium)

Outline

  1. Search Space
  2. Continuous Relaxation and Optimization
  3. Approximate Architecture Gradient
  4. Experimental Results

1. Search Space

1.1. Fixed Outer Structure

  • Similar to NASNet, PNASNet and AmoebaNet, fixed outer structures are used for different datasets.
  • The learned cell could either be stacked to form a convolutional network or recursively connected to form a recurrent network.

1.2. Directed Acyclic Graph

  • A cell is a directed acyclic graph consisting of an ordered sequence of N nodes.
  • Each node x(i) is a latent representation (e.g. a feature map in convolutional networks).
  • Each directed edge (i, j) is associated with some operation o(i, j) that transforms x(i).
  • The cell is assumed to have two input nodes and a single output node.
  • For convolutional cells, the input nodes are defined as the cell outputs in the previous two layers.
  • For recurrent cells, these are defined as the input at the current step and the state carried from the previous step.
  • The output of the cell is obtained by applying a reduction operation (e.g. concatenation) to all the intermediate nodes.
  • A special zero operation is also included to indicate a lack of connection between two nodes.

2. Continuous Relaxation and Optimization

2.1. Continuous Relaxation

  • Let O be a set of candidate operations (e.g., convolution, max pooling, zero) where each operation represents some function o() to be applied to x(i).
  • (b) & (c): The categorical choice of a particular operation is relaxed to a softmax over all possible operations:
  • where the operation mixing weights for a pair of nodes (i, j) are parameterized by a vector α(i, j) of dimension |O|.
  • (d): At the end of search, a discrete architecture can be obtained by replacing each mixed operation o(i, j) with the most likely operation:
  • Thus, α is treated as the encoding of the architecture.

After relaxation, the goal is to jointly learn the architecture α and the weights w within all the mixed operations. DARTS aims to optimize the validation loss, but using gradient descent. (not using apply evolution or reinforcement learning)

2.2. Optimization

  • Ltrain and Lval are denoted the training loss and the validation loss, respectively.
  • The goal for architecture search is to find α* that minimizes the validation loss Lval(w*,α*), where the weights w* associated with the architecture are obtained by minimizing the training loss w* = argmin w Ltrain(w,α*).
  • This implies a bilevel optimization problem with α as the upper-level variable and w as the lower-level variable:

3. Approximate Architecture Gradient

3.1. Approximation

  • Evaluating the architecture gradient exactly can be prohibitive due to the expensive inner optimization.
  • A simple approximation scheme is as follows:
  • where w denotes the current weights maintained by the algorithm, and ξ is the learning rate for a step of inner optimization. The idea is to approximate w*(α) by adapting w using only a single training step, without solving the inner optimization.
  • The above equation will reduce to ∇αLval(w,α) if w is already a local optimum for the inner optimization and thus ξwLtrain(w,α)=0.
  • Setting ξ=0 makes the equation become a first-order approximation, which leads to worse performance.

3.2. Deriving Architecture

  • To form each node in the discrete architecture, the top-k strongest operations (from distinct nodes) are retained among all non-zero candidate operations collected from all the previous nodes.
  • k=2 for convolutional cells and k=1 for recurrent cell. (The zero operations are excluded)

4. Experimental Results

4.1. CIFAR-10

  • O: 3×3 and 5×5 separable convolutions, 3×3 and 5×5 dilated separable convolutions, 3×3 max pooling, 3×3 average pooling, identity, and zero.
  • ReLU-Conv-BN order is used.
  • Each separable convolution is always applied twice.
  • N=7 nodes.
  • The output node is defined as the depthwise concatenation of all the intermediate nodes
  • The architecture encoding therefore is (αnormal; αreduce), where αnormal is shared by all the normal cells and αreduce is shared by all the reduction cells.
  • DARTS achieved comparable results with the SOTA approaches while using three orders of magnitude less computation resources (i.e. 1.5 or 4 GPU days vs 2000 GPU days for NASNet and 3150 GPU days for AmoebaNet).
  • Moreover, with slightly longer search time, DARTS outperformed ENAS by discovering cells with comparable error rates but less parameters.
  • The longer search time is due to the fact that we have repeated the search process four times for cell selection.

4.2. Penn Treebank

  • N=12 nodes.
  • The very first intermediate node is obtained by linearly transforming the two input nodes, adding up the results and then passing through a tanh. The rest of the cell is learned.
  • The recurrent network consists of only a single cell, no repetitive patterns are assumed within the recurrent architecture.
  • A cell discovered by DARTS achieved the test perplexity of 55.7, on par with the state-of-the-art model enhanced by a mixture of softmaxes, better than all the rest of the architectures that are either manually or automatically discovered.
  • In terms of efficiency, the overall cost (4 runs in total) is within 1 GPU day, which is comparable to ENAS and significantly faster than NAS.
  • Nevertheless, with comparable or less search cost, DARTS is able to significantly improve upon random search in both cases (2.76±0.09 vs 3.29±0.15 on CIFAR-10, 55.7 vs 59.4 on PTB).

4.3. ImageNet

  • The cell learned on CIFAR-10 is indeed transferable to ImageNet.
  • It is worth noticing that DARTS achieves competitive performance with the state-of-the-art RL method while using three orders of magnitude less computation resources.

4.4. WT2

  • The cell identified by DARTS transfers to WT2 better than ENAS.
  • The transferability is weaker between PTB and WT2 (as compared to that between CIFAR-10 and ImageNet) could be explained by the relatively small size of the source dataset (PTB) for architecture search.

--

--

PhD, Researcher. I share what I learn. :) Reads: https://bit.ly/33TDhxG, LinkedIn: https://www.linkedin.com/in/sh-tsang/, Twitter: https://twitter.com/SHTsang3

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store