Heads up: posts on this site are drafted by Claude and fact-checked by Codex. Both can still get things wrong — read with care and verify anything load-bearing before relying on it.
why how

Why FlashAttention was a breakthrough

Same math, same exact outputs, same asymptotic compute — and yet it made attention several times faster and unlocked long context. The trick was noticing attention was a memory problem, not a compute problem.

AI & ML intermediate Apr 30, 2026

Why it exists

If you read the FlashAttention paper expecting a clever new approximation to attention, you’ll be confused. There isn’t one. The output is the same softmax attention the original transformer paper described in 2017 — exact, up to floating-point reordering. The asymptotic FLOP count is the same (the backward pass actually does more arithmetic, because it recomputes intermediates instead of storing them — more on that below). The model isn’t changed. And yet it ran several times faster on real GPUs and made it practical to train and serve transformers at sequence lengths that previously OOM’d.

That mismatch — same math, dramatically different wall-clock — is the entire point. Before FlashAttention, the prevailing intuition was that attention was compute-bound: it does an N×N matmul, and matmuls are what GPUs are for, so what’s left to optimize? The answer turned out to be everything. The bottleneck was never the multiplications. It was the N×N matrix being shuffled in and out of HBM three or four times per attention layer, with the GPU’s actual compute units sitting idle waiting for memory.

The breakthrough was noticing this and writing a kernel that respected the memory hierarchy. The technique — tiling plus an online softmax trick — is not novel in computer science; it’s the standard playbook for memory-bound numerical kernels. The novelty was applying that playbook to attention, demonstrating that attention had been a memory-bandwidth problem all along, and shipping a kernel everyone could use. The 2022 paper by Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré (arXiv:2205.14135) made that case, and it was unusually consequential — the FA-2/FA-3 lineage continues today, and the underlying kernel (or a derivative) is widely available through major training and inference stacks. (My read on adoption breadth, not a sourced market study.)

Why it matters now

The reason FlashAttention is still a load-bearing piece of infrastructure four years later, instead of a paper everyone moved on from, is that the bandwidth gap it exploits has only gotten wider:

The short answer

FlashAttention = exact softmax attention + tiling + online softmax + kernel fusion

It’s the same attention as before — same inputs, same exact outputs, same asymptotic FLOP count — restructured so the N×N intermediate matrix never gets written to slow GPU memory. The work is done in small blocks that fit in fast on-chip memory and are processed in a single fused kernel. The math is unchanged; only the data movement is. (The backward pass does add some recomputation FLOPs as the price of not storing the attention matrix.)

How it works

Three ingredients, each of which is well-understood on its own. Their combination is the thing that hadn’t been done for attention.

1. The memory hierarchy nobody was respecting

A GPU has two memory tiers that matter here. SRAM is on-chip, sits next to the compute units, and is fast — but tiny, on the order of tens of MB total across all streaming multiprocessors. HBM is off-chip DRAM, much slower per byte but much bigger (tens of GB).

The original FlashAttention paper cites an A100 example where SRAM bandwidth is roughly 19 TB/s versus 1.5 TB/s for HBM — about 13× the throughput in roughly 1/2000th the capacity. Standard attention writes the full N×N scores matrix out to HBM, reads it back to apply softmax, writes the result back, and reads it again to multiply by V. Those round-trips are the bottleneck. The matmul units finish early and wait.

This is the core empirical claim of the paper: a benchmark profiling shows attention is memory-bound, not compute-bound. Once you see that, the fix is structural: don’t materialize the N×N matrix in HBM at all.

2. Tiling: do the work in blocks that fit in SRAM

Tiling is the textbook technique for memory-bound numerical work — it’s how good BLAS matmul kernels have always worked. FlashAttention applies it to attention.

Conceptually:

for each block of K, V (loaded into SRAM):
    for each block of Q (loaded into SRAM):
        compute partial scores = Q_block · K_blockᵀ   # in SRAM
        compute partial softmax + partial output       # in SRAM
        update running output and softmax statistics   # tiny

The N×N scores matrix is never assembled in HBM. Each block is computed, used, and discarded inside SRAM. The kernel writes back the final output (shape N×d, same size as the input) and a small O(N) per-row softmax statistic that the backward pass needs (the running max m and the running sum-of-exps l, sometimes combined and stored as m + log l); intermediate scores and probabilities are gone. Total HBM traffic drops from O(N²) to roughly O(N²·d²/M), where M is the SRAM size — a big concrete win because M is large enough that the d²/M factor is small.

3. Online softmax: the part that took some thought

Tiling is only “trivial textbook stuff” if your operation is associative across blocks. Matmul is. Softmax isn’t, naively — softmax(x_i) = exp(x_i) / Σ exp(x_j) requires knowing the whole row to compute any entry, because of the denominator. That global dependency is what stops you from just tiling attention the obvious way.

The fix, often called the online softmax trick, predates FlashAttention — Milakov and Gimelshein described an online normalizer in 2018 (arXiv:1805.02867) — but FlashAttention is what plumbed it through to a full attention kernel. The idea: as you stream through blocks of a row, keep two running scalars per row — the running max m (for numerical stability) and the running sum-of-exps l. When a new block arrives with its own local max and local sum, you rescale the running quantities and the partial output, then fold in the new block. After the last block, the result equals what a one-shot softmax would have produced, up to floating-point reordering. No approximation, no drift.

This is the part of FlashAttention that took genuine engineering taste to land — getting the rescaling exactly right while staying numerically stable, doing it inside a single CUDA kernel, and making the backward pass also work. The forward pass keeps only O(N) softmax statistics; the backward pass uses recomputation (re-deriving the attention matrix on the fly during backprop instead of storing it) to keep memory linear too.

4. Kernel fusion: don’t leave SRAM until you’re done

The final ingredient is fusing all of attention — score matmul, softmax, output matmul — into one CUDA kernel. Standard PyTorch attention does each step as a separate kernel launch, which means each step has to read its inputs from HBM and write its outputs back. FlashAttention does the whole thing inside one kernel, so blocks live in SRAM for the duration. This is the same fusion idea behind a lot of hand-tuned GPU work; FlashAttention’s contribution wasn’t inventing fusion, it was finding the version that fused all of attention without sacrificing exactness.

What the numbers actually showed

The reported speedups depend a lot on what you measure. The 2022 paper highlights:

End-to-end gains are smaller than attention-only gains because attention isn’t 100% of the work, especially at short sequence lengths. The longer the sequence, the bigger the fraction of total time attention takes, and the bigger FlashAttention’s relative win — which is part of why long-context training became practical around the time this kernel and its successors became standard. (I’m not claiming sole causation; that’s an oversimplification — sequence-parallel training, ring attention, and architectural changes all contributed.)

What it did not change

It’s worth being precise about the limits, because the reframing only goes so far.

The follow-ups, briefly

FlashAttention-2 (2023) reorganized the parallelism (different work partitioning across thread blocks, fewer non-matmul operations) and roughly doubled throughput on Ampere-class GPUs. FlashAttention-3 (2024) targets Hopper (H100) specifically — exploiting TMA, warp-specialized async pipelines, and FP8 / BF16 with block-quantized FP8 — and reports lifting H100 attention utilization from roughly 35% (FA-2) to ~85% in BF16 (~1.3 PFLOPs) in the NeurIPS 2024 paper; the July 2024 preprint reported ~75% / FP16, so the exact headline number depends on which version you read. My read, not consensus: each follow-up is a constant-factor win that gets harder to find as the easy bandwidth wins have already been taken.

Going deeper