Review — The Reversible Residual Network: Backpropagation Without Storing Activations
RevNet, Skip Storing Activation Maps During Backprop, Less Memory Consumption
The Reversible Residual Network: Backpropagation Without Storing Activations, RevNet, by University of Toronto, and Uber Advanced Technologies Group, 2017 NeurIPS, Over 300 Citations (Sik-Ho Tsang @ Medium)
Image Classification, Convolutional Neural Network, CNN, Residual Network, ResNet
- Reversible Residual Network (RevNet), a variant of ResNets, is proposed, where each layer’s activations can be reconstructed exactly from the next layer’s.
- Therefore, the activations for most layers need not be stored in memory during backpropagation.
Outline
- ResNet & Backprop Brief Review
- Reversible Residual Network (RevNet)
- Experimental Results
1. ResNet & Backprop Brief Review
1.1. ResNet
- ResNet proposes the skip connection (Figure 1 left) by adding the input to the output to prevent from gradient vanishing:
- ResNet proposes two residual block architectures: the basic residual function (Figure 1 right-top) and the bottleneck residual function (Figure 1 right-bottom):
- (Please feel free to read ResNet if interested.)
1.2. Backprop
- Let v1, …, vK denote a topological ordering of the nodes in the network’s computation graph G, where vK denotes the cost function C. Each node is defined as a function fi of its parents in G.
- Backprop computes the total derivative dC/dvi for each node in the computation graph. The total derivatives are denoted using bar(.) symbol:
- Backprop iterates over the nodes in the computation graph in reverse topological order. For each node vi, it computes the total derivative bar(vi) using the following rule:
- where Child(i) denotes the children of node vi in G and ∂fj/∂vi denotes the Jacobian matrix. (Jacobian matrix is a matrix of the first-order partial derivatives.)
2. Reversible Residual Network (RevNet)
2.1. Forward & Reverse
- RevNets are composed of a series of reversible blocks. The channels in each layer are partitioned into two groups, denoted x1 and x2.
- Forward (Figure 2(a)): Each reversible block takes inputs (x1, x2) and produces outputs (y1, y2) and residual functions F and G analogous to those in standard ResNets:
- Reverse (Figure 2(b)): Each layer’s activations can be reconstructed from the next layer’s activations as follows:
- Unlike residual blocks, reversible blocks must have a stride of 1 because otherwise the layer discards information, and therefore cannot be reversible. This is the constraint of RevNet for volume-preserving mappings.
2.2. Backprop
- To derive the backprop procedure, it is helpful to rewrite the forward (left) and reverse (right) computations in the following way:
- Even though z1=y1, the two variables represent distinct nodes of the computation graph, so the total derivatives bar(z1) and bar(y1) are different. In particular, z1 includes the indirect effect through y2, while y1 does not.
- This splitting lets us implement the forward and backward passes for reversible blocks in a modular fashion.
- In the backwards pass, we are given the activations (y1, y2) and their total derivatives bar(y1, y2) and wish to compute the inputs (x1, x2), their total derivatives bar(x1, x2), and the total derivatives for any parameters associated with F and G.
- The reconstruction formulas above are combined with the backprop rule (Section 1.2). The resulting algorithm is given as Algorithm 1 below:
- By applying Algorithm 1 repeatedly, one can perform backprop on a sequence of reversible blocks if one is given simply the activations and their derivatives for the top layer in the sequence.
Only activations of non-reversible layers, such as subsampling layers, need to be stored explicitly during backprop, which is handful.
In this case, the storage cost of the activations would be small, and independent of the depth of the network.
2.3. Computational Overhead
- In general, for a network with N connections, the forward and backward passes of backprop require approximately N and 2N add-multiply operations, respectively.
- For a RevNet, the residual functions each must be recomputed during the backward pass. Therefore, the number of operations required for reversible backprop is approximately 4N, or roughly 33% more than ordinary backprop.
2.4. Architecture Details
- It is observed that each reversible block has a computation depth of two original residual blocks. Therefore, the total number of residual blocks is reduced by approximately half, while approximately doubling the number of channels per block, since they are partitioned into two, so that RevNet architectures (which are far more memory efficient) were able to match the classification accuracy of ResNets of the same size.
3. Experimental Results
RevNets roughly matched the error rates of traditional ResNets (of roughly equal computational depth and number of parameters) on CIFAR-10 & 100 as well as ImageNet.
Reversibility did not lead to any noticeable per-iteration slowdown in training.
Reconstructing the activations over many layers causes numerical errors to accumulate.
- Left: While the angle grows during training, It did not observe any instability.
- Middle & Right: Despite the numerical error from reconstructing activations, both methods performed almost indistinguishably in terms of the training efficiency and the final performance.
(Normally, the backward pass of standard CNN has already well-handled by e.g.: TensorFlow or PyTorch frameworks. In RevNet, the forward pass and backward pass are customized.)
Reference
[2017 NeurIPS] [RevNet]
The Reversible Residual Network: Backpropagation Without Storing Activations
Image Classification
1989–2017 … [RevNet] … 2021 [Learned Resizer] [Vision Transformer, ViT] [ResNet Strikes Back] [DeiT] [EfficientNetV2] [MLP-Mixer] [T2T-ViT] [Swin Transformer] [CaiT] [ResMLP] [ResNet-RS] [NFNet] [PVT, PVTv1] [CvT] [HaloNet] [TNT] [CoAtNet] [Focal Transformer] [TResNet] [CPVT] [Twins] 2022 [ConvNeXt]