Review — XLNet: Generalized Autoregressive Pretraining for Language Understanding

XLNet, Reduce Pretrain-Finetune Discrepancy

Sik-Ho Tsang
6 min readJul 2, 2022

XLNet: Generalized Autoregressive Pretraining for Language Understanding
XLNet
, by Carnegie Mellon University, and Google AI Brain Team
2019 NeurIPS, Over 4900 Citations (Sik-Ho Tsang @ Medium)
Natural Language Processing, NLP, Language Model, BERT, Transformer

  • BERT neglects dependency between the masked positions and suffers from a pretrain-finetune discrepancy.
  • XLNet, a generalized autoregressive pretraining method, is proposed:
  1. It enables learning bidirectional contexts by maximizing the expected likelihood over all permutations of the factorization order.
  2. It overcomes the limitations of BERT thanks to its autoregressive formulation. Furthermore, it integrates ideas from Transformer-XL.

Outline

  1. Motivations
  2. XLNet: Permutation Language Modeling
  3. XLNet: Partial Prediction
  4. XLNet: Incorporating Ideas from Transformer-XL
  5. Experimental Results

1. Motivations

1.1. Preliminaries

  • AR language modeling seeks to estimate the probability distribution of a text corpus with an autoregressive model, such as GPT-1, GPT-2, GPT-3.
  • Given a text sequence x=(x1, …, xT), AR language modeling factorizes the likelihood into a forward product:
  • or a backward product:
  • A parametric model (e.g. a neural network) is trained to model each conditional distribution.

1.2. Autoregressive (AR) Model Problem

  • An AR language model is only trained to encode a uni-directional context (either forward or backward). AR language modeling performs pretraining by maximizing the likelihood under the forward autoregressive factorization:
  • where (x1:t-1) is a context representation produced by neural models, such as RNNs or Transformers, and e(x) denotes the embedding of x.
  • It is not effective at modeling deep bidirectional contexts. On the contrary, downstream language understanding tasks often require bidirectional context information.
  • This results in a gap between AR language modeling and effective pretraining.

1.3. Autoencoding (AE) Model Problem

  • AE models, such as BERT, it randomly masks x constructs a corrupted version ^x their training objective is to reconstruct bar(x) from ^x:
  • where mt=1 indicates xt is masked, and is a Transformer.
  • The artificial symbols like [MASK] used by BERT during pretraining are absent from real data at finetuning time, resulting in a pretrain-finetune discrepancy.
  • Moreover, since the predicted tokens are masked in the input, BERT is not able to model the joint probability using the product rule as in AR language modeling.

A natural question to ask is whether there exists a pretraining objective that brings the advantages of both while avoiding their weaknesses.

2. XLNet: Permutation Language Modeling

2.1. Permutations

  • For a sequence x of length T, there are T! different orders to perform a valid autoregressive factorization. Intuitively, if model parameters are shared across all factorization orders, in expectation, the model will learn to gather information from all positions on both sides.
  • Let ZT be the set of all possible permutations of the length-T index sequence [1, 2, …, T].The Permutation Language Modeling objective is:
  • Note that the proposed objective only permutes the factorization order, NOT the sequence order.
Illustration of the permutation language modeling objective for predicting x3 given the same input sequence x but with different factorization orders.
  • The above figure shows an example of predicting the token x3 given the same input sequence x but under different factorization orders.
  • The standard Softmax formulation, as below cannot be used:
  • Because hθ does not contain any position information.

2.2. Content and Query Representations

(a) Content representation, and (b) Query Representation
  • To avoid this problem, XLNet proposes to re-parameterize the next-token distribution to be target position aware:
  • where (xz<t, zt) denotes a new type of representations which additionally take the target position zt as input.
  • Thus, two sets of hidden representations are used instead of one:
  1. The content representation (xzt), or abbreviated as hzt, which serves a similar role to the standard hidden states in Transformer.
  2. The query representation (xz<t, zt), or abbreviated as gzt, which only has access to the contextual information xz<t and the position zt, but not the content xzt.
Overview of Two-Stream Attention, i.e. Content Stream & Query Stream
  • The first layer query stream is initialized with a trainable vector, i.e. g(0)i=w, while the content stream is set to the corresponding word embedding, i.e. h(0)i=e(xi).
  • For each self-attention layer m=1, …, M, the two streams of representations are schematically updated with a shared set of parameters as follows:
  • where Q, K, V denote the query, key, and value in an attention operation.
  • During finetuning, we can simply drop the query stream and use the content stream as a normal Transformer(-XL).

3. XLNet: Partial Prediction

  • To reduce the optimization difficulty, only the last tokens in a factorization order are predicted.
  • Formally, z is split into a non-target subsequence zc and a target subsequence z>c, where c is the cutting point. The objective is to maximize the log-likelihood of the target subsequence conditioned on the non-target subsequence:
  • A hyperparameter K is introduced to adjust the subsequence length:

4. XLNet: Incorporating Ideas from Transformer-XL

  • Two important techniques in Transformer-XL are integrated, namely the relative positional encoding scheme and the segment recurrence mechanism.
A detailed illustration of the content stream of the proposed objective with both the joint view and split views based on a length-4 sequence under the factorization order [3, 2, 4, 1]. Note that if we ignore the query representation, the computation in this figure is simply the standard self-attention, though with a particular attention mask.
  • The “mem” part is the segment recurrence mechanism concept of Transformer-XL. It caches the previous states to handle long sequence. (Please feel free to read Transformer-XL if interested.)
  • For example, when the order is [3, 2, 4, 1], only those hidden states that have been estimated will contribute the another hidden state estimation.
A detailed illustration of the query stream of the proposed objective with both the joint view and split views based on a length-4 sequence under the factorization order [3, 2, 4, 1]. The dash arrows indicate that the query stream cannot access the token (content) at the same position, but only the location information.
  • Similar for query representation.
  • (This section has too many details, so I don’t go deep here. Please feel free to read the paper directly.)

5. Experimental Results

5.1. Fair Comparison with BERT

  • All models are trained using the same data and hyperparameters as in BERT. The best of 3 BERT variants are used for comparison; i.e., the original BERT, BERT with whole word masking, and BERT without next sentence prediction.

Trained on the same data with an almost identical training recipe, XLNet outperforms BERT by a sizable margin on all the considered datasets.

5.2. SOTA Comparisons Using Multiple Datasets

Comparison with state-of-the-art results on the test set of RACE, a reading comprehension task, and on ClueWeb09-B, a document ranking task
Results on SQuAD, a reading comprehension dataset
Comparison with state-of-the-art error rates on the test sets of several text classification datasets
Results on GLUE. * indicates using ensembles
  • Numerous datasets, such as SQuaD, IMDB, GLUE, are evaluated.

XLNet generally outperforms BERT and RoBERTa.

--

--

Sik-Ho Tsang

PhD, Researcher. I share what I learn. :) Linktree: https://linktr.ee/shtsang for Twitter, LinkedIn, etc.