FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness (2022)
December 29, 2025TransformerのQKV注意機構の時間、空間計算量は、系列長が\(N\)のときには\(O(N^2)\)になる。 Reformer: The Efficient TransformerなどのFLOP数を減らす手法はあるが、実際の処理時間を十分に短縮できず、普及していない。 FLASHATTENTIONは、FLOP数ではなく、GPUのHBM (high bandwidth memory) と SRAM間の転送量を減らし、処理時間を短縮する。 HBMの記憶領域はSRAMよりも数十倍大きいが、SRAMよりも数十倍遅い。 注意 $\mathbf{O}$ を計算するとき、FLASHATTENTIONは、$\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$, $\mathbf{O}$それぞれの一部の行ベクトルからなるブロックだけをSRAMに読み、$\mathbf{O}$の一部を更新し、計算した一部をHBMに書き込む。 更新を繰り返すと最終的に自己注意機構と等しい注意を計算できる。
FLASHATTENTIONは逆伝搬時のHBMとSRAM間の転送量も減らす。 通常のQKV注意機構であれば、$\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$, $\mathbf{O}$の勾配を計算するためにSoftmax関数に与える行列$\mathbf{S}$とその結果$\mathbf{P}$を記録する。 $\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$が\(\mathbb{R}^{N\times d}\)であれば$\mathbf{S}$と$\mathbf{P}$のサイズは$\textrm{O}(N^2)$になる。 FLASHATTENTIONは、順伝搬時に計算した中間結果と$\mathbf{O}$から$\mathbf{S}$と$\mathbf{P}$を計算し直すことで$\mathbf{S}$と$\mathbf{P}$の記録を省き、空間計算量を減らす。
FLASHATTENTIONでは、注意$\mathbf{O}$の計算にベクトルの最大値でスケールしたSoftmax関数が採用されている。 softmax関数は\(\mathbf{S}\)の行ベクトル\(x\in \mathbb{R}^d\)ごとに適用される。
$$ \begin{align*} \mathbf{S} &= \mathbf{Q}\mathbf{K}^\top \in \mathbb{R}^{N\times N}\\\\ \mathbf{P} &= \textrm{softmax}(\mathbf{S})\in \mathbb{R}^{N\times N}\\\\ \mathbf{O} &=\mathbf{P}\mathbf{V}\in\mathbb{R}^{N\times d}\\\\ m(x) &:= \max_{{i}} x_{i}\\\\ f(x) &:= [e^{x_1-m(x)} \dots e^{x_d-m(x)}]\\\\ \ell(x) &:= \sum_i {f(x)}_i\\\\ \text{softmax}(x) &:= \frac{f(x)}{\ell(x)} \end{align*} $$softmax関数の計算を安定させるために最大値$m(x)$でスケールされている。 Attention Is All You Needでは、$d$が大きくなればドット積の値も大きくなり勾配が消失するので、ベクトルの要素を$\sqrt{d}$でスケールしてからSoftmax関数が適用されている。
2つのベクトルとその統計量\(m\), \(\ell\)から、連結したベクトルにsoftmax関数を適用した結果を計算できる。
$$ \begin{align*} m(x) &= m([x^{(1)}x^{(2)}])=\max(m(x^{(1)}), m(x^{(2)})) \\\\ f(x) &= \left[ e^{m(x^{(1)})-m(x)}f(x^{(1)})\ \ \ \ e^{m(x^{(2)})-m(x)}f(x^{(2)}) \right]\\\\ \ell(x) &= \ell(\left[x^{(1)}\ x^{(2)}\right])=e^{m(x^{(1)})-m(x)}\ell(x^{(1)})+e^{m(x^{(2)})-m(x)}\ell (x^{(2)}) \\\\ \text{softmax}(x)&=\frac{f(x)}{\ell(x)} \end{align*} $$FLASHATTENTIONは、統計量から連結ベクトルのSoftmax関数の適用結果を求めあれる性質をもちいて、SRAMとHBM間の転送量を減らす。
具体的なアルゴリズムは以下の文献から引用したアルゴリズムになる。

