🎯 Self-Attention
GPT-4 reads ‘The trophy didn’t fit in the suitcase because it was too big.’ What does ‘it’ refer to? Humans know instantly. Without attention, models can’t.
Interactive Sandbox
What you’re seeing:the scaled dot-product attention computation — each token emits a Query, Key, and Value vector; the diagram traces how one token’s Query scores against every other token’s Key to produce attention weights. What to try:select different tokens below to see how their attention distribution shifts, then toggle “Remove sqrt(d_k) scaling” to watch weights collapse to near one-hot.
Scaled dot-product attentionis how each token decides which other tokens are relevant. Every token produces a Query ("what am I looking for?"), a Key ("what do I contain?"), and a Value ("what information do I carry?"). The attention weight between two tokens is the dot product of Query and Key, scaled by , then passed through softmax.
What to try: select different tokens to see their attention distribution across the sentence. Notice how "it" attends strongly to its antecedent. Then toggle "Remove sqrt(d_k) scaling" below to watch the weights collapse to near one-hot.
Show calculation steps
Query: "sat" d_k = 64 Raw scores (Q·K): [-1.77, -1.89, 3.74, 1.11, -1.04, -0.36] Scaled (÷√64=8.0): [-0.22, -0.24, 0.47, 0.14, -0.13, -0.04] After softmax: [0.1299, 0.1280, 0.2586, 0.1862, 0.1423, 0.1550]
Why Every Token Spies on Every Other
Every token needs to "borrow" information from other tokens to understand itself. For example, the word "bank" means completely different things in "river bank" vs. "bank account" — the model needs to look at surrounding words to decide.
The essence of attention is: each token asks every other token "how relevant are you to me?", then computes a weighted sum based on relevance. The weight bars you saw in the visualization above represent exactly this "relevance score".
If d_k = 64, what is the scaling factor √d_k?
Worked Example: One Token at a Time
Before the matrix math, let's walk through attention for one specific word — step by step, with numbers. This is the Alammar-style walkthrough.
Step 1 — Create Q, K, V vectors
Take the sentence "The cat sat". Each word has an embedding vector (d=4 for simplicity). We multiply each embedding by three weight matrices to get three new vectors per word:
//"The" embedding × W_Q → q₁ = [1.2, 0.8, 0.3]
//"The" embedding × W_K → k₁ = [0.5, 1.1, 0.9]
//"The" embedding × W_V → v₁ = [0.2, 0.7, 1.0]
//Same for "cat" → q₂, k₂, v₂
//Same for "sat" → q₃, k₃, v₃
Step 2 — Score: How much should "The" attend to each word?
To compute attention for "The" (position 1), we take its Query vector and dot it with every Key vector:
score("The", "The") = q₁ · k₁ = 1.2×0.5 + 0.8×1.1 + 0.3×0.9 = 1.75
score("The", "cat") = q₁ · k₂ = 2.81
score("The", "sat") = q₁ · k₃ = 1.03
Higher score = more relevant. "cat" is most relevant to "The" (makes sense — "The cat").
Step 3 — Scale by , then Softmax
Divide scores by = √3 ≈ 1.73 (our d_k is 3 in this example), then apply softmax to get probabilities that sum to 1:
scaled: [1.75/1.73, 2.81/1.73, 1.03/1.73] = [1.01, 1.62, 0.60]
softmax → [0.29, 0.53, 0.18]
"The" pays 53%attention to "cat", 29% to itself, 18% to "sat". Without scaling, these scores would be more extreme (near one-hot), killing gradients.
Step 4 — Weighted sum of Value vectors
Multiply each Value vector by its softmax weight, then sum:
output₁ = 0.29 × v₁ + 0.53 × v₂ + 0.18 × v₃
= 0.29 × [0.2, 0.7, 1.0]
+ 0.53 × [0.4, 0.2, 0.8]
+ 0.18 × [0.9, 0.5, 0.3]
= [0.43, 0.40, 0.77]
This output vector for "The" is a blend of all tokens, weighted by relevance. It's now "context-aware" — it knows about "cat" and "sat".
From vectors to matrices
We just computed attention for one token. In practice, we do all tokens simultaneously by stacking Q, K, V into matrices and using one formula:
Every row of the output matrix is exactly the weighted sum we computed above — just parallelized across all tokens at once. This is why Transformers are fast on GPUs.
Step-by-Step Derivation
Step 1: Project into Q, K, V
The input matrix is projected into Query, Key, and Value through three trainable matrices:
Step 2: Compute similarity (dot product)
Step 3: Scale by
When d_k =64 , the scaling factor is
Step 4: Softmax → weighted sum
Full formula (the one to memorize)
Quick check
You increase d_k from 64 to 256 (keeping everything else fixed). By what factor does the standard deviation of an unscaled dot product change?
PyTorch implementation
# Scaled Dot-Product Attention
import torch
import torch.nn.functional as F
def attention(Q, K, V, d_k):
scores = Q @ K.transpose(-2, -1) / d_k**0.5 # [n, n]
weights = F.softmax(scores, dim=-1) # [n, n]
return weights @ V # [n, d_v]
# In practice: Q, K, V come from linear projections
d_model, d_k = 768, 64
W_q = torch.nn.Linear(d_model, d_k, bias=False)
W_k = torch.nn.Linear(d_model, d_k, bias=False)
W_v = torch.nn.Linear(d_model, d_k, bias=False)
x = torch.randn(10, d_model) # 10 tokens, 768-dim
Q, K, V = W_q(x), W_k(x), W_v(x) # each [10, 64]
out = attention(Q, K, V, d_k) # [10, 64]PyTorch: Scaled Dot-Product Attention
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q, K, V: (batch, seq_len, d_k)
mask: (batch, 1, 1, seq_len) or None
"""
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output, attn_weights
# Usage with PyTorch's built-in (may dispatch to FlashAttention backend):
# output = F.scaled_dot_product_attention(Q, K, V, attn_mask=mask)Flash Attention — IO-Aware Exact Attention
Standard attention materializes the full attention matrix in GPU HBM (high-bandwidth memory). At seq_len=4096 with float16, that is per head — and must be read and written multiple times. This makes attention memory-bandwidth bound, not compute bound.
Flash Attention (Dao et al. 2022) rewrites the attention kernel to never materialize the full matrix. Instead, it tiles Q, K, V into SRAM-resident blocks and fuses the softmax + matmul into a single pass, using the online softmax trick to accumulate the correct result incrementally:
Online softmax (per block):
For each Q block, iterate over K/V blocks. Track running max and normalizer . Update output without ever storing the full row:
| Property | Standard Attention | Flash Attention |
|---|---|---|
| Memory (HBM) | O(n²) | O(n) |
| FLOPs | O(n²d) | O(n²d) — same |
| HBM reads/writes | O(nd + n²) | Θ(N²d²/M) — M = SRAM size |
| Speedup in practice | — |
Quick check
FlashAttention reports 2–4× speedup on A100 GPUs without changing FLOPs. A colleague proposes: “Then we can train with double the batch size for the same wall-clock time.” Is this correct?
Scaling Attention Beyond a Single GPU
Flash Attention solves the memory problem within a single GPU. Two influential 2023 system papers pushed that idea further — one for multi-GPU training at million-token scale, and one for high-throughput inference serving.
Ring Attention (Liu et al., 2023): With very long sequences, even the KV cache no longer fits on one device. Ring Attention distributes the sequence across a ring of GPUs. Each device holds one chunk of tokens, computes local attention, and passes KV blocks around the ring until every device has seen every block. The result is still exact attention, but memory per device scales as instead of .
Paged Attention (Kwon et al., 2023): During inference, preallocating one contiguous KV buffer per request wastes a lot of memory early in generation. PagedAttention splits the KV cache into fixed-size blocks and uses a page table to map logical sequence positions to physical pages on demand. That reduces both internal and external fragmentation and, in the vLLM paper, improved throughput by roughly 2–4× versus prior baselines at similar latency.
Scaling notes above are historical snapshots of the cited 2022–2023 papers, not a claim about the current 2026 SOTA frontier.
Break It — See What Happens
Scaled dot-product attention has exactly two invariants that keep it numerically stable: the √dk scaling and the softmax normalization. Remove either one and the mechanism fails—not gracefully, but catastrophically, in ways that invalidate the entire training run. Understandingwhy each invariant is load-bearing is more useful for interviews than memorizing the formula.
The scaling failure is purely statistical. When q and k are each drawn from N(0, 1), the dot product q · k is a sum of dk independent unit-variance terms, so its variance is dk. At , the standard deviation is 8; at , it is . Softmax exponentiates these scores: e11≈ 60,000 while e1≈ 2.7, so a single high-scoring key captures virtually all the probability mass. The output collapses to a near-hard lookup of the top key rather than a weighted average across all keys. Gradient flow to the other (near-zero weight) positions vanishes. Vaswani et al. 2017 identified this explicitly in §3.2.1 and introduced the 1/√dk divisor as the fix.
The softmax failure is a normalization failure. Without it, raw attention weights are unnormalized dot products that can take any sign and any magnitude. The output at position i becomes Σj (qi · kj) · vj—a sum with O(n) terms where total weight grows with sequence length. At n = 4096 (a common context in 2023-era models), the output is thousands of times larger than at n = 1. Gradient magnitudes cascade through every subsequent layer: the feed-forward block, the residual add, and the next attention layer all see exploding activations. Training diverges long before loss meaningfully decreases.
d_k =64 — try d_k=1 to see attention degrade to simple similarity
Quick check
You remove softmax entirely from attention — outputs are now raw weighted sums with unnormalized weights. Beyond the obvious (weights don’t sum to 1), what is the most serious training-time consequence?
What this teaches
The √dkscaling is not a cosmetic polish detail—it is an architectural constraint that every production transformer ships with because removing it makes the model untrainable beyond small dk. GPT-2 (dk = 64), GPT-3 (dk = 128), and Llama-2 70B (dk = 128) all use the same divisor. The (Dao et al. 2022) preserved exact numerical equivalence with vanilla attention precisely so the scaling and softmax semantics are unchanged—Flash is an IO optimization, not a change to the attention formula.
From a production standpoint, these two invariants explain a concrete operational risk: any low-precision quantization scheme that compresses the QK dot products before the scaling divisor is applied will amplify rounding error by up to √dk relative to post-scaling quantization. This is why quantization toolkits (bitsandbytes, GPTQ) apply quantization to the projection weights WQ, WKrather than to the raw activations inside the attention kernel—the kernel's numerical contract depends on stable dot-product magnitudes, and the 1/√dk divisor only restores that stability if the inputs are in roughly the right range. Pre-scale quantization breaks the contract; post-scale (or weight-only) quantization preserves it.
The 43 GB You Can’t Afford to Recompute
Every number in the table below encodes a deliberate engineering decision. The progression from the original Transformer (dk=64, 8 heads, 6 layers) to GPT-3 (dk=128, 96 heads, 96 layers) to Llama-3 70B (dk=128, 64 query heads, 80 layers) is not arbitrary scaling—each choice reflects constraints in memory, compute, and serving economics that became apparent only at production scale. Reading the table as a progression tells you what the field learned between 2017 and 2023.
The most diagnostic column is n_heads. GPT-3's with dk=128 gives dmodel=12,288. That is an unusually wide model: the original Transformer used just 8 heads at dmodel=512. The width expansion multiplies KV cache size proportionally—this is exactly why, by the time Llama-2 70B was designed, had become necessary: it reduces KV cache by 8× relative to full multi-head attention while retaining 64 query heads for expressive power. The shift from MHA → GQA was not a quality choice—it was a serving economics choice.
| Model | d_model | n_heads | d_k | n_layers |
|---|---|---|---|---|
| Llama-3 70B | 8192 | 128 | 80 | |
| GPT-3 175B | 12288 | 96 | ||
| Llama-2 7B | 4096 | 32 | 32 | |
| Mistral 7B | 4096 | 32 | 128 | 32 |
| Original Transformer | 512 | 8 | 6 |
What these numbers imply for production
KV cache memory at GPT-3 scale. Each forward pass for GPT-3 at seq_len=2,048 stores K and V tensors for every layer and every head. The arithmetic is exact: per request. That is 9.7 GB of GPU memory consumed just to hold one user's conversation state—before any model weights are considered. At 350 GB of weights in bf16, a single A100 node (8×80 GB = 640 GB) can serve only ~30 concurrent users at max context before KV cache alone saturates memory. This is the precise reason (Dao et al. 2022) and PagedAttention (vLLM) were immediately adopted in production: FlashAttention avoids materializing the full n×n attention matrix in HBM, and PagedAttention manages KV cache as virtual memory pages so unused capacity can be reallocated across requests.
Attention at 100K context.The quadratic cost is latent at short sequences but lethal at long ones. A 100K token context with 96 heads produces an attention matrix of 100,000×100,000 = 10B elements per head per layer. At fp16, that is 20 GB per head per layer—clearly impossible to store in HBM even for a single layer. Standard dense attention was never a viable design for 100K+ contexts; sparse attention (Longformer, BigBird), linear approximations (Performer), and ultimately FlashAttention's tiling approach (which avoids storing the full matrix at all) were each direct engineering responses to this constraint. Any interview question about long-context serving is really asking you to reason from this number.
Per-head compute cost and why MQA/GQA were invented. Each attention head requires two matrix multiplies of shape (n×dk) × (dk×n) = O(n²dk) for QKT, plus another O(n²dk) for the weighted sum over V. Across 96 heads and 96 layers at GPT-3 scale, the attention FLOPs alone scale as O(n² × n_heads × n_layers × dk). At long contexts, this dominates FFN compute. The insight behind Multi-Query Attention (MQA, Shazeer 2019) was that sharing a single K and V head across all query heads reduces KV cache by n_heads× with only a small quality regression. Grouped Query Attention (GQA, Ainslie et al. 2023) generalized this to g groups, letting practitioners pick the quality/memory tradeoff: Llama-2 70B's g=8 achieves 8× KV savings while retaining 64 independent query heads for expressivity.
What this means for an interview answer.When asked to design a serving system for a 70B model, you can anchor the conversation immediately: 9.7 GB KV cache per GPT-3-scale request at 2K context, scaling linearly with context. At 100K context that becomes ~475 GB per request—impossible on any single GPU, which forces either chunked prefill (process context in segments), offloading KV cache to CPU/NVMe, or a sparse attention variant that avoids materializing the full matrix. Name the mechanism (FlashAttention for dense, PagedAttention for multi-user, GQA for KV compression) and state what constraint each one solves. An interviewer who hears you derive these numbers from first principles—rather than reciting them from memory— signals that you understand the system, not just the buzzwords.
Quick check
GPT-3 has 96 attention heads, d_model=12288 (d_k=128), and 96 layers. At seq_len=2048 in float16, what is the total KV-cache size for a single request?
Key Takeaways
What to remember for interviews
- 1Attention is a soft dictionary lookup: every token queries all keys and returns a weighted average of values
- 2Scaling by √d_k prevents softmax saturation — without it, large d_k drives dot products to ±∞ and gradients vanish
- 3Time complexity is O(n²d): the QKᵀ matrix multiply is quadratic in sequence length
- 4Causal masking (−∞ in the upper triangle) enforces the autoregressive property for decoder-only generation
Recap Quiz
Self-attention computes scores = QKᵀ. If q,k ~ N(0,1) and d_k=128, what is the standard deviation of an unscaled dot product, and why does this break softmax?
Self-attention on a sequence of n=4096 tokens with d_k=128 computes the QKᵀ matrix. What is the memory footprint of that matrix in float16, and what is the FLOPs complexity class?
An interviewer asks: “What happens to attention patterns when you remove the √d_k divisor during inference but keep the trained weights?” What is the correct answer?
FlashAttention achieves 2–4× speedup on A100s without reducing FLOPs. What is the mechanism, and what does it reveal about the attention computation bottleneck?
In an encoder-decoder Transformer (e.g., T5), the decoder has both self-attention and cross-attention layers. What is the key structural difference in where Q, K, V come from in each?
GPT-3 175B uses 96 attention heads with d_model=12288, giving d_k=128 per head. What happens to each head’s representational capacity if you halve the number of heads to 48 (keeping d_model fixed)?
Linear attention replaces the softmax in Attention(Q,K,V) with a feature map φ(·) that satisfies K(q,k) ≈ φ(q)ᵀφ(k). This reduces complexity from O(n²d) to O(nd²). What is the fundamental quality tradeoff?
Further Reading
- Attention Is All You Need — Vaswani et al. 2017 — the paper that introduced scaled dot-product attention and the Transformer architecture.
- The Illustrated Transformer — Jay Alammar's visual walkthrough — the gold standard explanation of attention mechanics with step-by-step diagrams.
- The Illustrated BERT — Jay Alammar — how BERT reuses the Transformer encoder with bidirectional attention and masked language modeling for pretraining.
- 3Blue1Brown — Attention in Transformers — Grant Sanderson's visual deep-dive into attention — exceptional mathematical intuition with animations.
- Lilian Weng — Attention? Attention! — Lilian Weng's comprehensive survey of attention mechanisms from seq2seq to Transformers.
- Transformer Explainer (Georgia Tech) — Interactive visual explanation of GPT-2 running live in the browser — great for seeing attention weights.
- LLM Visualization — Brendan Bycroft — Step-by-step 3D walkthrough of a GPT model — trace every tensor through the forward pass.
- Andrej Karpathy — Let's Build GPT — Codes scaled dot-product self-attention from scratch in ~50 lines of PyTorch — essential companion for internalizing Q, K, V.
- A Mathematical Framework for Transformer Circuits (Elhage et al.) — Anthropic interpretability team's mechanistic view of attention heads as composable circuits — how heads actually specialize in practice.
- In-Context Learning and Induction Heads — Olsson et al. 2022 — induction heads as the mechanism behind in-context learning, emerging as a phase change during training.
- Chris Olah — Attention and Augmented Recurrent Neural Networks — Olah & Carter 2016 — the original visual explainer of attention mechanisms before transformers.
Interview Questions
Showing 11 of 11
Derive the attention mechanism from scratch on a whiteboard.
★★★Why do we scale by √d_k? Walk through the variance argument.
★★★What happens to softmax output when the input values are very large? Think about the gradient...
What is the time and space complexity of self-attention?
★★☆If d_k = 1, what does attention degrade to?
★★★Explain attention as a "soft dictionary lookup" analogy.
★★☆What happens to attention when the softmax temperature approaches zero?
What does each row of softmax(QK^T/√d_k) summing to 1 mean semantically?
★★☆What are the limitations of standard attention? What alternatives exist?
★★★How does causal masking work in decoder-only models? Why is it needed?
★★☆What is the difference between additive attention (Bahdanau) and dot-product attention? Why did Transformers choose dot-product?
★★☆Can you apply attention to non-sequence data? Give an example.
★★☆What is cross-attention and how does it differ from self-attention?
★★☆