Retrieval Augmented Generation for Knowledge Intensive NLP Tasks (2021)
January 27, 2024While large pre-trained language models (LLMs) can implicitly store factual knowledge in their parameters, their ability to access explicit encyclopedic and commonsense is knowledge still limited. Retrieval Augmented Generation (RAG) is a method to improve language generation by providing external explicit knowledge to LLMs.
In Retrieval Augmented Generation for Knowledge Intensive NLP Tasks, two RAG models are introduced: RAG-Token model and RAG-Sequence Model. They use dense vectors index of Wikipedia as external knowledge. They combine Dense Passage Retriever (DPR) and seq2seq. DPR is a neural retriever that encodes documents and an input sequence using BERTs, and then computes inner products of the embeddings of a document and the input sequence. The two RAG models run DPR to find the dense vectors of the top-k documents that are relevant to an input sequence. The dense vector of the found document and the input sequence are concatenated and then passed to the seq2seq model.
The RAG models leverage the DPR component \(p_{\eta}(z|x)\) and seq2seq component \(p_{\theta}(y_i|x, z, y_{1:i-1})\). The DPR retrieve documents \(z\) that are relevant to the input sequence \(x\). seq2seq uses \(x\) and \(z\) as additional context to generate the target seqeuence \(y\). \(\eta\) and \(\theta\) represent the model parameters. \(y_i\) denotes the \(i^{\text{th}}\) token in the target sequence.
In the RAG-Token, the seq2seq produces a distribution for the next output token for each retrieved document and then marginalize the documents: $$ p_{\text{RAG-Token}}(y|x) \approx\prod^N_i \sum_{z\in \text{top-k}(p(\cdot|x))}p_\eta (z|x)p_{\theta}(y_i|x,z,y_{1:i-1}) $$ where \(N\) is the length of the target sequence. The decoding uses a standard beam search to approximate \(\argmax_y p(y|x)\).
The RAG-Sequence Model uses the same retrieved document to generate the target sequence: $$ p_{\text{RAG-Sequence}}(y|x)\approx\sum_{z\in \text{top-k}(p(\cdot |x))}p_{\eta}(z|x)p_{\theta}(y|x,z) = \sum_{z\in\text{top-k}(p(\cdot|x))}p_{\eta}(z|x)\prod^N_i p_{\theta}(y_i|x,z,y_{1:i-1}) $$ The decoding runs beam search for each document \(z\) to yield a set of hypotheses. Some of which may not have appeared in the beams of all documents. In that case,\(p_{\theta}(y|x,z_i)\approx 0\) can be used for efficient decoding.
The DPR and the seq2seq components can be trained jointly. Given a training dataset of pairs \((x_i, y_i)\), the parameters are updated to minimize the negative marginal log-likelihood of each target, \(\sum_i-\log p(y_i|x_i)\).