FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness (2022)
TransformerのQKV注意機構の時間、空間計算量は、系列長が\(N\)のときには\(O(N^2)\)になる。 Reformer: The Efficient TransformerなどのFLOP数を減らす手法はあるが、実際の処理時間を十分に短縮できず、普及していない。 FLASHATTENTIONは、FLOP数ではなく、GPUのHBM (high bandwidth memory) と SRAM間の転送量を減らし、処理時間を短縮する。 HBMの記憶領域はSRAMよりも数十倍大きいが、SRAMよりも数十倍遅い。 注意 $\mathbf{O}$ を計算するとき、FLASHATTENTIONは、$\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$, $\mathbf{O}$それぞれの一部の行ベクトルからなるブロックだけをSRAMに読み、$\mathbf{O}$の一部を更新し、計算した一部をHBMに書き込む。 更新を繰り返すと最終的に自己注意機構と等しい注意を計算できる。
FLASHATTENTIONは逆伝搬時のHBMとSRAM間の転送量も減らす。 通常のQKV注意機構であれば、$\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$, $\mathbf{O}$の勾配を計算するためにSoftmax関数に与える行列$\mathbf{S}$とその結果$\mathbf{P}$を記録する。 $\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$が\(\mathbb{R}^{N\times d}\)であれば$\mathbf{S}$と$\mathbf{P}$のサイズは$\textrm{O}(N^2)$になる。 FLASHATTENTIONは、順伝搬時に計算した中間結果と$\mathbf{O}$から$\mathbf{S}$と$\mathbf{P}$を計算し直すことで$\mathbf{S}$と$\mathbf{P}$の記録を省き、空間計算量を減らす。