FLASHATTENTION: Fast and Memory-Efficient Exact Attention with IO-Awareness (2022)
Let $N$ be the sequence length. In a standard Transformer, self-attention takes $O(N^2)$ time and $O(N^2)$ memory. Approximate methods such as Reformer: The Efficient Transformer reduce FLOPs by approximating attention, but they often do not yield significant wall-clock speedups and are therefore not widely used. FLASHATTENTION instead targets memory traffic: it reduces data movement between GPU high-bandwidth memory (HBM) and on-chip SRAM, which can be a major bottleneck. HBM is much larger but has lower bandwidth than SRAM; SRAM is fast but small. FLASHATTENTION tiles the computation: it loads blocks of $\mathbf{Q}$, $\mathbf{K}$, and $\mathbf{V}$ into SRAM, updates a block of the output $\mathbf{O}$, and then writes that block back to HBM. By repeating this over all blocks, FLASHATTENTION computes exact attention.
FLASHATTENTION also reduces HBM–SRAM traffic in the backward pass. In a typical implementation, one materializes the attention score matrix $\mathbf{S}$ and the softmax output $\mathbf{P}$ so they can be reused to compute gradients of $\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}$, and $\mathbf{O}$. When $\mathbf{Q}$, $\mathbf{K}$, $\mathbf{V}\in \mathbb{R}^{N\times d}$, both $\mathbf{S}$ and $\mathbf{P}$ are $O(N^2)$ in size. FLASHATTENTION avoids storing these $N\times N$ matrices by recomputing $\mathbf{S}$ and $\mathbf{P}$ from $\mathbf{O}$ and intermediate values during backpropagation.