💾 KV Cache & Memory
Without caching, generating one extra token costs as much as reprocessing the entire prompt. Here's how to fix it.
During autoregressive generation, each new token attends to all previous tokens. Without caching, the model recomputes every Key and Value vector at every step — at step , it projects all tokens to get K, V, then attends over them: projection plus attention. Summed over steps, that is total projection work alone. KV Cache stores previously computed K, V vectors and reuses them, so each step only projects the new token ( instead of ). The attention dot-product itself is still per step ( total, treating as the head dimension) — the cache eliminates redundant K/V recomputation, not the attention. The tradeoff: memory.
Token-by-Token Generation: With vs Without KV Cache
Watch how the model generates tokens one at a time. Without cache, it recomputes all previous K, V at every step. With cache, only the new token's K, V are computed.
Without KV Cache (recompute all)
Total: 1+2+3+4+5+6 = 21 KV computes
With KV Cache (reuse previous)
Total: 1+1+1+1+1+1 = 6 KV computes
| Step (token) | Without cache | With cache | Running total (no cache) |
|---|---|---|---|
| 1 — The | 1 KV projections | 1 KV projection | 1 |
| 2 — cat | 2 KV projections | 1 KV projection | 3 |
| 3 — sat | 3 KV projections | 1 KV projection | 6 |
| 4 — on | 4 KV projections | 1 KV projection | 10 |
| 5 — the | 5 KV projections | 1 KV projection | 15 |
| 6 — mat | 6 KV projections | 1 KV projection | 21 |
| Total | 21 — O(n²) | 6 — O(n) | 3.5× more at n=6; 1000× more at n=2048 |
KV Cache Memory Calculator
Adjust model parameters below to see how KV cache memory scales with layers, heads, sequence length, and batch size. Compare MHA vs GQA to see the savings from sharing KV heads.
KV Cache Memory Calculator
Model Params
140.0 GB
KV Cache
1.3 GB
Total
141.3 GBOOM!
Memory Breakdown
Model details: Llama-2 70B
layers=80, heads=64, kv_heads=8, d_k=128
KV cache = 2 x 80 x 8 x 128 x 4,096 x 1 x 2 / 1e9 = 1.3 GB
Why Generation Is Memory-Bound
Why is generation memory-bound, not compute-bound? During generation, each step produces just one token. The GPU does a single matrix-vector multiply (not matrix-matrix), giving . The bottleneck is loading model weights and KV cache from HBM, not doing math.
The memory cost of caching: KV cache grows linearly with sequence length, batch size, and number of layers. For large models at long contexts, it can exceed the model parameters themselves.
Three techniques reduce KV cache memory:
- GQA / MQA — share K, V heads across multiple Q heads.
- PagedAttention (vLLM) — manage KV cache like OS virtual memory. Non-contiguous pages eliminate fragmentation,
- Continuous batching — when a request finishes, immediately insert a new one. No head-of-line blocking
Per-token KV cache size (FP16)
Each token stored in the cache costs:
bytes_per_token = 2 × n_layers × n_kv_heads × d_head × 2 bytes
# Llama-2 70B with GQA (8 KV heads, d_head=128, 80 layers)
bytes_per_token = 2 × 80 × 8 × 128 × 2 = 327,680 bytes ≈ 0.32 MB / token
Scale to batch_size=32, seq_len=4096
total = 32 × 4096 × 0.32 MB = 41.9 GB ← GQA (8 KV heads)
total = 32 × 4096 × 2.62 MB = 343 GB ← MHA (64 KV heads) 🚫 exceeds A100 80GB
The MHA number exceeds an A100 80GB by 4×. This is why GQA, MQA, and KV quantization are not optional at production batch sizes — they determine whether the workload is physically possible on available hardware.
KV cache quantization is a fourth lever that works orthogonally to GQA and PagedAttention. Rather than reducing the number of KV heads (GQA) or eliminating fragmentation (PagedAttention), it shrinks the per-element precision of cached tensors. KIVI (2024) showed that K and V tensors have very different statistical properties: Keys have per-channel outliers, while Values are smoother — so they benefit from different quantization strategies. Per-channel INT2 quantization of K and per-token INT2 of V achieves a 2.35x memory reduction with negligible perplexity impact, enabling a 2.35x increase in batch size or context length for the same GPU. FP8 KV cache (supported natively on Hopper-class GPUs) is widely used in production inference stacks, offering a simpler 2× reduction with virtually no quality loss.
Quick check
A Llama-2 70B request uses 1.34 GB of KV cache at seq=4096. At batch=32, how much GPU memory is consumed by KV cache alone — and does it fit on an A100 80 GB?
Why is autoregressive generation memory-bound rather than compute-bound?
Step-by-Step Derivation
KV Cache Memory Formula
Each layer stores a Key and a Value tensor for every token generated so far. The total memory:
Worked Example: Llama-2 70B (FP16, seq=4096, batch=1)
n_layers = 80, n_kv_heads = 8 (GQA), d_k = 128
KV Cache = 2 x 80 x 8 x 128 x 4096 x 1 x 2 bytes
= 2 x 80 x 8 x 128 x 4096 x 2
= 1,073,741,824 bytes
= 1.34 GB per request
With MHA (64 KV heads instead of 8):
= 10.7 GB per request (8x larger!)
GQA Memory Savings
GQA groups query heads into groups, each sharing one KV head:
. Mistral 7B: reduction.
PyTorch: Simple KV Cache
import torch
import torch.nn.functional as F
class KVCache:
"""Simple KV cache for autoregressive generation."""
def __init__(self, n_layers, n_kv_heads, d_k, max_seq, dtype=torch.float16):
# Pre-allocate for max sequence length
self.k = torch.zeros(n_layers, 1, n_kv_heads, max_seq, d_k, dtype=dtype)
self.v = torch.zeros(n_layers, 1, n_kv_heads, max_seq, d_k, dtype=dtype)
self.seq_len = 0
def update(self, layer_idx, k_new, v_new):
"""Append new K, V for one token at one layer."""
# k_new, v_new: [batch, n_kv_heads, 1, d_k]
self.k[layer_idx, :, :, self.seq_len:self.seq_len+1, :] = k_new
self.v[layer_idx, :, :, self.seq_len:self.seq_len+1, :] = v_new
def get(self, layer_idx):
"""Return cached K, V up to current position."""
return (
self.k[layer_idx, :, :, :self.seq_len+1, :],
self.v[layer_idx, :, :, :self.seq_len+1, :],
)
def step(self):
self.seq_len += 1
def cached_attention(q, K_cached, V_cached, d_k):
"""Single-token attention with KV cache.
q: [B, n_heads, 1, d_k] (just the new token)
K_cached: [B, n_kv_heads, seq, d_k]
V_cached: [B, n_kv_heads, seq, d_k]
"""
# GQA: repeat KV heads to match Q heads
# (omitted for clarity — expand n_kv_heads -> n_q_heads)
attn = (q @ K_cached.transpose(-2, -1)) / (d_k ** 0.5)
attn = F.softmax(attn, dim=-1)
return attn @ V_cached # [B, n_heads, 1, d_k]
# Memory: 2 * 80 * 8 * 128 * 4096 * 1 * 2 = 1.34 GB (Llama-2 70B)PyTorch implementation
# KV cache: accumulate past_key_values during autoregressive generation
import torch
def generate_with_kv_cache(model, input_ids, max_new_tokens=50):
past_key_values = None # grows each step: list of (k, v) per layer
for _ in range(max_new_tokens):
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
use_cache=True,
)
# outputs.past_key_values: tuple of (key, value) per layer
# shape: (batch, n_kv_heads, seq_len_so_far, d_k)
past_key_values = outputs.past_key_values
next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
input_ids = next_token # only pass NEW token; cache handles the rest
yield next_tokenQuick check
Mistral 7B uses GQA with 32 Q heads and 8 KV heads (d_k=128, 32 layers). What is the KV cache memory per token in bytes at FP16?
Break It — See What Happens
Quick check
Without PagedAttention, serving 100 requests with max_seq=4096 on Llama-2 70B (GQA) pre-allocates how much KV cache memory, and what fraction is typically wasted?
Real-World Numbers
| Model | Attention | KV Cache (FP16, 4K, B=1) | Notes |
|---|---|---|---|
| Llama-2 70B | GQA (8 KV) | Would be | |
| Llama-2 7B | MHA (32 KV) | 2.1 GB | KV cache is larger than 70B with GQA! |
| GPT-4 (estimated) | GQA | ~2-4 GB | Estimated from serving behavior; likely uses GQA + quantized KV |
| PagedAttention (vLLM) | -- | -- | ; |
Quick check
At seq=4096, FP16, batch=1: Llama-2 7B (MHA, 32 layers, 32 KV heads, d_k=128) has a larger KV cache than Llama-2 70B (GQA, 80 layers, 8 KV heads). Why?
SOTA 2024: Multi-Head Latent Attention (MLA)
DeepSeek-V2 (May 2024, arxiv:2405.04434) introduced Multi-Head Latent Attention (MLA) — a fundamentally different approach to KV compression that goes beyond GQA. Instead of sharing K/V heads across query groups, MLA projects both K and V into a shared low-rank latent vector per token.
Core idea
For each token, rather than storing separate K vectors and separate V vectors (one per head), MLA stores a single compressed latent where . At attention time, the per-head K and V are recovered from via learned projections.
MLA math + PyTorch sketch
Down-projection (cache): where , .
Up-projection (at attention time): , per head .
Only is stored in the KV cache. The up-projections run on-the-fly during attention, adding compute but slashing memory.
import torch
import torch.nn as nn
class MultiHeadLatentAttention(nn.Module):
"""Simplified MLA as in DeepSeek-V2 (arxiv:2405.04434)."""
def __init__(self, d_model=5120, n_heads=128, d_c=512, d_head=128):
super().__init__()
self.n_heads = n_heads
self.d_c = d_c
self.d_head = d_head
# Q projection (can also have a low-rank path for Q)
self.wq = nn.Linear(d_model, n_heads * d_head, bias=False)
# Down-project to shared latent (what gets cached)
self.w_dkv = nn.Linear(d_model, d_c, bias=False)
# Up-project latent → per-head K and V
self.w_uk = nn.Linear(d_c, n_heads * d_head, bias=False)
self.w_uv = nn.Linear(d_c, n_heads * d_head, bias=False)
self.wo = nn.Linear(n_heads * d_head, d_model, bias=False)
def forward(self, x, kv_cache=None):
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_heads, self.d_head).transpose(1, 2)
# Compress to latent — only this gets stored in KV cache
c_kv = self.w_dkv(x) # (B, T, d_c)
if kv_cache is not None:
c_kv = torch.cat([kv_cache, c_kv], dim=1)
new_cache = c_kv # store compressed latent, not full K/V
# Expand to per-head K, V at attention time
k = self.w_uk(c_kv).view(B, -1, self.n_heads, self.d_head).transpose(1, 2)
v = self.w_uv(c_kv).view(B, -1, self.n_heads, self.d_head).transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) / self.d_head ** 0.5
out = (attn.softmax(-1) @ v).transpose(1, 2).reshape(B, T, -1)
return self.wo(out), new_cacheMHA vs GQA vs MQA vs MLA — KV cache comparison
| Variant | What is cached per token | Cache size (relative) | Trade-off |
|---|---|---|---|
| MHA | n_h × d_h for K, n_h × d_h for V | 1× (baseline) | Full expressivity, max memory |
| GQA (g=8) | (n_h/8) × d_h for K+V | 0.125× MHA | <0.5% quality loss, no extra compute |
| MQA | 1 × d_h for K+V (shared) | 1/n_h × MHA | Noticeable quality loss at scale |
| MLA (DeepSeek-V2) | d_c latent (d_c ≪ n_h × d_h) | Extra up-projection FLOP at each layer; enables 128K+ context on commercial hardware |
Cross-reference: Multi-Head Attention module covers MHA vs GQA vs MQA in detail. MLA is the 2024 evolution that replaces the K/V projection matrices entirely with a shared latent bottleneck.
Key Takeaways
What to remember for interviews
- 1Without KV cache, generating n tokens costs O(n²) KV recomputations. With cache, each new token computes exactly one new K/V pair — O(n) total.
- 2Autoregressive generation is memory-bandwidth bound, not compute-bound: each step does a matrix-vector multiply with near-zero arithmetic intensity.
- 3GQA shares K/V heads across multiple query heads, giving up to 8x KV cache reduction (Llama-2 70B: 64 Q heads → 8 KV heads = 8x savings) with negligible quality loss.
- 4PagedAttention (vLLM) manages KV cache like OS virtual memory in non-contiguous pages, cutting waste from ~60-80% to under 4% and enabling continuous batching.
- 5MLA (DeepSeek-V2/V3, 2024) compresses K+V into a shared low-rank latent per token — 5–13× smaller KV cache vs MHA, enabling 128K+ context on commercial hardware at the cost of extra per-layer up-projection compute.
Recap quiz
Without KV cache, generating n tokens costs O(n²) total KV projections. What is the exact count for a 6-token sequence, and where does the savings come from?
Why is autoregressive decode memory-bandwidth bound even when the model has trillions of parameters?
Llama-2 70B uses GQA with 64 Q heads and 8 KV heads. At FP16, seq=4096, batch=32, how large is the KV cache — and what would it be with standard MHA?
PagedAttention reduces KV cache waste from ~60–80% to under 4%. What mechanism produces this improvement, and what does it trade off?
FP8 KV cache (vs FP16) halves memory per token. Why can KV values be quantized more aggressively than model weights without similar quality loss?
Further Reading
- Efficient Memory Management for Large Language Model Serving with PagedAttention — The vLLM paper — virtual memory paging for KV cache, eliminating fragmentation and enabling continuous batching.
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints — Grouped-query attention — interpolates between MHA and MQA to reduce KV cache size with minimal quality loss.
- KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache — Per-channel INT2 KV cache quantization — 2.35x memory reduction with negligible quality loss, enabling longer contexts on the same GPU.
- Fast Transformer Decoding: One Write-Head is All You Need (MQA) — Shazeer 2019 — multi-query attention reduces KV cache size by sharing one KV head across all query heads. Direct precursor to GQA.
- Lilian Weng's Blog — In-depth posts on KV cache optimization, memory management, and efficient LLM serving.
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — Chen et al. 2024 — introduces Multi-Head Latent Attention (MLA), compressing KV cache to a shared low-rank latent with 93.3% reduction vs MHA.
Interview Questions
Showing 6 of 6
Calculate KV cache memory for Llama-2 70B at 4K context, FP16, batch=1.
★★★Why is autoregressive generation memory-bound rather than compute-bound?
★★☆Explain PagedAttention (vLLM). What problem does it solve and how?
★★★Compare MHA vs MQA vs GQA. When would you choose each?
★★☆What is continuous batching and why is it critical for serving?
★★★How does KV cache quantization work and what are the tradeoffs?
★★★