REALM Retrieval Augmented Language Model Pre Training (2020)
May 27, 2024REALM is a language model for open-domain question answering. REALM is composed of two components: the neural knowledge retriever and the knowledge-augmented encoder. The neural knowledge retriever finds documents related to the input. The knowledge-augmented encoder generates responses from the found documents and the input. In pre-training, the REALM employs masked language modeling, and it is trained to predict the original tokens that are maked in the given sentences. In fine-tuning, the model is trained to produce the answer of the given input. The authors experimented the model with Open-domain QA tasks under the assumption that the answer can be found as a contiguous sequence of tokens in some documents.
Let \(\mathcal{Z}\) a textual knowledge corpus, \(x\) some input, REALM takes some input \(x\), REALM learns a distribution \(p(y|x)\) over possible outputs \(y\). REALM decomposes \(p(y|x)\) into two steps: retrieve \(p(z|x)\) then predict \(p(y|z, x)\): $$ p(y|x) = \sum_{z\in \mathcal{Z}}p(y|z, x)p(z|x) $$
Knowledge retriever models \(p(z|x)\) and is defined as a dense inner product model:
$$
p(z|x)=\frac{\exp f(x, z)}{\sum_{z’}\exp f(x, z’)}
$$
\(f(x, z)\) is the inner product of the vector embeddings, and denotes the relevance score between \(x\) and \(z\).
The embedding functions are implemented using BERT-style Transformer.
The embeddings of [CLS]
for each \(x\) and \(z\) are used as their representation.
$$
\begin{align*}
f(x, z) &= \texttt{Embed}_{\texttt{input}}(x)^\top \texttt{Embed}_{\texttt{doc}}(z)\\
\texttt{Embed}_{\texttt{input}}(x)&={\rm \bf{W}}_{\texttt{input}}\texttt{BERT}_{\texttt{CLS}}(\texttt{[CLS]}x\texttt{[SEP]} )\\
\texttt{Embed}_{\texttt{doc}}(z)
&={\rm \bf{W}}_{\texttt{doc}}\texttt{BERT}_{\texttt{CLS}}(\texttt{join}_{\texttt{BERT}}(z_\text{title}, z_{\text{body}}))\\
&={\rm \bf{W}}_{\texttt{doc}}\texttt{BERT}_{\texttt{CLS}}(\texttt{[CLS]}z_{\text{title}}\texttt{[SEP]}z_\text{body}\texttt{[SEP]})\\
\end{align*}
$$
The augmented encoder concatenates \(x\) and \(z\) into a single sequence, and feeds into a Tranformer.
In pre-training, the encoder is trained to predict the original value of each \(\texttt{[MASK]}\) token in \(x\).
Let \(J_x\) is the total number of [MASK]
tokens in \(x\), \(w_j\) is a learned word embedding for token \(y_j\), the encoder models \(p(y|z, x)\):
$$
\begin{align*}
p(y|z, x) &= \prod^{J_x}_{j=1}p(y_j|z, x)\\
p(y_j|z, x) &\propto\exp (w_j^\top \texttt{BERT}_{\texttt{MASK}(j)}(\texttt{join}_{\texttt{BERT}}(x, z_{\text{body}})))
\end{align*}
$$
For Open-QA fine-tuning, let \(S(z, y)\) be the set of spans matching \(y\) in \(z\), \(p(y|z, x)\) is defined as: $$ \begin{align*} p(y|z, x)&\propto \sum_{s\in S(z, y)}\exp (\texttt{MLP}([h_{\texttt{START}(s)}; h_{\texttt{END}(s)}]))\\ h_{\texttt{START}(s)}&=\texttt{BERT}_{\texttt{START}(s)}(\texttt{join}_{\texttt{BERT}}(x, z_{\text{body}}))\\ h_{\texttt{END}(s)}&=\texttt{BERT}_{\texttt{END}(s)}(\texttt{join}_{\texttt{BERT}}(x, z_{\text{body}})) \end{align*} $$