Skip to content

Transformer Math

Module 23 · Inference

💾 KV Cache & Memory

Without caching, generating one extra token costs as much as reprocessing the entire prompt. Here's how to fix it.

Status:

During autoregressive generation, each new token attends to all previous tokens. Without caching, the model recomputes every Key and Value vector at every step — at step , it projects all tokens to get K, V, then attends over them: projection plus attention. Summed over steps, that is total projection work alone. KV Cache stores previously computed K, V vectors and reuses them, so each step only projects the new token ( instead of ). The attention dot-product itself is still per step ( total, treating as the head dimension) — the cache eliminates redundant K/V recomputation, not the attention. The tradeoff: memory.

🎮

Token-by-Token Generation: With vs Without KV Cache

Watch how the model generates tokens one at a time. Without cache, it recomputes all previous K, V at every step. With cache, only the new token's K, V are computed.

Without KV Cache (recompute all)

Step 1:
The
= 1 KV
Step 2:
Thecat
= 2 KV
Step 3:
Thecatsat
= 3 KV
Step 4:
Thecatsaton
= 4 KV
Step 5:
Thecatsatonthe
= 5 KV
Step 6:
Thecatsatonthemat
= 6 KV

Total: 1+2+3+4+5+6 = 21 KV computes

With KV Cache (reuse previous)

Step 1:
The
= 1 new KV
Step 2:
Thecat
= 1 new KV
Step 3:
Thecatsat
= 1 new KV
Step 4:
Thecatsaton
= 1 new KV
Step 5:
Thecatsatonthe
= 1 new KV
Step 6:
Thecatsatonthemat
= 1 new KV

Total: 1+1+1+1+1+1 = 6 KV computes

Step (token)Without cacheWith cacheRunning total (no cache)
1 — The1 KV projections1 KV projection1
2 — cat2 KV projections1 KV projection3
3 — sat3 KV projections1 KV projection6
4 — on4 KV projections1 KV projection10
5 — the5 KV projections1 KV projection15
6 — mat6 KV projections1 KV projection21
Total21 — O(n²)6 — O(n)3.5× more at n=6; 1000× more at n=2048
✨ Insight · Without cache: KV projections. With cache: . For a 2048-token generation that is ~2.1M vs 2K operations — a 1000x reduction in KV projection work.

KV Cache Memory Calculator

Adjust model parameters below to see how KV cache memory scales with layers, heads, sequence length, and batch size. Compare MHA vs GQA to see the savings from sharing KV heads.

KV Cache Memory Calculator

256131K
164
GQA — KV heads: 8 (8x compression vs MHA)

Model Params

140.0 GB

KV Cache

1.3 GB

Total

141.3 GBOOM!

Memory Breakdown

A100 40GBA100/H100 80GB
Model ParamsKV CacheGPU Limit
Model details: Llama-2 70B

layers=80, heads=64, kv_heads=8, d_k=128

KV cache = 2 x 80 x 8 x 128 x 4,096 x 1 x 2 / 1e9 = 1.3 GB

💡

Why Generation Is Memory-Bound

Why is generation memory-bound, not compute-bound? During generation, each step produces just one token. The GPU does a single matrix-vector multiply (not matrix-matrix), giving . The bottleneck is loading model weights and KV cache from HBM, not doing math.

The memory cost of caching: KV cache grows linearly with sequence length, batch size, and number of layers. For large models at long contexts, it can exceed the model parameters themselves.

Three techniques reduce KV cache memory:

  • GQA / MQA — share K, V heads across multiple Q heads.
  • PagedAttention (vLLM) — manage KV cache like OS virtual memory. Non-contiguous pages eliminate fragmentation,
  • Continuous batching — when a request finishes, immediately insert a new one. No head-of-line blocking
💡 Tip · KV cache is why serving LLMs is fundamentally different from serving traditional ML models. A single Llama-2 70B request at 4K context needs just for the cache — multiply by batch size to see why GPU memory is the primary constraint.

Per-token KV cache size (FP16)

Each token stored in the cache costs:

bytes_per_token = 2 × n_layers × n_kv_heads × d_head × 2 bytes

# Llama-2 70B with GQA (8 KV heads, d_head=128, 80 layers)

bytes_per_token = 2 × 80 × 8 × 128 × 2 = 327,680 bytes ≈ 0.32 MB / token

Scale to batch_size=32, seq_len=4096

total = 32 × 4096 × 0.32 MB = 41.9 GB ← GQA (8 KV heads)

