OUTRAGEOUSLY LARGE NEURAL NETWORKS THE SPARSELY-GATED MIXTURE-OF-EXPERTS LAYER (2017)
August 1, 2025パラメタ数を増やせば多くの情報をモデルに学習させられるが、計算量も増える。 OUTRAGEOUSLY LARGE NEURAL NETWORKS: THE SPARSELY-GATED MIXTURE-OF-EXPERTS LAYER (MoE) は、ゲートと数千規模の全結合層からなる層であり、ゲートの後に全結合層を並列に配置する。 ゲートは、サンプルごとに疎なベクトルを出力する。 各サンプルの推論において、ベクトルの0でない要素に対応する全結合層だけを計算対象に限定し、パラメタ数の増加と計算量の抑制を両立する。
サンプルを\(x\), ゲートを\(G\), \(n\)個の\(i\)番目の全結合層を\(E_i\)をおくと、MoE層の\(x\)に対する出力は、ゲートと全結合層の出力の内積になる。 $$ y=\sum^n_{i=1}G(x)_iE_i(x) $$
\(G\)の出力の次元は\(n\)で、ハイパーパラメタ\(k\)の要素が1未満の正の値に、それ以外の要素は\(0\)になる。 \(G(x)\)の出力が分かれば、\(y\)の計算に必要な全結合層は\(k\)個のみになる。 \(G\)の重みは2つある。 ひとつは全結合層を選ぶための重み\(W_g\)であり、残りはノイズを調整する重み\(W_{\text{noise}}\)である。 $$ \begin{align*} G(x)&=\text{Softmax}(\text{KeepTopK}(H(x), k))\\ H(x)_i&=(x\cdot W_g)_i + \text{StandardNormal()}\cdot \text{Softplus}((x\cdot W_{\text{noise}})_i)\\ \text{KeepTopK}(v,k)_i&= \begin{cases} v_i&\text{if }v_i \text{ is in the top }k\text{ elements of }v\\ -\infty&\text{otherwise} \end{cases} \end{align*} $$
選ばれる全結合層が偏ると計算を分散できなくなる。そこで、偏りに対する2つの損失関数\(L_{\text{importance}}(X)\)と\(L_{\text{load}}(X)\)を導入する。以下の式の\(\textit{CV}\)を変動係数とすると、\(L_{\text{importance}}(X)\)は $$ \begin{align*} \text{Importance}(X)&=\sum_{x\in X}G(x)\\ L_{\text{importance}}(X)&=w_{\text{importance}}\cdot \textit{CV}(\text{Importance}(X))^2 \end{align*} $$ と定義される。
ノイズ\(\text{StandardNormal()}\)の役割は、選ばれる全結合層に偏りをなくすことにある。 全結合層に渡されるサンプル数は、離散値であり、誤差逆伝播法を使えない。 そこで、全結合層に割り当てられるサンプル数の不均衡さを示す滑らかな関数\(L_{\text{load}}(X)\)を定義し、損失の一部として計算する。
まず、\(\text{kth\_excluding}(v, k, i)\)を\(i\)を除いて\(k\)番目に値の大きい\(v\), 正規分布の累積分布関数を\(\Phi\), \(G(x)_i\)が\(0\)でない確率を\(P(x, i)\)とする。 $$ \begin{align*} P(x, i) &=\textit{Pr}\left((x\cdot W_g)_i + \text{StandardNormal}()\cdot \text{Softplus}((x\cdot W_{\text{noise}})_i) > \text{kth\_excluding}(H(x), k, i)\right)\\ &=\Phi\left(\frac{(x\cdot W_g)_i - \text{kth\_excluding}(H(x), k, i)}{\text{Softplus}((x\cdot W_{\text{noise}})_i)}\right)\\ \end{align*} $$ このとき、 \(w_{\text{load}}\)をハイパーパラメタとすると、\(L_{\text{load}}(X)\)は $$ \begin{align*} \text{Load}(X)_i&= \sum_{x\in X} P(x,i)\\ L_{\text{load}}(X)&=w_{\text{load}}\cdot\text{CV}(\text{Load}(X))^2 \end{align*} $$ になる。