Brief Review — GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding

GShard to Scale MoE Model to 600B Parameters Across GPUs

Sik-Ho Tsang
5 min readAug 13, 2024
GShard to Scale MoE Model to 600B Parameters Across GPUs

GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding
GShard
, by Google Inc.
2021 ICLR, Over 830 Citations (Sik-Ho Tsang @ Medium)

Large Langauge Model (LLM)
2020 … 2023
[GPT-4] [LLaMA] [Koala] [BloombergGPT] [GLM-130B] [UL2] [PaLM 2] [Llama 2] [MultiMedQA, HealthSearchQA, Med-PaLM] [Med-PaLM 2] [Flan 2022, Flan-T5] [AlphaCode 2] [Mistral 7B]
==== My Other Paper Readings Are Also Over Here ====

  • When I read NLP papers, GShard often comes up as a method used for scaling language models.
  • GShard, a module composed of a set of lightweight annotation APIs and an extension to the XLA compiler,s is proposed to enable large scale models with up to trillions of parameters.
  • GShard and conditional computation enable to scale up multilingual neural machine translation (MNMT) Transformer model with Sparsely-Gated Mixture-of-Experts (MoE).
  • A giant MNMT MoE model with 600 billion parameters can efficiently be trained on 2048 TPU v3 cores in 4 days.

Outline

  1. Modified MoE in GShard
  2. Highly Parallel Implementation in GShard
  3. Results

1. Modified MoE in GShard

(a) Transformer, (b) MoE, © MoE With GShard

1.1. MoE

MoE layer is the combination of the expert outputs:

  • where the vector G_s,E is computed by a gating function GATE(.). G_s,E represents how much an expert contributes to the final network output.

The sparse gating function and the auxiliary loss being used in MoE this paper is different from the one in the conventional MoE.

1.2. Modified MoE

1.2.1. Load balancing

  • Naively picking top-k experts from the softmax probability distribution leads to load imbalance problem for training. MoE layer in this paper enforces that the number of tokens processed by one expert is below some uniform threshold called expert capacity.
  • k=2 is used in this paper.
  • When both experts selected by a token already exceed their capacity, the token is considered as an overflowed token, where G_s,E degenerates into a zero vector. Such tokens will be passed on to the next layer via residual connections.

1.2.2. Local Dispatching for Parallel Gating

  • A new GATE(.) function that partitions all tokens in a training batch evenly into G local groups, i.e., each group contains S=N/G tokens for local dispatching. All local groups are processed independently in parallel. Each group is given a fractional capacity of each expert, C=2N/(G·E), to ensure that there are tokens dispatched to an expert.

With fixed expert capacity and local dispatching, GShard is able to speed up the gating function by O(G) times.

1.2.3. Auxiliary Loss

  • A new differentiable auxiliary loss term l_aux is added to enforce load balancing. The overall loss becomes:

2. Highly Parallel Implementation in GShard

Highly Parallel Implementation in GShard
  • D number of devices are used.
  • Top2Gating computes the union of all group-local G_S,E.
  • combine_weights is a 4-D tensor with shape [G, S, E, C], whose element value becomes non-zero when the input token s in group g is sent to expert e at capacity buffer position c. For a specific g and s, a slice combine_weight[g, s, :, :] contains at most two non-zero values.
  • Binary dispatch_mask is produced from combine_weights by simply setting all non-zero values to 1.
  • To scale the computation to a cluster with D devices, the number of groups G and the number of experts E are chosen so that they are proportional to D.
  • The per device flops for softmax is proportional to D.
  • D≤2H for up to 16K devices so it is less than that of FFN. Consequently the total per-device FLOPS could be considered independent of D, satisfying sublinear scaling design requirements.

To express parallelism, tensors in the linear algebra computation are annotated with sharding information using GShard APIs to selectively specify how they should be partitioned across a cluster of devices.

This sharding information is propagated to the compiler so that the compiler can automatically apply transformations for parallel execution.

  • The input tensor is split along the first dimension and the gating weight tensor is replicated.

A 600B GShard model for M4 processes 1T tokens (source side tokens after sub-word segmentation) in 250k training steps under 4 days.

2. Results

MoE with GShard Scaling Performance
  • The depth of the Transformer network (L) and the number of experts (E) are varied to scale the model. For depth, 12 (original Transformer depth, which consists of 6 encoder and 6 decoder layers), 36 and 60 layers, are tested. For the number of experts that replaces every other feed-forward layer, three options, 128, 512 and 2048 experts, are tested.
  • T(96L) is the dense model without any MoE used, which treated as baseline.

Deeper Models Bring Consistent Quality Gains Across the Board.

MoE with GShard Scaling Performance

Deeper models converge faster with fewer examples.

  • One of the largest models, MoE(2048E, 36L) with 600 billion parameters, utilized 2048 TPU cores for 4 days. This model achieves the best translation quality in terms of average BLEU, but also takes a total of 22.4 TPU years to train.

Scaling with conditional computation is way more practical compared to dense scaling.

  • Given the same number of TPU cores used by MoE(2048E, 36L), the dense scaling variant, T(96L), appears to be taking more than ten times to train (235 TPU core years).

--

--

Sik-Ho Tsang

PhD, Researcher. I share what I learn. :) Linktree: https://linktr.ee/shtsang for Twitter, LinkedIn, etc.