total = 32 × 4096 × 2.62 MB = 343 GB ← MHA (64 KV heads) 🚫 exceeds A100 80GB

The MHA number exceeds an A100 80GB by 4×. This is why GQA, MQA, and KV quantization are not optional at production batch sizes — they determine whether the workload is physically possible on available hardware.

KV cache quantization is a fourth lever that works orthogonally to GQA and PagedAttention. Rather than reducing the number of KV heads (GQA) or eliminating fragmentation (PagedAttention), it shrinks the per-element precision of cached tensors. KIVI (2024) showed that K and V tensors have very different statistical properties: Keys have per-channel outliers, while Values are smoother — so they benefit from different quantization strategies. Per-channel INT2 quantization of K and per-token INT2 of V achieves a 2.35x memory reduction with negligible perplexity impact, enabling a 2.35x increase in batch size or context length for the same GPU. FP8 KV cache (supported natively on Hopper-class GPUs) is widely used in production inference stacks, offering a simpler 2× reduction with virtually no quality loss.

Quick check

Derivation

A Llama-2 70B request uses 1.34 GB of KV cache at seq=4096. At batch=32, how much GPU memory is consumed by KV cache alone — and does it fit on an A100 80 GB?

A Llama-2 70B request uses 1.34 GB of KV cache at seq=4096. At batch=32, how much GPU memory is consumed by KV cache alone — and does it fit on an A100 80 GB?
Quick Check

Why is autoregressive generation memory-bound rather than compute-bound?

📐

Step-by-Step Derivation

KV Cache Memory Formula

Each layer stores a Key and a Value tensor for every token generated so far. The total memory:

Worked Example: Llama-2 70B (FP16, seq=4096, batch=1)

n_layers = 80, n_kv_heads = 8 (GQA), d_k = 128

KV Cache = 2 x 80 x 8 x 128 x 4096 x 1 x 2 bytes

= 2 x 80 x 8 x 128 x 4096 x 2

= 1,073,741,824 bytes

= 1.34 GB per request

With MHA (64 KV heads instead of 8):

= 10.7 GB per request (8x larger!)

⚠ Warning · At batch=32 with GQA: . With MHA it would be 340+ GB. This is why GQA is not optional for production serving of 70B+ models.

GQA Memory Savings

GQA groups query heads into groups, each sharing one KV head:

. Mistral 7B: reduction.

PyTorch: Simple KV Cache

python
import torch
import torch.nn.functional as F

class KVCache:
    """Simple KV cache for autoregressive generation."""
    def __init__(self, n_layers, n_kv_heads, d_k, max_seq, dtype=torch.float16):
        # Pre-allocate for max sequence length
        self.k = torch.zeros(n_layers, 1, n_kv_heads, max_seq, d_k, dtype=dtype)
        self.v = torch.zeros(n_layers, 1, n_kv_heads, max_seq, d_k, dtype=dtype)
        self.seq_len = 0

    def update(self, layer_idx, k_new, v_new):
        """Append new K, V for one token at one layer."""
        # k_new, v_new: [batch, n_kv_heads, 1, d_k]
        self.k[layer_idx, :, :, self.seq_len:self.seq_len+1, :] = k_new
        self.v[layer_idx, :, :, self.seq_len:self.seq_len+1, :] = v_new

    def get(self, layer_idx):
        """Return cached K, V up to current position."""
        return (
            self.k[layer_idx, :, :, :self.seq_len+1, :],
            self.v[layer_idx, :, :, :self.seq_len+1, :],
        )

    def step(self):
        self.seq_len += 1

def cached_attention(q, K_cached, V_cached, d_k):
    """Single-token attention with KV cache.
    q: [B, n_heads, 1, d_k]  (just the new token)
    K_cached: [B, n_kv_heads, seq, d_k]
    V_cached: [B, n_kv_heads, seq, d_k]
    """
    # GQA: repeat KV heads to match Q heads
    # (omitted for clarity — expand n_kv_heads -> n_q_heads)
    attn = (q @ K_cached.transpose(-2, -1)) / (d_k ** 0.5)
    attn = F.softmax(attn, dim=-1)
    return attn @ V_cached  # [B, n_heads, 1, d_k]

# Memory: 2 * 80 * 8 * 128 * 4096 * 1 * 2 = 1.34 GB (Llama-2 70B)
PyTorch implementation
# KV cache: accumulate past_key_values during autoregressive generation
import torch

