Skip to content

Transformer Math

Module 6 · The Transformer

🧠 Multi-Head Attention

One head looks at the current word, another at the sentence 12 positions back, a third at syntactic structure. Why do you need 32 of them in parallel?

Status:
🎮

Interactive Sandbox

What you’re seeing:multiple attention heads running in parallel — each head projects Q/K/V into a different subspace and produces its own attention pattern, then all heads are concatenated and mixed through W𝑶. What to try:toggle “Force all heads to same weights” below to collapse head diversity and see why multi-head attention outperforms single-head.

Multi-Head AttentionInputd_modelLinearQ1,K1,V1AttentionZ1Head 1LinearQ2,K2,V2AttentionZ2Head 2LinearQ3,K3,V3AttentionZ3Head 3LinearQ4,K4,V4AttentionZ4Head 4d_k = d_model / hConcatW^O[d_model, d_model]Outputd_modelHead 1Head 2Head 3Head 4Output projection

Multi-head attentionruns several attention operations in parallel, each with its own learned Q/K/V projections. Instead of one head trying to capture everything, each head specializes — one might track syntax (subject-verb agreement), another coreference ("it" refers to "cat"), another positional proximity. Their outputs are concatenated and mixed through a final projection matrix .

What to try: compare the 4 heads below — notice how each produces a different attention pattern for the same input. Then toggle "Force all heads to same weights" to see the diversity collapse into duplicated identical heads, eliminating the benefit of multiple heads.

Sentence: "The cat sat on the mat" -- click a head for details

0: The1: cat2: sat3: on4: the5: mat
💡

The Intuition

A single head's attention pattern is too limited. Multiple heads each focus on different aspects: some capture syntactic dependencies (subject→verb), some handle coreference resolution (it→cat), and some track positional adjacency.

Different heads learn different attention patterns. Research shows some heads specialize in syntactic dependencies (adjacent words), others in semantic relationships (co-reference), and others in positional patterns. This division of labor is why multi-head attention outperforms a single wide attention head — each head becomes an expert in a different type of relationship.

Quick check

Trade-off

A syntactic head and a coreference head both attend to the same input sentence. The syntactic head attends strongly to adjacent tokens; the coreference head links “it” to “cat” three positions earlier. What architectural property makes this possible?

A syntactic head and a coreference head both attend to the same input sentence. The syntactic head attends strongly to adjacent tokens; the coreference head links “it” to “cat” three positions earlier. What architectural property makes this possible?
🖼️

Diagram: 8 Heads Running in Parallel

Each head receives the full input X but projects it to its own smaller Q/K/V subspace (d_k = d_model/h). All heads compute attention independently and simultaneously.

XHead 1W_Q W_K W_VAttn(Q,K,V)d_k = d/hHead 2W_Q W_K W_VAttn(Q,K,V)d_k = d/hHead 3W_Q W_K W_VAttn(Q,K,V)d_k = d/hHead 4W_Q W_K W_VAttn(Q,K,V)d_k = d/h...Head hCat+ W^OAll 8 heads run simultaneously — parallel computation, each learning different patterns
🖼️

Diagram: Concat → W^O Projection

Each head outputs a d_k-dimensional vector. They are concatenated into a d_model-wide block, then W^O (d_model × d_model) mixes information across all heads back into d_model.

h1d_kh2d_kh3d_kh4d_kh5d_kh6d_kh7d_kh8d_kConcat(head_1, ..., head_8) — width = d_modeld_model = 8 × d_kW^O(d×d)Outputd_model (heads mixed)W^O is the only place heads can share information — removing it isolates them forever
🖼️

Diagram: Head Specialization — Same Sentence, Different Patterns

Three heads attending to "The cat sat on the mat". Each head learns a completely different relationship — same input, radically different attention weights. Darker cell = more attention.

Head 1: Syntactic neighborsThecatsatonthematThecatsatonthematHead 2: Semantic matchThecatsatonthematThecatsatonthematHead 3: Positional (+2)ThecatsatonthematThecatsatonthematSame input — each head learns a completely different attention distribution
Quick Check

If d_model = 768 and n_heads = 12, what is d_k (dimension per head)?

📐

Step-by-Step Derivation

Multi-Head Attention

Split the -dimensional space into subspaces, each head with dimension :

where ,

✨ Insight · is not just a dimension-reduction after concatenation — it is the only channel for cross-head information interaction. Without it, each head's output is confined to its own subspace with no way to merge.
PyTorch implementation
# Multi-Head Attention
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=768, n_heads=12):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.W_qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.W_qkv(x).reshape(B, T, 3, self.n_heads, self.d_k)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)  # each [B, h, T, d_k]
        att = (q @ k.transpose(-2, -1)) / self.d_k**0.5
        att = att.softmax(dim=-1)
        out = (att @ v).transpose(1, 2).reshape(B, T, C)  # [B, T, d_model]
        return self.W_o(out)

