REALM Retrieval Augmented Language Model Pre Training (2020)
May 27, 2024REALMは、オープンドメイン質問応答向けの言語モデルであり、入力文に関連する文書をみつけるknowledge retrieverと見つけた文書と入力文から応答を生成するknowledge-augmented encoderで構成されている。 事前学習はmasked language modelingによる教師なし学習であり、一部のトークンがマスクされた入力文をあたえたときに、もとのマスクされた単語を出力できるように訓練する。 オープンドメイン質問応答のファインチューニングでは、質問文を入力したときに、回答を出力できるようにモデルを訓練する。 評価実験では、期待する回答がknowledge retrieverの学習に使うコーパスに文字通りにあることが前提にされている。
Knowedge retrieverが参照するコーパスを\(\mathcal{x}\)とすると、REALMの学習する確率分布\(p(y|x)\)を、\(x\)に関連する文書\(z\)を収集する部分と、\(x\)と\(z\)から\(y\)を予測する部分に分けられる。 \(p(y|x)\)を分解すると次の式になる。
$$ p(y|x) = \sum\_{z\in \mathcal{Z}}p(y|z, x)p(z|x) $$Knowledge retriverは、入力\(x\)と文書の関連度\(f(x, z)\)をソフトマックス関数に入力したときの値を\(p(z|x)\)とみなす。
$$ p(z|x)=\frac{\exp f(x, z)}{\sum\_{z'}\exp f(x, z')} $$
\(f(x, z)\)は\(x\)と\(z\)の次元\(d\)のエンベディング\(\texttt{Embed}_{\texttt{input}}(x), \texttt{Embed}_{\texttt{doc}}(z)\)の内積であり、BERTの[CLS]トークンのベクトルを入力や文書のエンベディングとしてあつかう。
Augmented encoderは、\(p(y|z, x)\)にあたり、\(z\)と\(x\)を連結した系列をTransformerに入力する。 事前学習では、\(x\)と文書の本文\(\text{z}_{\text{body}}\)から\(J_x\)個の\(\texttt{[MASK]}\)を含む入力\(x\)の系列のマスクの中身を推定できるように訓練する。
$$ \begin{align*} p(y|z, x) &= \prod^{J\_x}\_{j=1}p(y\_j|z, x)\\\\ p(y\_j|z, x) &\propto\exp (w\_j^\top \texttt{BERT}\_{\texttt{MASK}(j)}(\texttt{join}\_{\texttt{BERT}}(x, z\_{\text{body}}))) \end{align*} $$ある文書\(z\)の連続するトークン列が正解\(y\)になることを前提にしたオープン質問応答のファインチューニングでは、\(S(z, y)\)を\(y\)にマッチする区間の集合、\(\texttt{BERT}_{\texttt{START}(s)}, \texttt{BERT}_{\texttt{END}(s)}\)を区間\(s\)の最初と最後の区間として、\(p(y|z, x)\)を以下のように定める。
$$ \begin{align*} p(y|z, x)&\propto \sum\_{s\in S(z, y)}\exp (\texttt{MLP}([h\_{\texttt{START}(s)}; h\_{\texttt{END}(s)}]))\\\\ h\_{\texttt{START}(s)}&=\texttt{BERT}\_{\texttt{START}(s)}(\texttt{join}\_{\texttt{BERT}}(x, z\_{\text{body}}))\\\\ h\_{\texttt{END}(s)}&=\texttt{BERT}\_{\texttt{END}(s)}(\texttt{join}\_{\texttt{BERT}}(x, z\_{\text{body}})) \end{align*} $$