def generate_with_kv_cache(model, input_ids, max_new_tokens=50):
    past_key_values = None  # grows each step: list of (k, v) per layer

    for _ in range(max_new_tokens):
        outputs = model(
            input_ids=input_ids,
            past_key_values=past_key_values,
            use_cache=True,
        )
        # outputs.past_key_values: tuple of (key, value) per layer
        # shape: (batch, n_kv_heads, seq_len_so_far, d_k)
        past_key_values = outputs.past_key_values

        next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
        input_ids = next_token  # only pass NEW token; cache handles the rest
        yield next_token

Quick check

Derivation

Mistral 7B uses GQA with 32 Q heads and 8 KV heads (d_k=128, 32 layers). What is the KV cache memory per token in bytes at FP16?

Mistral 7B uses GQA with 32 Q heads and 8 KV heads (d_k=128, 32 layers). What is the KV cache memory per token in bytes at FP16?
🔧

Break It — See What Happens

No KV cache (recompute every step)
Unlimited cache (no eviction, no paging)

Quick check

Trade-off

Without PagedAttention, serving 100 requests with max_seq=4096 on Llama-2 70B (GQA) pre-allocates how much KV cache memory, and what fraction is typically wasted?

Without PagedAttention, serving 100 requests with max_seq=4096 on Llama-2 70B (GQA) pre-allocates how much KV cache memory, and what fraction is typically wasted?
📊

Real-World Numbers

ModelAttentionKV Cache (FP16, 4K, B=1)Notes
Llama-2 70BGQA (8 KV)Would be
Llama-2 7BMHA (32 KV)2.1 GBKV cache is larger than 70B with GQA!
GPT-4 (estimated)GQA~2-4 GBEstimated from serving behavior; likely uses GQA + quantized KV
PagedAttention (vLLM)----;
✨ Insight · The key insight: a 7B model with MHA has a larger KV cache (2.1 GB) than a 70B model with GQA () at the same sequence length. KV cache size depends on architecture choices, not just model size.

Quick check

Derivation

At seq=4096, FP16, batch=1: Llama-2 7B (MHA, 32 layers, 32 KV heads, d_k=128) has a larger KV cache than Llama-2 70B (GQA, 80 layers, 8 KV heads). Why?

At seq=4096, FP16, batch=1: Llama-2 7B (MHA, 32 layers, 32 KV heads, d_k=128) has a larger KV cache than Llama-2 70B (GQA, 80 layers, 8 KV heads). Why?
🚀

SOTA 2024: Multi-Head Latent Attention (MLA)

DeepSeek-V2 (May 2024, arxiv:2405.04434) introduced Multi-Head Latent Attention (MLA) — a fundamentally different approach to KV compression that goes beyond GQA. Instead of sharing K/V heads across query groups, MLA projects both K and V into a shared low-rank latent vector per token.

Core idea

For each token, rather than storing separate K vectors and separate V vectors (one per head), MLA stores a single compressed latent where . At attention time, the per-head K and V are recovered from via learned projections.

MLA math + PyTorch sketch

Down-projection (cache): where , .

Up-projection (at attention time): , per head .

Only is stored in the KV cache. The up-projections run on-the-fly during attention, adding compute but slashing memory.

import torch
import torch.nn as nn

class MultiHeadLatentAttention(nn.Module):
    """Simplified MLA as in DeepSeek-V2 (arxiv:2405.04434)."""
    def __init__(self, d_model=5120, n_heads=128, d_c=512, d_head=128):
        super().__init__()
        self.n_heads = n_heads
        self.d_c = d_c
        self.d_head = d_head
        # Q projection (can also have a low-rank path for Q)
        self.wq = nn.Linear(d_model, n_heads * d_head, bias=False)
        # Down-project to shared latent (what gets cached)
        self.w_dkv = nn.Linear(d_model, d_c, bias=False)
        # Up-project latent → per-head K and V
        self.w_uk = nn.Linear(d_c, n_heads * d_head, bias=False)
        self.w_uv = nn.Linear(d_c, n_heads * d_head, bias=False)
        self.wo  = nn.Linear(n_heads * d_head, d_model, bias=False)

    def forward(self, x, kv_cache=None):
        B, T, _ = x.shape
        q = self.wq(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)

        # Compress to latent — only this gets stored in KV cache
        c_kv = self.w_dkv(x)  # (B, T, d_c)
        if kv_cache is not None:
            c_kv = torch.cat([kv_cache, c_kv], dim=1)
        new_cache = c_kv  # store compressed latent, not full K/V

        # Expand to per-head K, V at attention time
        k = self.w_uk(c_kv).view(B, -1, self.n_heads, self.d_head).transpose(1, 2)
        v = self.w_uv(c_kv).view(B, -1, self.n_heads, self.d_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) / self.d_head ** 0.5
        out = (attn.softmax(-1) @ v).transpose(1, 2).reshape(B, T, -1)
        return self.wo(out), new_cache