Feed-Forward Network (FFN)

Classic version with :

Modern version using SwiGLU (Llama, PaLM, Gemma):

where , . SwiGLU uses three matrices instead of two, but adjusts the hidden dimension to keep total parameter count similar.

PyTorch implementation
# SwiGLU FFN (used in Llama)
import torch.nn.functional as F

class SwiGLU_FFN(nn.Module):
    def __init__(self, d_model=768, d_ff=2048):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)  # gate
        self.w2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

Quick check

Derivation

The original Transformer (Vaswani 2017) sets d_k = d_model/h = 64 with h=8, d_model=512. Why is this constraint (d_k = d_model/h) imposed rather than choosing d_k and h independently?

The original Transformer (Vaswani 2017) sets d_k = d_model/h = 64 with h=8, d_model=512. Why is this constraint (d_k = d_model/h) imposed rather than choosing d_k and h independently?

PyTorch: Multi-Head Attention

python
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, n_heads=8):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, T, C = x.shape

        # Project and reshape: (B, T, C) -> (B, n_heads, T, d_k)
        Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        # Scaled dot-product attention per head
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)

        # Combine heads: (B, n_heads, T, d_k) -> (B, T, C)
        out = (attn @ V).transpose(1, 2).contiguous().view(B, T, C)
        return self.W_o(out)
🔧

Break It — See What Happens

Force all heads to same weights
📊

Real-World Numbers

Head counts scale with model size: GPT-2 Small uses (d_model=768, ), while GPT-3 175B uses (d_model=12288, ). Modern models like Llama-2 70B further add GQA to compress the KV cache:

Llama-2 70B GQA Configuration (actual architecture)

ParameterValue
Query heads (h)
KV groups (g)
d_model8192
d_k per head
KV cache savings8x smaller vs MHA

Parameter distribution per layer — dense transformer approximation (d=8192)

Attention
4d² = 268M
FFN
8d² = 537M

Note: these are standard dense-transformer formulas. Llama-2 70B uses SwiGLU FFN (intermediate_size=28672), making FFN ≈704M/layer and ~56B total across 80 layers.

💡 Tip · FFN accounts for roughly 67% of each layer's parameters in a standard dense transformer. GQA shrinks the KV cache by 8x, which is key to making 70B model inference feasible on a single machine.
🔬

MHA vs GQA vs MQA — KV Cache Impact

At inference time, every generated token appends one K vector and one V vector per layer to the KV cache. The total KV cache size is:

The factor of 2 is for K and V. With MHA, . With , for group size . With MQA (Shazeer 2019), .

VariantKV headsKV cache (Llama-2 70B, 4K ctx, batch=8, FP16)Quality vs MHA
MHA64Baseline
GQA (g=8)8
MQA1~1.25 GBNoticeable degradation

The practical consequence: Llama-2 70B with MHA would require ~80 GB just for the KV cache at 4K context (batch=8, FP16) — a significant fraction of an 80 GB A100's memory budget on top of the model weights. GQA cuts this to ~10 GB, making 70B inference viable on a single GPU. The insight from Ainslie et al. is that K and V patterns are far more redundant across heads than Q patterns — sharing KV across groups of queries loses almost nothing, while sharing Q would lose significant expressivity.

PyTorch: GQA implementation
import torch, torch.nn as nn, torch.nn.functional as F

class GroupedQueryAttention(nn.Module):
    def __init__(self, d_model=4096, n_q_heads=32, n_kv_heads=8):
        super().__init__()
        assert n_q_heads % n_kv_heads == 0
        self.n_q, self.n_kv = n_q_heads, n_kv_heads
        self.d_k = d_model // n_q_heads
        self.groups = n_q_heads // n_kv_heads   # queries per KV group

        self.wq = nn.Linear(d_model, n_q_heads  * self.d_k, bias=False)
        self.wk = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.wv = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
        self.wo = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        B, T, _ = x.shape
        q = self.wq(x).view(B, T, self.n_q,  self.d_k).transpose(1, 2)
        k = self.wk(x).view(B, T, self.n_kv, self.d_k).transpose(1, 2)
        v = self.wv(x).view(B, T, self.n_kv, self.d_k).transpose(1, 2)

        # Expand K,V to match Q heads: repeat each KV head 'groups' times
        k = k.repeat_interleave(self.groups, dim=1)  # [B, n_q, T, d_k]
        v = v.repeat_interleave(self.groups, dim=1)

        scores = (q @ k.transpose(-2, -1)) / self.d_k ** 0.5
        out = (scores.softmax(-1) @ v).transpose(1, 2).reshape(B, T, -1)
        return self.wo(out)
