REALM Retrieval Augmented Language Model Pre Training (2020)
May 27, 2024REALM is a language model for open-domain question answering. REALM is composed of two components: the neural knowledge retriever and the knowledge-augmented encoder. The neural knowledge retriever finds documents related to the input. The knowledge-augmented encoder generates responses from the found documents and the input. In pre-training, the REALM employs masked language modeling, and it is trained to predict the original tokens that are maked in the given sentences. In fine-tuning, the model is trained to produce the answer of the given input. The authors experimented the model with Open-domain QA tasks under the assumption that the answer can be found as a contiguous sequence of tokens in some documents.
Let \(\mathcal{Z}\) a textual knowledge corpus, \(x\) some input, REALM takes some input \(x\), REALM learns a distribution \(p(y|x)\) over possible outputs \(y\). REALM decomposes \(p(y|x)\) into two steps: retrieve \(p(z|x)\) then predict \(p(y|z, x)\):
$$ p(y|x) = \sum\_{z\in \mathcal{Z}}p(y|z, x)p(z|x) $$Knowledge retriever models \(p(z|x)\) and is defined as a dense inner product model:
$$ p(z|x)=\frac{\exp f(x, z)}{\sum\_{z'}\exp f(x, z')} $$
\(f(x, z)\) is the inner product of the vector embeddings, and denotes the relevance score between \(x\) and \(z\).
The embedding functions are implemented using BERT-style Transformer.
The embeddings of [CLS] for each \(x\) and \(z\) are used as their representation.
The augmented encoder concatenates \(x\) and \(z\) into a single sequence, and feeds into a Tranformer.
In pre-training, the encoder is trained to predict the original value of each \(\texttt{[MASK]}\) token in \(x\).
Let \(J_x\) is the total number of [MASK] tokens in \(x\), \(w_j\) is a learned word embedding for token \(y_j\), the encoder models \(p(y|z, x)\):
For Open-QA fine-tuning, let \(S(z, y)\) be the set of spans matching \(y\) in \(z\), \(p(y|z, x)\) is defined as:
$$ \begin{align*} p(y|z, x)&\propto \sum\_{s\in S(z, y)}\exp (\texttt{MLP}([h\_{\texttt{START}(s)}; h\_{\texttt{END}(s)}]))\\\\ h\_{\texttt{START}(s)}&=\texttt{BERT}\_{\texttt{START}(s)}(\texttt{join}\_{\texttt{BERT}}(x, z\_{\text{body}}))\\\\ h\_{\texttt{END}(s)}&=\texttt{BERT}\_{\texttt{END}(s)}(\texttt{join}\_{\texttt{BERT}}(x, z\_{\text{body}})) \end{align*} $$