Review — RAG Model: Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
Recent Langchain Framework Also Based on RAG Model for Dense Text Retrieval
Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
Retrieval-Augmented Generation (RAG), by Facebook AI Research; University College London; New York University
2020 NeurIPS, Over 870 Citations (Sik-Ho Tsang @ Medium)Dense Text Retrieval
==== My Other Paper Readings Are Also Over Here ====
- A general-purpose fine-tuning recipe is proposed for retrieval-augmented generation (RAG) — models which combine pre-trained parametric and non-parametric memory for language generation.
- RAG models where the parametric memory is a pre-trained seq2seq model and the non-parametric memory is a dense vector index of Wikipedia, accessed with a pre-trained neural retriever.
Outline
- Retrieval-Augmented Generation (RAG) Conceptual Idea
- RAG Model
- Results
1. Retrieval-Augmented Generation (RAG) Conceptual Idea
There are 2 RAG model variants: RAG-Sequence Model and RAG-Token Model.
1.1. RAG-Sequence Model
- The RAG-Sequence model uses the same retrieved document to generate the complete sequence.
Technically, it treats the retrieved document as a single latent variable that is marginalized to get the seq2seq probability p(y|x) via a top-K approximation.
- Concretely, the top K documents are retrieved using the retriever, and the generator produces the output sequence probability for each document, which are then marginalized:
1.2. RAG-Token Model
- In the RAG-Token model, a different latent document is drawn for each target token and marginalized accordingly. This allows the generator to choose content from several documents when producing an answer.
Concretely, the top K documents are retrieved using the retriever, and then the generator produces a distribution for the next output token for each document, before marginalizing, and repeating the process with the following output token, Formally:
1.3. Sequence Classification Task
- RAG can be used for sequence classification tasks by considering the target class as a target sequence of length one, in which case RAG-Sequence and RAG-Token are equivalent.
2. RAG Model
2.1. Retriever: DPR
- The retriever is based on DPR [26]. DPR follows a bi-encoder architecture:
- where d(z) is a dense representation of a document produced by a BERT_BASE document encoder, and q(x) is a query representation produced by a query encoder, also based on BERT_BASE.
- Calculating top-k(pη(.|x)), the list of k documents z with highest prior probability pη(z|x), is a Maximum Inner Product Search (MIPS) problem, which can be approximately solved in sub-linear time [23], i.e. using FAISS [23].
A pre-trained bi-encoder from DPR is used to initialize the retriever and to build the document index. This retriever was trained to retrieve documents which contain answers to TriviaQA [24] questions and Natural Questions [29].
The document index is referred as the non-parametric memory.
2.2. Generator: BART
The generator component pθ(yi|x, z, y1:i-1) could be modelled using any encoder-decoder. BART-large, a pre-trained seq2seq Transformer with 400M parameters, is used.
- BART was pre-trained using a denoising objective and a variety of different noising functions, outperforms comparably-sized T5.
- To combine the input x with the retrieved content z when generating from BART, they are simply concatenated.
The BART generator parameters θ is refered as the parametric memory.
2.3. Training
- The retriever and generator components can be trained jointly.
- Given a fine-tuning training corpus of input/output pairs (xj, yj), the RAG model can be minimized using the negative marginal log-likelihood:
Yet, updating the document encoder BERT_d during training is costly. Thus, The document encoder (and index) is keeping fixed, only the query encoder BERT_q and the BART generator are fine-tuned.
2.4. Decoding
- For RAG-Token model, it can be seen as a standard, autoregressive seq2seq generator with transition probability:
- To decode, p’θ(yi|x, y1:i-1) is plugged into a standard beam decoder.
- For RAG-Sequence, beam search is run for each document z, scoring each hypothesis using:
- This yields a set of hypotheses Y, some of which may not have appeared in the beams of all documents.
- There are two decoding methods:
- Thorough Decoding: To estimate the probability of an hypothesis y, an additional forward pass is run for each document z for which y does not appear in the beam, multiply generator probability with pη(z|x) and then sum the probabilities across beams for the marginals.
- Fast Decoding: For longer output sequences, |Y| can become large, requiring many forward passes. For more efficient decoding, a further approximation can be made: pθ(y|x, zi)≈0 where y was not generated during beam search from x, zi.
3. Results
- A single Wikipedia December 2018 dump is used for the non-parametric knowledge source. Each Wikipedia article is split into disjoint 100-word chunks, to make a total of 21M documents.
- The document encoder is used to compute an embedding for each document, and build a single MIPS index using FAISS [23] with a Hierarchical Navigable Small World approximation for fast retrieval [37].
- During training, the top k documents are retrieved for each query. k∈{5,10} is considered for training and k is set for test time using dev data.
3.1. Open-domain Question Answering
Table 1: Unlike REALM and T5+SSM, RAG enjoys strong results without expensive, specialized “salient span masking” pre-training [20].
3.2. Abstractive Question Answering
Table 2: RAG-Sequence outperforms BART on Open MS-MARCO NLG by 2.6 Bleu points and 2.6 Rouge-L points.
- Table 3: Some examples are shown above.
3.3. Jeopardy Question Generation
Table 2: RAG-Token performs better than RAG-Sequence on Jeopardy question generation, with both models outperforming BART on Q-BLEU-1.
Table 4: Evaluators indicated that BART was more factual than RAG in only 7.1% of cases, while RAG was more factual in 42.7% of cases, and both RAG and BART were factual in a further 17% of cases.
- Figure 2: A visualization is shown above.
3.4. Fact Verification
Table 2: For 3-way classification, RAG scores are within 4.3% of state-of-the-art models, which are complex pipeline systems.
3.5. Further Study
Table 5: RAG-Sequence’s generations are more diverse than RAG-Token’s, and both are significantly more diverse than BART without needing any diversity-promoting decoding.
Table 6: Learned retrieval improves results for all tasks.
Figure 3 (left) shows that retrieving more documents at test time monotonically improves Open-domain QA results for RAG-Sequence, but performance peaks for RAG-Token at 10 retrieved documents.
Figure 3 (right) shows that retrieving more documents leads to higher Rouge-L for RAG-Token at the expense of Bleu-1, but the effect is less pronounced for RAG-Sequence.