Why FlashAttention was a breakthrough
Same math, same exact outputs, same asymptotic compute — and yet it made attention several times faster and unlocked long context. The trick was noticing attention was a memory problem, not a compute problem.
Why it exists
If you read the FlashAttention paper expecting a clever new approximation to attention, you’ll be confused. There isn’t one. The output is the same softmax attention the original transformer paper described in 2017 — exact, up to floating-point reordering. The asymptotic FLOP count is the same (the backward pass actually does more arithmetic, because it recomputes intermediates instead of storing them — more on that below). The model isn’t changed. And yet it ran several times faster on real GPUs and made it practical to train and serve transformers at sequence lengths that previously OOM’d.
That mismatch — same math, dramatically different wall-clock — is the entire point. Before FlashAttention, the prevailing intuition was that attention was compute-bound: it does an N×N matmul, and matmuls are what GPUs are for, so what’s left to optimize? The answer turned out to be everything. The bottleneck was never the multiplications. It was the N×N matrix being shuffled in and out of HBM three or four times per attention layer, with the GPU’s actual compute units sitting idle waiting for memory.
The breakthrough was noticing this and writing a kernel that respected the memory hierarchy. The technique — tiling plus an online softmax trick — is not novel in computer science; it’s the standard playbook for memory-bound numerical kernels. The novelty was applying that playbook to attention, demonstrating that attention had been a memory-bandwidth problem all along, and shipping a kernel everyone could use. The 2022 paper by Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré (arXiv:2205.14135) made that case, and it was unusually consequential — the FA-2/FA-3 lineage continues today, and the underlying kernel (or a derivative) is widely available through major training and inference stacks. (My read on adoption breadth, not a sourced market study.)
Why it matters now
The reason FlashAttention is still a load-bearing piece of infrastructure four years later, instead of a paper everyone moved on from, is that the bandwidth gap it exploits has only gotten wider:
- Long context is normal. Frontier models advertise 200k–1M token windows. Materializing an N×N scores matrix in HBM at N=1,000,000 is impossible — it’s a trillion entries per layer. FlashAttention is one of the reasons the activation memory of attention is linear in N in current stacks; without something equivalent, very long context wouldn’t fit on the device at all.
- Attention is more of the work than it used to be. As sequences get longer, the quadratic attention term overtakes the linear-in-N feed-forward term. Optimizing attention used to be a small win on short sequences; now it’s most of the budget on long ones.
- Hardware kept getting more memory-skewed. Each new GPU generation adds compute throughput faster than it adds memory bandwidth. Algorithms that ignore the memory hierarchy fall further behind every generation, not closer. (See why VRAM is the bottleneck.)
- It set the template. The “treat it as an IO problem” reframe turned into a generation of follow-up work — FlashAttention-2 (2023), FlashAttention-3 (2024) targeting Hopper-class H100s, and a wave of fused kernels for everything else in the transformer stack.
The short answer
FlashAttention = exact softmax attention + tiling + online softmax + kernel fusion
It’s the same attention as before — same inputs, same exact outputs, same asymptotic FLOP count — restructured so the N×N intermediate matrix never gets written to slow GPU memory. The work is done in small blocks that fit in fast on-chip memory and are processed in a single fused kernel. The math is unchanged; only the data movement is. (The backward pass does add some recomputation FLOPs as the price of not storing the attention matrix.)
How it works
Three ingredients, each of which is well-understood on its own. Their combination is the thing that hadn’t been done for attention.
1. The memory hierarchy nobody was respecting
A GPU has two memory tiers that matter here. SRAM is on-chip, sits next to the compute units, and is fast — but tiny, on the order of tens of MB total across all streaming multiprocessors. HBM is off-chip DRAM, much slower per byte but much bigger (tens of GB).
The original FlashAttention paper cites an A100 example where SRAM bandwidth is roughly 19 TB/s versus 1.5 TB/s for HBM — about 13× the throughput in roughly 1/2000th the capacity. Standard attention writes the full N×N scores matrix out to HBM, reads it back to apply softmax, writes the result back, and reads it again to multiply by V. Those round-trips are the bottleneck. The matmul units finish early and wait.
This is the core empirical claim of the paper: a benchmark profiling shows attention is memory-bound, not compute-bound. Once you see that, the fix is structural: don’t materialize the N×N matrix in HBM at all.
2. Tiling: do the work in blocks that fit in SRAM
Tiling is the textbook technique for memory-bound numerical work — it’s how good BLAS matmul kernels have always worked. FlashAttention applies it to attention.
Conceptually:
for each block of K, V (loaded into SRAM):
for each block of Q (loaded into SRAM):
compute partial scores = Q_block · K_blockᵀ # in SRAM
compute partial softmax + partial output # in SRAM
update running output and softmax statistics # tiny
The N×N scores matrix is never assembled in HBM. Each block is computed, used, and discarded inside SRAM. The kernel writes back the final output (shape N×d, same size as the input) and a small O(N) per-row softmax statistic that the backward pass needs (the running max m and the running sum-of-exps l, sometimes combined and stored as m + log l); intermediate scores and probabilities are gone. Total HBM traffic drops from O(N²) to roughly O(N²·d²/M), where M is the SRAM size — a big concrete win because M is large enough that the d²/M factor is small.
3. Online softmax: the part that took some thought
Tiling is only “trivial textbook stuff” if your operation is associative across blocks. Matmul is. Softmax isn’t, naively — softmax(x_i) = exp(x_i) / Σ exp(x_j) requires knowing the whole row to compute any entry, because of the denominator. That global dependency is what stops you from just tiling attention the obvious way.
The fix, often called the online softmax trick, predates FlashAttention — Milakov and Gimelshein described an online normalizer in 2018 (arXiv:1805.02867) — but FlashAttention is what plumbed it through to a full attention kernel. The idea: as you stream through blocks of a row, keep two running scalars per row — the running max m (for numerical stability) and the running sum-of-exps l. When a new block arrives with its own local max and local sum, you rescale the running quantities and the partial output, then fold in the new block. After the last block, the result equals what a one-shot softmax would have produced, up to floating-point reordering. No approximation, no drift.
This is the part of FlashAttention that took genuine engineering taste to land — getting the rescaling exactly right while staying numerically stable, doing it inside a single CUDA kernel, and making the backward pass also work. The forward pass keeps only O(N) softmax statistics; the backward pass uses recomputation (re-deriving the attention matrix on the fly during backprop instead of storing it) to keep memory linear too.
4. Kernel fusion: don’t leave SRAM until you’re done
The final ingredient is fusing all of attention — score matmul, softmax, output matmul — into one CUDA kernel. Standard PyTorch attention does each step as a separate kernel launch, which means each step has to read its inputs from HBM and write its outputs back. FlashAttention does the whole thing inside one kernel, so blocks live in SRAM for the duration. This is the same fusion idea behind a lot of hand-tuned GPU work; FlashAttention’s contribution wasn’t inventing fusion, it was finding the version that fused all of attention without sacrificing exactness.
What the numbers actually showed
The reported speedups depend a lot on what you measure. The 2022 paper highlights:
- ~7.6× speedup on the attention computation itself on GPT-2.
- ~3× end-to-end speedup on GPT-2 training (sequence length 1k).
- ~15% end-to-end speedup on BERT-large training (sequence length 512), beating the then-current MLPerf record.
End-to-end gains are smaller than attention-only gains because attention isn’t 100% of the work, especially at short sequence lengths. The longer the sequence, the bigger the fraction of total time attention takes, and the bigger FlashAttention’s relative win — which is part of why long-context training became practical around the time this kernel and its successors became standard. (I’m not claiming sole causation; that’s an oversimplification — sequence-parallel training, ring attention, and architectural changes all contributed.)
What it did not change
It’s worth being precise about the limits, because the reframing only goes so far.
- Attention is still O(N²) in compute. The quadratic in N is unchanged. The asymptotic FLOP count is the same; the backward pass actually does more arithmetic because it recomputes the attention matrix instead of storing it. (The 2022 paper’s own benchmark shows a higher GFLOP count for FlashAttention than standard attention on one of its tests — and still wins, because it spends them on hot data instead of HBM round-trips.) FlashAttention changes the bandwidth, not the arithmetic.
- It is exact. No approximation, unlike sparse / linear / sliding-window attention. The output is identical to vanilla attention up to floating-point reordering.
- The bottleneck moved. After FlashAttention, attention is much closer to compute-bound. Follow-up work (FA-2, FA-3) focuses on extracting more compute utilization (parallelism, async pipelines, low precision) rather than further bandwidth reduction — my read is that the easy bandwidth wins were taken by the original kernel, but I haven’t seen a head-to-head bandwidth attribution study to confirm that.
The follow-ups, briefly
FlashAttention-2 (2023) reorganized the parallelism (different work partitioning across thread blocks, fewer non-matmul operations) and roughly doubled throughput on Ampere-class GPUs. FlashAttention-3 (2024) targets Hopper (H100) specifically — exploiting TMA, warp-specialized async pipelines, and FP8 / BF16 with block-quantized FP8 — and reports lifting H100 attention utilization from roughly 35% (FA-2) to ~85% in BF16 (~1.3 PFLOPs) in the NeurIPS 2024 paper; the July 2024 preprint reported ~75% / FP16, so the exact headline number depends on which version you read. My read, not consensus: each follow-up is a constant-factor win that gets harder to find as the easy bandwidth wins have already been taken.
Famous related terms
- Why attention is quadratic —
attention cost ∝ N² · d. The constraint FlashAttention works around, not the one it removes. Useful background for understanding what didn’t change. - Memory bandwidth —
memory bandwidth = the actual bottleneck for most LLM operations, not FLOPs. The general principle FlashAttention is one famous instance of. - Tiling / blocking —
tiling = split a big matrix op into small blocks that fit in fast cache + reuse each block while it's hot. The cornerstone trick for memory-bound numerical kernels since long before deep learning. - Online softmax —
online softmax = streaming softmax + a running max and a running sum that get rescaled as new blocks arrive. The piece that lets softmax be tiled exactly. Predates FlashAttention; Milakov & Gimelshein 2018. - Kernel fusion —
kernel fusion = collapse multiple GPU ops into one kernel + keep intermediates in registers/SRAM. The reason all of attention can stay on-chip in one pass. - Recomputation (gradient checkpointing) —
recomputation = trade FLOPs for memory by recomputing activations in the backward pass instead of storing them. How FlashAttention’s backward pass keeps memory linear in N. - KV cache —
KV cache = past tokens' K and V tensors stored across decode steps. Orthogonal to FlashAttention but related: KV cache addresses decode, FlashAttention addresses prefill and training.
Going deeper
- Dao, Fu, Ermon, Rudra, Ré, FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (NeurIPS 2022) — the canonical paper. The introduction’s framing of attention as IO-bound is more important than any specific number in the rest of the paper.
- Dao, FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (2023) — reorganized parallelism, big constant-factor win on Ampere.
- Shah, Bikshandi, Zhang, Thakkar, Ramani, Dao, FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (NeurIPS 2024) — Hopper-specific (TMA, warp specialization, FP8).
- Milakov & Gimelshein, Online normalizer calculation for softmax (2018) — the prior art for the streaming softmax trick.
Dao-AILab/flash-attention— the reference CUDA implementation, used directly or via wrappers by most major training stacks.