🧠 Multi-Head Attention
One head looks at the current word, another at the sentence 12 positions back, a third at syntactic structure. Why do you need 32 of them in parallel?
Interactive Sandbox
What you’re seeing:multiple attention heads running in parallel — each head projects Q/K/V into a different subspace and produces its own attention pattern, then all heads are concatenated and mixed through W𝑶. What to try:toggle “Force all heads to same weights” below to collapse head diversity and see why multi-head attention outperforms single-head.
Multi-head attentionruns several attention operations in parallel, each with its own learned Q/K/V projections. Instead of one head trying to capture everything, each head specializes — one might track syntax (subject-verb agreement), another coreference ("it" refers to "cat"), another positional proximity. Their outputs are concatenated and mixed through a final projection matrix .
What to try: compare the 4 heads below — notice how each produces a different attention pattern for the same input. Then toggle "Force all heads to same weights" to see the diversity collapse into duplicated identical heads, eliminating the benefit of multiple heads.
Sentence: "The cat sat on the mat" -- click a head for details
The Intuition
A single head's attention pattern is too limited. Multiple heads each focus on different aspects: some capture syntactic dependencies (subject→verb), some handle coreference resolution (it→cat), and some track positional adjacency.
Different heads learn different attention patterns. Research shows some heads specialize in syntactic dependencies (adjacent words), others in semantic relationships (co-reference), and others in positional patterns. This division of labor is why multi-head attention outperforms a single wide attention head — each head becomes an expert in a different type of relationship.
Quick check
A syntactic head and a coreference head both attend to the same input sentence. The syntactic head attends strongly to adjacent tokens; the coreference head links “it” to “cat” three positions earlier. What architectural property makes this possible?
Diagram: 8 Heads Running in Parallel
Each head receives the full input X but projects it to its own smaller Q/K/V subspace (d_k = d_model/h). All heads compute attention independently and simultaneously.
Diagram: Concat → W^O Projection
Each head outputs a d_k-dimensional vector. They are concatenated into a d_model-wide block, then W^O (d_model × d_model) mixes information across all heads back into d_model.
Diagram: Head Specialization — Same Sentence, Different Patterns
Three heads attending to "The cat sat on the mat". Each head learns a completely different relationship — same input, radically different attention weights. Darker cell = more attention.
If d_model = 768 and n_heads = 12, what is d_k (dimension per head)?
Step-by-Step Derivation
Multi-Head Attention
Split the -dimensional space into subspaces, each head with dimension :
where ,
PyTorch implementation
# Multi-Head Attention
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=768, n_heads=12):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x):
B, T, C = x.shape
qkv = self.W_qkv(x).reshape(B, T, 3, self.n_heads, self.d_k)
q, k, v = qkv.permute(2, 0, 3, 1, 4) # each [B, h, T, d_k]
att = (q @ k.transpose(-2, -1)) / self.d_k**0.5
att = att.softmax(dim=-1)
out = (att @ v).transpose(1, 2).reshape(B, T, C) # [B, T, d_model]
return self.W_o(out)Feed-Forward Network (FFN)
Classic version with :
Modern version using SwiGLU (Llama, PaLM, Gemma):
where , . SwiGLU uses three matrices instead of two, but adjusts the hidden dimension to keep total parameter count similar.
PyTorch implementation
# SwiGLU FFN (used in Llama)
import torch.nn.functional as F
class SwiGLU_FFN(nn.Module):
def __init__(self, d_model=768, d_ff=2048):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False) # gate
self.w2 = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))Quick check
The original Transformer (Vaswani 2017) sets d_k = d_model/h = 64 with h=8, d_model=512. Why is this constraint (d_k = d_model/h) imposed rather than choosing d_k and h independently?
PyTorch: Multi-Head Attention
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, n_heads=8):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, T, C = x.shape
# Project and reshape: (B, T, C) -> (B, n_heads, T, d_k)
Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
# Scaled dot-product attention per head
scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn = F.softmax(scores, dim=-1)
# Combine heads: (B, n_heads, T, d_k) -> (B, T, C)
out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
return self.W_o(out)Break It — See What Happens
Real-World Numbers
Head counts scale with model size: GPT-2 Small uses (d_model=768, ), while GPT-3 175B uses (d_model=12288, ). Modern models like Llama-2 70B further add GQA to compress the KV cache:
Llama-2 70B GQA Configuration (actual architecture)
| Parameter | Value |
|---|---|
| Query heads (h) | |
| KV groups (g) | |
| d_model | 8192 |
| d_k per head | |
| KV cache savings | 8x smaller vs MHA |
Parameter distribution per layer — dense transformer approximation (d=8192)
Note: these are standard dense-transformer formulas. Llama-2 70B uses SwiGLU FFN (intermediate_size=28672), making FFN ≈704M/layer and ~56B total across 80 layers.
MHA vs GQA vs MQA — KV Cache Impact
At inference time, every generated token appends one K vector and one V vector per layer to the KV cache. The total KV cache size is:
The factor of 2 is for K and V. With MHA, . With , for group size . With MQA (Shazeer 2019), .
| Variant | KV heads | KV cache (Llama-2 70B, 4K ctx, batch=8, FP16) | Quality vs MHA |
|---|---|---|---|
| MHA | 64 | Baseline | |
| GQA (g=8) | 8 | ||
| MQA | 1 | ~1.25 GB | Noticeable degradation |
The practical consequence: Llama-2 70B with MHA would require ~80 GB just for the KV cache at 4K context (batch=8, FP16) — a significant fraction of an 80 GB A100's memory budget on top of the model weights. GQA cuts this to ~10 GB, making 70B inference viable on a single GPU. The insight from Ainslie et al. is that K and V patterns are far more redundant across heads than Q patterns — sharing KV across groups of queries loses almost nothing, while sharing Q would lose significant expressivity.
PyTorch: GQA implementation
import torch, torch.nn as nn, torch.nn.functional as F
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model=4096, n_q_heads=32, n_kv_heads=8):
super().__init__()
assert n_q_heads % n_kv_heads == 0
self.n_q, self.n_kv = n_q_heads, n_kv_heads
self.d_k = d_model // n_q_heads
self.groups = n_q_heads // n_kv_heads # queries per KV group
self.wq = nn.Linear(d_model, n_q_heads * self.d_k, bias=False)
self.wk = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
self.wv = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
self.wo = nn.Linear(d_model, d_model, bias=False)
def forward(self, x):
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_q, self.d_k).transpose(1, 2)
k = self.wk(x).view(B, T, self.n_kv, self.d_k).transpose(1, 2)
v = self.wv(x).view(B, T, self.n_kv, self.d_k).transpose(1, 2)
# Expand K,V to match Q heads: repeat each KV head 'groups' times
k = k.repeat_interleave(self.groups, dim=1) # [B, n_q, T, d_k]
v = v.repeat_interleave(self.groups, dim=1)
scores = (q @ k.transpose(-2, -1)) / self.d_k ** 0.5
out = (scores.softmax(-1) @ v).transpose(1, 2).reshape(B, T, -1)
return self.wo(out)KV-cache sizes and quality deltas here are rough model-specific comparisons based on the cited Llama-2 and GQA papers, not a claim about the current 2026 leaderboard.
Quick check
GQA with g=8 in Llama-2 70B cuts KV cache 8× vs MHA, with <0.5% quality degradation. MQA (g=1) cuts it 64× but shows noticeable degradation. What does this asymmetry reveal about information redundancy in attention?
Key Takeaways
What to remember for interviews
- 1h heads × d_k = d_model — same total parameters as single-head, but parallel specialization.
- 2Each head learns its own Q/K/V projections, so different heads attend to different relationship types.
- 3W^O is the only place heads can share information — it mixes all head outputs back into d_model.
- 4Removing head diversity (all heads same weights) collapses MHA into duplicated identical heads, eliminating the benefit of multiple heads.
Recap quiz
Vaswani et al. use h=8 heads with d_k=64 in a d_model=512 Transformer. If you replaced MHA with a single-head attention keeping d_model=512 and d_k=512, how does the total QKV parameter count change?
A single attention head with d_k = d_model can represent any linear function of input tokens. Why does splitting into h smaller heads with d_k = d_model/h empirically outperform it?
Llama-2 70B uses GQA with 64 Q-heads and 8 KV-groups (g=8). At 4K context, batch=8, FP16 (2 bytes), with 80 layers: what is the approximate KV cache size with GQA vs full MHA?
GPT-2 Small uses 12 attention heads with d_model=768 (d_k=64). GPT-3 175B uses 96 heads with d_model=12288 (d_k=128). Why does d_k double even though head count increases 8×?
In multi-head attention, what happens to the output if you remove W^O entirely (i.e., just concatenate head outputs without any final linear projection)?
Michel et al. (2019) found that pruning most attention heads at test time causes minimal quality loss. What does this imply about the structure of multi-head attention in practice?
Further Reading
- Attention Is All You Need — Vaswani et al. 2017 — introduced multi-head attention with 8 heads, each projecting to d_k = 64.
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints — Ainslie et al. 2023 — grouped-query attention used by LLaMA 2 70B. Balances MHA quality with MQA speed.
- Fast Transformer Decoding: One Write-Head is All You Need (MQA) — Shazeer 2019 — multi-query attention shares a single KV head across all query heads. Used by PaLM and Falcon.
- The Illustrated Transformer — Jay Alammar — Visual walkthrough of multi-head attention — shows how heads split dimensions and recombine outputs.
- 3Blue1Brown — Attention in Transformers — Grant Sanderson's animated breakdown of how multiple attention heads capture different relationship types.
- Transformer Explainer (Georgia Tech) — Interactive visualization — see all attention heads running simultaneously on real text.
- Are Sixteen Heads Really Better than One? — Michel et al. 2019 — shows most attention heads are redundant and can be pruned with minimal quality loss. Challenges assumptions about head count.
- In-context Learning and Induction Heads (Olsson et al.) — Anthropic research showing a specific two-head circuit is the mechanism behind in-context learning — the clearest example of head specialization.
Interview Questions
Showing 7 of 7
MHA vs single-head attention with the same total parameters — what changes?
★★☆What is W^O's role in MHA? Can you remove it?
★★☆Grouped Query Attention (GQA) — why is it the new standard in Llama-2 70B and beyond?
★★★Why is the output projection W^O necessary? What would happen without it?
★★☆How does GQA reduce memory bandwidth while preserving quality?
★★☆What patterns do different attention heads learn to specialize in?
★★☆Derive the memory savings of MQA vs MHA for a model with 32 heads.
★★★