MHA vs GQA vs MQA vs MLA — KV cache comparison

VariantWhat is cached per tokenCache size (relative)Trade-off
MHAn_h × d_h for K, n_h × d_h for V1× (baseline)Full expressivity, max memory
GQA (g=8)(n_h/8) × d_h for K+V0.125× MHA<0.5% quality loss, no extra compute
MQA1 × d_h for K+V (shared)1/n_h × MHANoticeable quality loss at scale
MLA (DeepSeek-V2)d_c latent (d_c ≪ n_h × d_h)Extra up-projection FLOP at each layer; enables 128K+ context on commercial hardware
✨ Insight · DeepSeek-V2 concrete numbers (arxiv:2405.04434): — from ~131 GB to ~8.7 GB per sequence at the model's maximum context. This is what makes 128K-context serving viable on a single node. DeepSeek-V3 and DeepSeek-R1 both use MLA. For interview purposes: MLA ≈ “the low-rank factorization trick applied to the KV cache itself.”

Cross-reference: Multi-Head Attention module covers MHA vs GQA vs MQA in detail. MLA is the 2024 evolution that replaces the K/V projection matrices entirely with a shared latent bottleneck.

🧠

Key Takeaways

What to remember for interviews

  1. 1Without KV cache, generating n tokens costs O(n²) KV recomputations. With cache, each new token computes exactly one new K/V pair — O(n) total.
  2. 2Autoregressive generation is memory-bandwidth bound, not compute-bound: each step does a matrix-vector multiply with near-zero arithmetic intensity.
  3. 3GQA shares K/V heads across multiple query heads, giving up to 8x KV cache reduction (Llama-2 70B: 64 Q heads → 8 KV heads = 8x savings) with negligible quality loss.
  4. 4PagedAttention (vLLM) manages KV cache like OS virtual memory in non-contiguous pages, cutting waste from ~60-80% to under 4% and enabling continuous batching.
  5. 5MLA (DeepSeek-V2/V3, 2024) compresses K+V into a shared low-rank latent per token — 5–13× smaller KV cache vs MHA, enabling 128K+ context on commercial hardware at the cost of extra per-layer up-projection compute.
🧠

Recap quiz

Derivation

Without KV cache, generating n tokens costs O(n²) total KV projections. What is the exact count for a 6-token sequence, and where does the savings come from?

Without KV cache, generating n tokens costs O(n²) total KV projections. What is the exact count for a 6-token sequence, and where does the savings come from?
Derivation

Why is autoregressive decode memory-bandwidth bound even when the model has trillions of parameters?

Why is autoregressive decode memory-bandwidth bound even when the model has trillions of parameters?
Derivation

Llama-2 70B uses GQA with 64 Q heads and 8 KV heads. At FP16, seq=4096, batch=32, how large is the KV cache — and what would it be with standard MHA?

Llama-2 70B uses GQA with 64 Q heads and 8 KV heads. At FP16, seq=4096, batch=32, how large is the KV cache — and what would it be with standard MHA?
Trade-off

PagedAttention reduces KV cache waste from ~60–80% to under 4%. What mechanism produces this improvement, and what does it trade off?

PagedAttention reduces KV cache waste from ~60–80% to under 4%. What mechanism produces this improvement, and what does it trade off?
Trade-off

FP8 KV cache (vs FP16) halves memory per token. Why can KV values be quantized more aggressively than model weights without similar quality loss?

FP8 KV cache (vs FP16) halves memory per token. Why can KV values be quantized more aggressively than model weights without similar quality loss?
📚

Further Reading

🎯

Interview Questions

Difficulty:
Company:

Showing 6 of 6

Calculate KV cache memory for Llama-2 70B at 4K context, FP16, batch=1.

★★★
OpenAIDatabricks

Why is autoregressive generation memory-bound rather than compute-bound?

★★☆
GoogleAnthropic

Explain PagedAttention (vLLM). What problem does it solve and how?

★★★
DatabricksOpenAI

Compare MHA vs MQA vs GQA. When would you choose each?

★★☆
GoogleMetaAnthropic

What is continuous batching and why is it critical for serving?

★★★
DatabricksGoogle

How does KV cache quantization work and what are the tradeoffs?

★★★
DatabricksMeta