アルゴリズムは$j$について帰納的に証明できる。 $\mathbf{K}, \mathbf{V}$の最初の$jB_c$行のブロックを$\mathbf{K}_{:j}\in\mathbb{R}^{jB_c\times d}$, $\mathbf{V}_{:j}\in\mathbb{R}^{jB_c\times d}$とおく。 また、$\mathbf{S}_{:,:j}=\mathbf{Q}\mathbf{K}^{\top}_{:j}\in\mathbb{R}^{N\times jB_c}$, $\mathbf{P}_{::j}=\text{softmax}(\mathbf{S}_{::j})\in\mathbb{R}^{N\times jB_c}$とする。 これらの変数をもとに、$j$番目の外側の繰り返しの後の$m, \ell, \mathbf{O}$を、$m^j, \ell^{(j)}, \mathbf{O}^{(j)}$として$j$番目の外側の繰り返しの後に、以下が成り立つことを帰納的に示す。
$$ \begin{align*} m^{(j)}&=\textrm{rowmax}(\mathbf{S}_{:,:j})\in\mathbb{R}^N\\\\ \ell^{(j)}&=\textrm{rowsum}(\exp(\mathbf{S}_{:,:j}-m^{(j)}))\in\mathbb{R}^N\\\\ \mathbf{O}^{(j)}&=\mathbf{P}_{:,:j}\mathbf{V}_{:j}\in\mathbb{R}^{N\times d} \end{align*} $$アルゴリズムの2行目で変数を初期したとき、上の式は成立する。 以降、$j=1,\dots,T_c-1$のときに上の3つの式が成り立つと仮定し、$j+1$でも成立することを示す。
アルゴリズムは$\mathbf{S}$をブロック単位で計算する。 $\mathbf{S}$の$B_C(j-1)+1$から$B_Cj$列のブロックを$\mathbf{S}_{:,j-1:j}\in \mathbb{R}^{N\times B_C}$とすると、$j$回目の外側の繰り返しは$\mathbf{S}_{:,j-1:j}$を計算する。 内側の繰り返しは$\mathbf{S}_{:,j-1:j}$を上から順番に$\mathbb{R}^{B_r\times B_c}$ずつ計算する。 $j$回目の繰り返しの終わりには、$\mathbf{S}_{:,j-1:j}$の各行の最大値$\textrm{rowmax}(\mathbf{S}_{:,j-1:j})$を計算できている。
以上の計算手順と11行目の$m^{\text{new}}_i$の更新方法より
$$ \begin{align*} m^{(j+1)}&=\text{max}(\textrm{rowmax}(\mathbf{S}_{:,:j}), \textrm{rowmax}(\mathbf{S}_{:,j:j+1}))\\\\ m^{(j+1)}&=\text{rowmax}(\mathbf{S}_{:,:j+1}) \end{align*} $$となり$j+1$でも成立する。
また、11行目において、仮定より$e^{m_i}\ell_i$は$\text{rowsum}(\exp(\mathbf{S}_{i,:j-1}))$, $e^{\tilde{m}_{ij}}\tilde{\ell}_{ij}$は$\text{rowsum}(\exp(\mathbf{S}_{ij}))$になるので、
$$ \ell^{(j+1)}=\text{rowsum}(\exp(\mathbf{S}_{:,:j+1}-m^{(j+1)}))\in\mathbb{R}^N $$も成り立つ。
$\mathbf{V}_{:j}$を$\mathbf{V}$の1列から$B_Cj$列からなるブロックとして、アルゴリズムの12行目の計算を外側の繰り返しにおける1回の計算に直し、$\tilde{m}$を$\textrm{rowmax}(\mathbf{S}_{:,j:j+1})$とすると
$$ \begin{align*} \mathbf{O}^{(j+1)} &=\textrm{diag}(\ell^{(j+1)})^{-1}(\textrm{diag}(\ell^{(j)})e^{m^{(j)}-m^{(j+1)}}\mathbf{O}^{(j)} + e^{\tilde{m}-m^{(j+1)}}\exp(\mathbf{S}_{:,j:j+1}-\tilde{m})\mathbf{V}_{j+1})\\\\ &=\textrm{diag}(\ell^{(j+1)})^{-1}(\textrm{diag}(\ell^{(j)})e^{m^{(j)}-m^{(j+1)}}\mathbf{P}_{:,:j}\mathbf{V}_{:j} + e^{-m^{(j+1)}}\exp(\mathbf{S}_{:,j:j+1})\mathbf{V}_{j+1})\\\\ &=\textrm{diag}(\ell^{(j+1)})^{-1}(\textrm{diag}(\ell^{(j)})e^{m^{(j)}-m^{(j+1)}}\textrm{diag}(\ell^{(j)})\exp (\mathbf{S}_{:,:j}-m^{(j)})\mathbf{V}_{:j}+ e^{-m^{(j+1)}}\exp(\mathbf{S}_{:,j:j+1})\mathbf{V}_{j+1})\\\\ &=\textrm{diag}(\ell^{(j+1)})^{-1}(e^{-m^{(j+1)}}\exp (\mathbf{S}_{:,:j})\mathbf{V}_{:j}+e^{-m^{(j+1)}}\exp(\mathbf{S}_{:,:j:j+1})\mathbf{V}_{j+1})\\\\ &=\textrm{diag}(\ell^{(j+1)})^{-1}(\exp(\mathbf{S}_{:,:j}-m^{(j+1)})\mathbf{V}_{:j}+\exp (\mathbf{S}_{:,j:j+1}-m^{(j+1)})\mathbf{V}_{j+1})\\\\ &=\textrm{softmax}(\mathbf{S}_{:,:j+1})\mathbf{V}_{j+1} \end{align*} $$となり、$j+1$でも成りたつ。 以上からFLASHATTENTIONで自己注意機構と等しい注意を計算できる。