💡 Tip · GQA can be retrofitted from a trained MHA checkpoint by mean-pooling the KV head weights within each group, then fine-tuning briefly — this is the "GQA from MHA checkpoint" technique from Ainslie et al. 2023, which avoids training from scratch.

KV-cache sizes and quality deltas here are rough model-specific comparisons based on the cited Llama-2 and GQA papers, not a claim about the current 2026 leaderboard.

💡 Tip · 2024 evolution — MLA (Multi-Head Latent Attention): DeepSeek-V2/V3 (2024) go one step further than GQA by replacing the per-head K/V projections entirely with a shared low-rank latent vector where . This achieves a 93.3% KV cache reduction vs MHA — see the KV Cache module for the full MLA breakdown and MHA vs GQA vs MQA vs MLA comparison table.

Quick check

Trade-off

GQA with g=8 in Llama-2 70B cuts KV cache 8× vs MHA, with <0.5% quality degradation. MQA (g=1) cuts it 64× but shows noticeable degradation. What does this asymmetry reveal about information redundancy in attention?

GQA with g=8 in Llama-2 70B cuts KV cache 8× vs MHA, with <0.5% quality degradation. MQA (g=1) cuts it 64× but shows noticeable degradation. What does this asymmetry reveal about information redundancy in attention?
🧠

Key Takeaways

What to remember for interviews

  1. 1h heads × d_k = d_model — same total parameters as single-head, but parallel specialization.
  2. 2Each head learns its own Q/K/V projections, so different heads attend to different relationship types.
  3. 3W^O is the only place heads can share information — it mixes all head outputs back into d_model.
  4. 4Removing head diversity (all heads same weights) collapses MHA into duplicated identical heads, eliminating the benefit of multiple heads.
🧠

Recap quiz

Derivation

Vaswani et al. use h=8 heads with d_k=64 in a d_model=512 Transformer. If you replaced MHA with a single-head attention keeping d_model=512 and d_k=512, how does the total QKV parameter count change?

Vaswani et al. use h=8 heads with d_k=64 in a d_model=512 Transformer. If you replaced MHA with a single-head attention keeping d_model=512 and d_k=512, how does the total QKV parameter count change?
Trade-off

A single attention head with d_k = d_model can represent any linear function of input tokens. Why does splitting into h smaller heads with d_k = d_model/h empirically outperform it?

A single attention head with d_k = d_model can represent any linear function of input tokens. Why does splitting into h smaller heads with d_k = d_model/h empirically outperform it?
Derivation

Llama-2 70B uses GQA with 64 Q-heads and 8 KV-groups (g=8). At 4K context, batch=8, FP16 (2 bytes), with 80 layers: what is the approximate KV cache size with GQA vs full MHA?

Llama-2 70B uses GQA with 64 Q-heads and 8 KV-groups (g=8). At 4K context, batch=8, FP16 (2 bytes), with 80 layers: what is the approximate KV cache size with GQA vs full MHA?
Derivation

GPT-2 Small uses 12 attention heads with d_model=768 (d_k=64). GPT-3 175B uses 96 heads with d_model=12288 (d_k=128). Why does d_k double even though head count increases 8×?

GPT-2 Small uses 12 attention heads with d_model=768 (d_k=64). GPT-3 175B uses 96 heads with d_model=12288 (d_k=128). Why does d_k double even though head count increases 8×?
Trade-off

In multi-head attention, what happens to the output if you remove W^O entirely (i.e., just concatenate head outputs without any final linear projection)?

In multi-head attention, what happens to the output if you remove W^O entirely (i.e., just concatenate head outputs without any final linear projection)?
Trade-off

Michel et al. (2019) found that pruning most attention heads at test time causes minimal quality loss. What does this imply about the structure of multi-head attention in practice?

Michel et al. (2019) found that pruning most attention heads at test time causes minimal quality loss. What does this imply about the structure of multi-head attention in practice?
📚

Further Reading

🎯

Interview Questions

Difficulty:
Company:

Showing 7 of 7

MHA vs single-head attention with the same total parameters — what changes?

★★☆
GoogleMeta

What is W^O's role in MHA? Can you remove it?

★★☆
GoogleOpenAI

Grouped Query Attention (GQA) — why is it the new standard in Llama-2 70B and beyond?

★★★
OpenAIAnthropicDatabricks

Why is the output projection W^O necessary? What would happen without it?

★★☆
OpenAIGoogle

How does GQA reduce memory bandwidth while preserving quality?

★★☆
Google

What patterns do different attention heads learn to specialize in?

★★☆
Anthropic

Derive the memory savings of MQA vs MHA for a model with 32 heads.

★★★
Meta