Review — DeepNet: Scaling Transformers to 1,000 Layers
Using DEEPNORM for Normalization, Up to 1000 Transformer Layers
DeepNet: Scaling Transformers to 1,000 Layers,
DeepNet, by Microsoft Research,
2022 arXiv v1, Over 30 Citations (Sik-Ho Tsang @ Medium)
NLP, NMT, Neural Machine Translation, Transformer,2.2. Machine Translation
2013 … 2021 [ResMLP] [GPKD] [Roformer] [DeLighT] 2022 [DeepNet]
==== My Other Paper Readings Are Also Over Here ====
- A new normalization function, DEEPNORM, is proposed to modify the residual connection in Transformer, such that model updates can be bounded in a stable way.
- Finally, DeepNet is formed, with Transformer successfully scaled up to 1,000 layers (i.e., 2,500 attention and feed-forward network sublayers) without difficulty, which is one order of magnitude deeper than previous Transformer.
Outline
- Initialization Analysis
- DeepNet
- Results
1. Initialization Analysis
1.1. Post-LN-init
- It is hypothesized that better initialization methods stabilize the training of Transformer.
- Post-LN: is a baseline Transformer using post-LN activation without any weight scaling.
- To have better initialization, the weights of l-th layer are downscaled by kl=N-l+1 where l ∈ [1, N] after performing Xavier initialization.
- This simple scaling approach is named as Post-LN-init. For example, the output projection Wlo of FFN in l-th layer is initialized as:
- where d’ is an average of input and output dimensions.
Thus, the scale of lower layers is narrowed down.
1.2. Analysis of Gradient Norm
- 18L-18L Post-LN, and 18L-18L Post-LN-init are trained on the IWSLT-14 De-En machine translation dataset.
- (c) Validation Loss Curve: Post-LN-init converged while Post-LN did not.
- (a) & (b) TGradient Norm: The gradient norm of Post-LN-init in the last layer is still much larger than that of Post-LN, regardless of model depth.
It concludes that the exploding gradients in deep layers should not be the root cause of instability of Post-LN.
1.3. Instability Causes of Post-LN
- The instability of Post-LN comes from a chain of several issues, including gradient vanishing as well as too large model updates.
- (a): The norm of model update ||ΔF|| at the early stage of training:
- where x and θi denotes input, and model parameters after i-th updates.
- Post-LN has an exploding update at the very beginning of training, and then nearly no update shortly. It indicates that the model has been stuck in a spurious local optima.
- (b) & (c): When the update explodes, the inputs to LN become large.
- According to the theoretical analysis from Pre-LN Transformer, the magnitude of gradient through LN is inversely proportional to the magnitude of its input:
- ||x|| is significantly larger than √d (d=512) without warm-up or proper initialization.
- (d): This explains the gradient vanishing problem occurred in the training of Post-LN.
Above all, the instability starts from the large model update at the beginning of training. It renders the model trapped in a bad local optima, which in turn increases the magnitude of inputs to each LN.
As training continues, the gradient through LN becomes increasingly small, thus resulting in severe gradient vanishing. The vanishing gradients make it difficult to escape from the local optima, and further destabilize the optimization.
On the contrary, Post-LN-init has relatively small updates, and the inputs to LN are stable. This relieves suffering from gradient vanishing, making optimization more stable.
2. DeepNet
- (For quick read, please read 2.1 then 2.3.)
2.1. DEEPNORM Idea & Formulation
- DeepNet uses DEEPNORM, instead of Post-LN, for each sub-layer:
- where α is a constant, and Gl(xl, θl) is the function of the l-th Transformer sub-layer (i.e., attention or feed-forward network).
- Besides, DeepNet scales the weights θl inside residual branches by β.
- Notably, both α and β are constants that only depend on the architecture.
2.2. Expected Magnitude of Model Update
- Considering 1-head case in attention module, WQ, WK, WV are the input projection matrices, and WO is the output projection matrix. Then, the attention module can be formulated as:
Due to the softmax function, WQ and WK do not change the bound of attention output’s magnitude. WQ and WK is not considered as the source of instability issue in this paper.
- where “Θ=” stands for equal bound of magnitude.
2.2.1. Encoder DeepNet
Given an N-layer DeepNet, F(x, θ) (θ = {θ1, θ2, …, θ2N}), where θ2l-1 and θ2l denote the parameters of self-attention and FFN in l-th layer, and each sub-layer is normalized with DEEPNORM, as in 2.1 above, then ||ΔF|| satisfies:
- It is noted that Vanilla Post-LN can be regarded as a special case of DeepNet, where α=1 and vl=wl=1 at Xavier initialization.
For vanilla Post-LN, the above equation shows that the model tends to accumulate the update of each sub-layer, which leads to exploding magnitude of model’s update and destabilizes the optimization at the early stage.
- It also explains why warm-ups and smaller initialization can stabilize the training of Post-LN.
Warm-ups can reduce the magnitude of the model update by decreasing ||θ*i-θi||, while smaller initialization lowers √(v²i+ w²i).
- The magnitude of DeepNet with an N-layer encoder and an M-layer decoder, is also studied.
2.2.2. Encoder-Decoder DeepNet
- (θ = {θd1, θd2, …, θd,3M}) stands for the parameters of self-attentions, cross-attentions, and FFNs. {αe, Gel} and {αd, Gdl} are used to distinguish the notations α and G between the encoder and the decoder.
Given an encoder-decoder DeepNet with N encoder layers and M decoder layers, where each encoder sub-layer is normalized by their own {αe, Gel} and {αd, Gdl}, then ||ΔFed|| satisfies:
The vanilla encoder-decoder model satisfies that all of {αe, αd, vei, wei, vdi, wdi} equal to 1, it indicates the similar accumulative effect which leads to fast growth of the magnitude regarding the model depth. The decoder is more unstable than the encoder.
2.2.3. α & β Initializations
- With the use of SGD, and Pre-LN Transformer proves that Post-LN decreases the magnitude of backpropagating error signal, the second term of the above equation can be bounded as:
- For decoder, with also the consideration to balance the residual connections and initialization:
- For encoder:
2.3. Encoder & Decoder Initializations
- In summary, for encoder-decoder architecture:
- And the pseudocode and summary for DEEPNORM are as shown above.
With the above initialization, DeepNet can be formed. The above figure shows that DeepNet has much smaller and more stable updates than Post-LN.
3. Results
3.1. Bilingual NMT
Compared with the models with Post-LN, DeepNet is more stable, and can successfully scale to 100L-100L, reaching the 28.9 BLEU on the test set.
- In contrast, the baselines with Post-LN lead to unstable optimization when the depth goes to 50L-50L.
- Besides, DeepNet achieves comparable performance with these baselines when the models are shallow.
3.2. Further Studies
Overall, DeepNet is stable from shallow to deep. It converges fast, achieving over 30 BLEU in only 8,000 steps.
- DeepNet is further scaled to larger learning rate, batch size, and hidden dimension, respectively.
DeepNet can be trained without difficulty in all the largest settings.
- The loss of DeepNet with 1024 hidden size increases after 10K steps because of overfitting. Besides, it indicates that DeepNet can benefit from the larger settings, resulting in faster convergence and lower validation loss.
3.2. Multilingual NMT
- DeepNet of {12, 20, 100, 200, 1000} layers are trained on the OPUS-100 dataset.
Compared with bilingual NMT, multilingual NMT benefits more from scaling the depth of the model because of its hunger in model capacity.
Increasing the depth can significantly improve the translation quality of NMT.
- M2M-100 has a 24-layer encoder, a 24-layer decoder, and 4,096 hidden size, resulting in up to 12B parameters.
Compared with M2M-100, DeepNet is deep and narrow with only 3.2B parameters.