Heads up: posts on this site are drafted by Claude and fact-checked by Codex. Both can still get things wrong — read with care and verify anything load-bearing before relying on it.
why how

Why MLA replaced MHA

DeepSeek-V2 cut its KV cache by 93% by attacking the bottleneck differently than GQA — and it scored higher, not lower, on benchmarks.

AI & ML intermediate Apr 30, 2026

Why it exists

Imagine a really long ChatGPT conversation — hundreds of messages back and forth. To stay coherent, the model has to “remember” every previous word in the chat. It does that by stashing a small note about each token in a special on-GPU scratchpad called the KV cache. The longer the chat, the bigger the scratchpad. Eventually the scratchpad gets larger than the model itself — and that, not the model, is what runs out of GPU memory first. MLA is a trick that shrinks each note by roughly 14×, so the same GPU can hold a much longer conversation.

Long-context inference has a memory problem, and the memory problem has a name: the KV cache. Every token a model has already seen leaves behind a key vector and a value vector for every attention head, in every layer. Generate a long enough conversation and that cache, not the model weights, is what fills your VRAM.

The standard fix until 2024 was GQA: make groups of query heads share a single K/V head. It works, but you can feel the trade — fewer distinct K/V heads means less expressive attention, and teams generally accept a small quality dent to buy the memory back.

DeepSeek-V2 shipped in May 2024 with a different bet: Multi-head Latent Attention. Same goal — shrink the KV cache. Different tactic — don’t share heads, compress what each token stores down to a tiny shared vector and reconstitute the full keys and values at attention time. DeepSeek reports a 93.3% KV-cache reduction versus their dense 67B baseline, and a quality improvement over plain MHA, not a regression (DeepSeek-V2 paper — KV compression is §2.1.2, decoupled RoPE is §2.1.3, the throughput numbers live in §3.2.3, and the “MLA beats MHA” comparison is in Appendix D.2 / Table 9). That combination — cheaper and better — is the read I have for why this architecture spread, though I haven’t seen a clean independent ablation isolating MLA from the rest of DeepSeek-V2’s choices.

Why it matters now

If you’ve used DeepSeek-V2, V3, or R1, you’ve used MLA. The frontier-cost crash everyone wrote about in 2025 is partly a story about MLA: a much smaller cache means much longer contexts and much larger inference batches on the same GPUs, which is the lever that drives per-token price down. There is now a small literature on retrofitting MLA onto already-trained MHA/GQA models (e.g. TransMLA, 2025), which is itself a tell — people want this attention shape badly enough to do surgery on existing checkpoints.

GQA is still everywhere. But for new dense-attention designs where the cache is the bottleneck, MLA is the architecture to beat.

The short answer

MLA = MHA + low-rank latent K/V cache + decoupled RoPE channel

Instead of caching every head’s full key and value per token, MLA caches one small “latent” vector per token and learns up-projection matrices that reconstruct each head’s K and V on demand. A separate small vector carries the rotary position signal, because position information refuses to compress cleanly. At inference, an algebra trick folds the up-projections into the query and output weights so the model never actually decompresses — it just attends in the compressed space.

How it works

Plain MHA caches, per token per layer, n_heads × head_dim numbers for K and the same for V. For a model with 128 heads of dimension 128, that’s 32,768 numbers (16,384 for K, 16,384 for V) per token per layer. Multiply by layers and context length and the cache eats the GPU.

MLA’s move, in three steps:

1. Compress to a latent. A learned matrix W_DKV projects the token’s hidden state down to a small vector c_KV — DeepSeek-V2 sets the compression dimension to 512, versus the 32,768 combined K+V elements you’d otherwise cache in the 128-head example. Only c_KV (plus a small RoPE side-channel, see step 4) lives in the cache. That compression is the core reason MLA shrinks the cache so much.

2. Up-project per head when you need K and V. Two more learned matrices, W_UK and W_UV, take c_KV back up to per-head keys and values at attention time. So far this looks like a strictly worse MHA — same parameter count plus an extra matmul. The win shows up next.

3. Absorb the up-projections into the query and output weights. This is the trick that makes MLA practical, not just memory-frugal. Because attention scores are Q · K^T and the output is (scores · V) · W_O, you can pre-multiply: fold W_UK into the query projection and W_UV into the output projection (DeepSeek-V2 paper §2.1.2 / Appendix C). At inference, the model never reconstructs full K or V — it computes attention directly against the small latent. You pay compute to train with the up-projections; you skip that compute at inference.

The seam — and the part that surprised the authors enough to call it out — is RoPE. RoPE multiplies Q and K by a position-dependent rotation matrix before the dot product. If your “K” is really W_UK · c_KV, you’d want to apply the rotation after up-projection — but that breaks the absorption trick, because the rotation depends on position and can’t be folded into a fixed weight. DeepSeek’s fix is decoupled RoPE (§2.1.3): split the key into two parts. The content part lives in the compressed latent and gets no RoPE. The position part is a small, shared-across-heads key vector (per-head dim 64 in DeepSeek-V2) that carries RoPE and is cached alongside c_KV. So the cache holds c_KV plus this small shared RoPE key per token — small enough that the 93% headline survives. It’s structurally inelegant — my read, not the paper’s framing — and it’s the price of keeping the absorption trick alive.

The honest gap: I have not tried to verify the 5.76× throughput claim independently — that’s DeepSeek’s number, measured on a single node of 8 H800 GPUs and reported alongside FP8 deployment plus ~6-bit KV-cache quantization (paper §3.2.3), so MLA is doing some but not all of that work. The 93.3% cache reduction is the more load-bearing claim and is reproduced across third-party explainers, but throughput in your stack will depend on batch size, context length, and serving software.

Going deeper