REALM Retrieval Augmented Language Model Pre Training (2020)
May 27, 2024REALMは、オープンドメイン質問応答向けの言語モデルであり、入力文に関連する文書をみつけるknowledge retrieverと見つけた文書と入力文から応答を生成するknowledge-augmented encoderで構成されている。 事前学習はmasked language modelingによる教師なし学習であり、一部のトークンがマスクされた入力文をあたえたときに、もとのマスクされた単語を出力できるように訓練する。 オープンドメイン質問応答のファインチューニングでは、質問文を入力したときに、回答を出力できるようにモデルを訓練する。 評価実験では、期待する回答がknowledge retrieverの学習に使うコーパスに文字通りにあることが前提にされている。
Knowedge retrieverが参照するコーパスを\(\mathcal{x}\)とすると、REALMの学習する確率分布\(p(y|x)\)を、\(x\)に関連する文書\(z\)を収集する部分と、\(x\)と\(z\)から\(y\)を予測する部分に分けられる。 \(p(y|x)\)を分解すると次の式になる。 $$ p(y|x) = \sum_{z\in \mathcal{Z}}p(y|z, x)p(z|x) $$
Knowledge retriverは、入力\(x\)と文書の関連度\(f(x, z)\)をソフトマックス関数に入力したときの値を\(p(z|x)\)とみなす。
$$
p(z|x)=\frac{\exp f(x, z)}{\sum_{z’}\exp f(x, z’)}
$$
\(f(x, z)\)は\(x\)と\(z\)の次元\(d\)のエンベディング\(\texttt{Embed}_{\texttt{input}}(x), \texttt{Embed}_{\texttt{doc}}(z)\)の内積であり、BERTの[CLS]トークンのベクトルを入力や文書のエンベディングとしてあつかう。
$$
\begin{align*}
f(x, z) &= \texttt{Embed}_{\texttt{input}}(x)^\top \texttt{Embed}_{\texttt{doc}}(z)\\
\texttt{Embed}_{\texttt{input}}(x)&={\rm \bf{W}}_{\texttt{input}}\texttt{BERT}_{\texttt{CLS}}(\texttt{[CLS]}x\texttt{[SEP]} )\\
\texttt{Embed}_{\texttt{doc}}(z)
&={\rm \bf{W}}_{\texttt{doc}}\texttt{BERT}_{\texttt{CLS}}(\texttt{join}_{\texttt{BERT}}(z_\text{title}, z_{\text{body}}))\\
&={\rm \bf{W}}_{\texttt{doc}}\texttt{BERT}_{\texttt{CLS}}(\texttt{[CLS]}z_{\text{title}}\texttt{[SEP]}z_\text{body}\texttt{[SEP]})\\
\end{align*}
$$
Augmented encoderは、\(p(y|z, x)\)にあたり、\(z\)と\(x\)を連結した系列をTransformerに入力する。 事前学習では、\(x\)と文書の本文\(\text{z}_{\text{body}}\)から\(J_x\)個の\(\texttt{[MASK]}\)を含む入力\(x\)の系列のマスクの中身を推定できるように訓練する。 $$ \begin{align*} p(y|z, x) &= \prod^{J_x}_{j=1}p(y_j|z, x)\\ p(y_j|z, x) &\propto\exp (w_j^\top \texttt{BERT}_{\texttt{MASK}(j)}(\texttt{join}_{\texttt{BERT}}(x, z_{\text{body}}))) \end{align*} $$ ある文書\(z\)の連続するトークン列が正解\(y\)になることを前提にしたオープン質問応答のファインチューニングでは、\(S(z, y)\)を\(y\)にマッチする区間の集合、\(\texttt{BERT}_{\texttt{START}(s)}, \texttt{BERT}_{\texttt{END}(s)}\)を区間\(s\)の最初と最後の区間として、\(p(y|z, x)\)を以下のように定める。 $$ \begin{align*} p(y|z, x)&\propto \sum_{s\in S(z, y)}\exp (\texttt{MLP}([h_{\texttt{START}(s)}; h_{\texttt{END}(s)}]))\\ h_{\texttt{START}(s)}&=\texttt{BERT}_{\texttt{START}(s)}(\texttt{join}_{\texttt{BERT}}(x, z_{\text{body}}))\\ h_{\texttt{END}(s)}&=\texttt{BERT}_{\texttt{END}(s)}(\texttt{join}_{\texttt{BERT}}(x, z_{\text{body}})) \end{align*} $$