Brief Review — sMLP: Efficient Language Modeling with Sparse all-MLP

sMLP, Chunk Vector to Each Expert

Sik-Ho Tsang
3 min readNov 9, 2023
sMLP-Determinstic Obtains Large Speedup

Efficient Language Modeling with Sparse all-MLP
sMLP
, by State University of New York at Buffalo, and Meta AI,
2022 arXiv (Sik-Ho Tsang @ Medium)

Language Model (LM)
2007 … 2022 [GLM] [Switch Transformers] [WideNet] [MoEBERT] [X-MoE]

  • Sparsely activated MLPs with mixture-of-experts (MoEs) significantly increase model capacity and expressiveness while keeping the compute constant.
  • Also, e.g. gMLP shows that all-MLPs can match Transformers in language modeling, but still lag behind in downstream tasks.
  • This paper proposes sMLP to address the above challenges by designing deterministic routing and partial prediction.

Outline

  1. sMLP
  2. Results

1. sMLP

sMLP Architecture
  • sMLP model contains N1 dense blocks and N2 sparse blocks.
  • In each sparse block, it contains two modules:
  1. tMoE module: The MoE is adopted from Base Layers (Lewis et al., 2021) to replace the FFN module in dense Transformers.
  2. sMoE module: This sMoE module is proposed in this paper to replace the self attention module in the Transformer and the Spatial Gating Unit in gMLP.

1.1. Sparsely-activated all-MLP

Proposed sMoE Gating (Right)
  • Left: An example of the gating function from existing Transformer-based MoEs (tMoE). tMoE sends these four tokens to these 3 experts at the FFN layer using a learnt gating function.

Right: In sparse all-MLP architecture, sMLP proposes to chunk hidden representation along the hidden dimension and send chunked vectors to different experts.

1.2. Routing in Feature Space

  • Compared to routing tokens, routing hidden dimensions faces a unique challenge in autoregressive models, with information leaking from looking ahead at future tokens if done naively.

1.2.1. Deterministic Routing

Using Deterministic Routing, the vector is chunked in a hidden dimension and send hidden vectors to experts deterministically.

1.2.2. Partial Prediction

  • The first 20% of tokens X1 are used to decide the routing, and the remaining 80% of tokens X2 are used for prediction.

Instead of training the language model on the whole sequence length T, it is trained to predict X2. X1 is used to learn the gating weights Wr.

2. Results

sMLP dramatically improves the performance of the all-MLP-based model and also outperforms the Transformer model.

sMLP achieves the best generalization at 25k training steps and in the meanwhile achieves the highest training speed.

Scaling Up sMLP
  • The model size is increased for 2.0 TFLOPs training. The embedding is changed from 1024 to 2048, and the hidden dimension is adjusted from 4096 to 8192.

sMLP still outperforms Switch Transformer-Enlarge despite the latter having more FLOPs.

Zero-Shot

sMLP outperforms all Sparse Transformers in terms of average accuracy. Notable improvements come from commonsense reasoning tasks such as COPA, StoryCloze and HellaSwag.

  • GPT-3 is also compared while GPT-3 was trained with more pre-training data.

--

--

Sik-Ho Tsang

PhD, Researcher. I share what I learn. :) Linktree: https://linktr.ee/shtsang for Twitter, LinkedIn, etc.