Retrieval Augmented Generation for Knowledge Intensive NLP Tasks (2021)
January 27, 2024問題文から推論できない知識を必要とする自然言語処理のタスクの場合、Large Language Model (LLM) の事前学習による効果は限定的になる。 LLMのパラメタに知識がなければ回答できず、また、パラメタは正確な知識を記録する仕組みではない。 Retrieval Augmented Generation (RAG) は、LLMの外部から提供された知識を利用し、入力から推論できない知識にもとづく文書を生成する。
Retrieval Augmented Generation for Knowledge Intensive NLP Tasksは、文書の埋め込みベクトルを外部の知識に使うRAGであり、入力文で埋め込みベクトルを検索し、検索結果の上位の文書と入力文から文書を生成する。 文書の検索にはDense Passage Retriever (DPR) を、後の文書の生成にはseq2seqを使う。 DPRは、質問と文書を別々のBERTで埋め込みベクトルに変換し、質問に関係する文書と質問のベクトル間の内積を最大にするように学習する。 Faissなどの近似最近傍探索を利用すれば、実用可能な速度で推論できる。
説明のために、DPRを\(p_{\eta}(z|x)\), seq2seqを\(p_{\theta}(y_i|x, z, y_{1:i-1})\)とおく。 \(x\)は入力系列、\(z\)は文書、\(\eta\)と\(\theta\)はパラメタ、\(y_i\)は出力系列の\(i\)番目のトークンを示す。
提案された2つのRAGモデルのひとつは、RAG-Token Modelといい、検索で見つかった文書ごとに、文書を条件にした次のトークンの条件付き確率を求め、文書について周辺化することで、次のトークンを予測する。 式に下すと次のようになる。 $$ p_{\text{RAG-Token}}(y|x) \approx\prod^N_i \sum_{z\in \text{top-k}(p(\cdot|x))}p_\eta (z|x)p_{\theta}(y_i|x,z,y_{1:i-1}) $$ \(\argmax_y p(y|x)\)は、ビームサーチで求められる。
もうひとつのモデルは、RAG-Sequence Modelといい、1つの文書から1つの出力系列を出力する。 式では次のように表せる。 $$ p_{\text{RAG-Sequence}}(y|x)\approx\sum_{z\in \text{top-k}(p(\cdot |x))}p_{\eta}(z|x)p_{\theta}(y|x,z) = \sum_{z\in\text{top-k}(p(\cdot|x))}p_{\eta}(z|x)\prod^N_i p_{\theta}(y_i|x,z,y_{1:i-1}) $$ \(\argmax_y p(y|x)\)を求めるには、はじめに、\(z\)ごとにビームサーチを適用し、\(z\)ごとの\(y\)の候補の集合を生成する。 \(z_i\)において生成されていない\(y\)については、\(p_\theta (y|x,z_i)\approx 0\)として周辺化する。
実験では、seq2seqの部分 \(p_{\theta}(y_i|x,z,y_{1:i-1})\) にはBART-largeが使われた。 入力系列\(x\)と埋め込みベクトル\(z\)は連結してあたえられる。
訓練時は、DPRとseq2seqを同時に学習する。 そのため、入力文に対して検索すべき文書は訓練データとしては不要であり、サンプルは入力系列と出力系列のペア\((x_j, y_j)\)からなる。 \(\sum_j-\log p(y_j|x_j)\)を最小化するようにパラメタを更新する。