Skip to content

Transformer Math

Module 0 · The Transformer

🏗️ High-Level Overview

GPT-3 predicts each token in 6ms — but processes the entire 96-layer forward pass to do it. Why can’t it just skip layers it already ‘knows’?

Status:
🏗️

The Complete Transformer Pipeline

This interactive diagram shows every stage a token passes through in a decoder-only Transformer. Click any stage to jump to its dedicated module. Hover for a quick summary.

Input: "The cat sat"raw text stringTokenizer (BPE)[n]"The" → 464, "cat" → 3797, "sat" → 3332Embedding Lookup[n, d]Token IDs → d-dim vectors+ Positional Encoding[n, d]sin/cos or RoPE positional signalTransformer Blockx N layersLayerNorm[n, d]Multi-Head Attention[n, d]Q, K, V → Scaled Dot-Product+residualLayerNorm[n, d]FFN (SwiGLU)[n, d]d → 4d → d (or 8d/3 for SwiGLU)+residualGPT-2: N=12, d=768 | Llama-3 8B: N=32, d=4096Final LayerNorm[n, d]Linear (Unembedding)[n, |V|]Project to vocab size (50k-128k)Softmax[n, |V|]Logits → probabilitiesNext token: "on" (p=0.34)argmax or sample from distribution[464, 3797, 3332]Thecatsatd=768Thecatsatcausal maskon (34%)down (21%)there (9%)TokenizerEmbedPosEncAttnFFNNormLinearSoftmax
📐

Architecture at a Glance

What you’re seeing: the data flow from raw tokens through one (or N stacked) Transformer blocks to output logits. Dashed accent lines are the residual bypasses — information flows around each sublayer, not just through it.

What to try: trace a single token from top to bottom. Notice that the shape of the vector never changes — every sublayer returns a vector of size d_model, added back to the stream. The model scales by increasing N (layers) and d_model (width), not by changing the shape.

× N layersInput Tokens"the cat sat"Token Embeddinglookup + pos encMulti-Head AttentionQ, K, V projectionsAdd & Normresidual + LayerNormFeed-Forward (FFN)d → 4d → d, GELUAdd & Normresidual + LayerNormOutput LogitsLinear → Softmaxresidualresidual
💡

What Breaks Without Each Component?

The Transformer is a sequence-to-sequence machine that converts a list of tokens into a probability distribution over the next token. Every modern LLM — GPT-4, Claude, Llama, Gemini — is built from repeated Transformer blocks (with variations like MoE or GQA). But why this particular combination of components? Each one fixes a specific failure mode.

Attention = Routing

Self-attention decides which tokens talk to which. Each token computes a query (“what am I looking for?”) and attends to keys from all previous tokens (“who has it?”). Without attention, the model is blind to context — “bank” in “river bank” and “bank account” would be indistinguishable after embedding.

FFN = Memory

The feed-forward network applies the same MLP to each token independently. Research () shows FFNs behave like key-value memories: the first layer detects patterns, the second retrieves associated facts. “Paris is the capital of ___” is answered by FFN neurons that activate on “capital of France” patterns.

Residual = Highway

Residual connections () create a direct gradient highway from the loss back to every layer. Without them, gradients must flow through 96+ matrix multiplications — any eigenvalue slightly below 1 causes exponential vanishing. With residuals, early layers remain trainable even at GPT-3 scale ().

LayerNorm = Stabilizer

LayerNorm rescales each token’s activation vector to mean 0, std 1 before each sublayer. Without it, activations amplify with depth — attention logits become unbounded, softmax collapses to one-hot, and gradients vanish. Pre-Norm (normalize before the sublayer, not after) keeps the residual path clean, enabling stable training without learning-rate warmup.

✨ Insight · The entire architecture can be summarized as: tokenize → embed → (normalize → attend → add → normalize → FFN → add) × N → project → softmax. Every component is a matrix multiplication — the Transformer is fundamentally a composition of linear maps with nonlinearities.

The residual stream (Elhage et al., 2021) is the backbone: each attention head and FFN sublayer writes additively into a shared d-dimensional vector. This means individual layers can be pruned with graceful degradation — the stream carries redundant information.

Worked Example: tracing “cat” through one Transformer block

Model: GPT-2 small — d_model=768, 12 heads, d_k=d_v=64 per head.

  1. 1. Embed.The token ID for “cat” is looked up in the embedding table (vocab × 768). Add the positional embedding for position t. Result: a single vector x ∈ ℝ⁷⁶⁸ — this is the residual stream state entering the block.
  2. 2. Pre-LayerNorm. Compute x̂ = LayerNorm(x). This normalizes the 768 values to mean≈0, std≈1, then applies learned scale γ and bias β. The original x is untouched — LayerNorm only affects the copy passed into attention.
  3. 3. Multi-Head Attention. Project into 12 separate Q/K/V triplets (each 64-dim via a 768→64 linear). Each head computes scaled dot-product attention over all previous positions. The 12 outputs (each 64-dim) are concatenated → 768-dim → projected through W_O (768×768). Result: attn_out ∈ ℝ⁷⁶⁸.
  4. 4. Residual add. x ← x + attn_out. The stream for “cat” is updated in-place. If attention learned nothing useful, attn_out ≈ 0 and xpasses through unchanged — the residual is the “default path”.
  5. 5. Pre-LayerNorm → FFN. Normalize again: x̂ = LayerNorm(x). Pass through the FFN: 768 → 3072 → 768 with GELU activation. This is where most per-token computation happens — the FFN expands to 4× width to mix features.
  6. 6. Residual add. x ← x + ffn_out. After 12 blocks of this, the final x ∈ ℝ⁷⁶⁸ is projected to vocab size (50257 for GPT-2) and softmaxed. The probability assigned to the next token is the output.

Key shape invariant: x stays (batch, seq_len, 768) throughout all N blocks — residuals guarantee this.

Quick check

Trade-off

Geva et al. (2021) show FFN layers act as key-value memories. Which component plays which role?

Geva et al. (2021) show FFN layers act as key-value memories. Which component plays which role?
Quick Check

What is the primary role of the FFN layers in a transformer?

Parameter Counting

Where do all those billions of parameters actually live? The formula is simpler than it looks — everything reduces to a few matrix-multiplication shapes.

1. Token Embedding Table

One vector of size per vocabulary token. For GPT-2: 50 257 × 768 ≈ params.

2. Attention (per layer)

4 weight matrices, each d×d

Each of the four projection matrices (Q, K, V, output) is . With heads, the per-head size is , but the total stays .

3. FFN (per layer)

standard d_ff = 4d

4. Total

ignoring biases and LayerNorm (< 0.1% of params)

For GPT-2 Small (L=12, d=768, V=50 257):

GPT-2 Small worked example

The figure closely matches the official GPT-2 paper (the simplified formula omits positional embeddings and biases, which account for <1% of total params).

The minimal Pre-Norm block below matches GPT-2, Llama, and every modern decoder-only model. Note the two residual adds and that LayerNorm always precedes the sublayer.

One Transformer block — Pre-Norm with residual connections

python
import torch
import torch.nn as nn

class TransformerBlock(nn.Module):
    def __init__(self, d_model=768, n_heads=12, d_ff=3072):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x, attn_mask=None):
        # Pre-Norm: LayerNorm BEFORE sublayer, residual wraps around it
        normed = self.ln1(x)
        attn_out, _ = self.attn(normed, normed, normed, attn_mask=attn_mask)
        x = x + attn_out          # residual add

        normed = self.ln2(x)
        x = x + self.ffn(normed)  # residual add
        return x

# GPT-2 small: d_model=768, 12 heads → d_k=64 per head
block = TransformerBlock(d_model=768, n_heads=12, d_ff=3072)
x = torch.randn(1, 16, 768)   # batch=1, seq_len=16, d_model=768
causal_mask = nn.Transformer.generate_square_subsequent_mask(16)
out = block(x, attn_mask=causal_mask)  # shape: (1, 16, 768)

Verify your formula against real model weights:

Parameter counting — formula vs real model

python
import torch
import torch.nn as nn

def count_transformer_params(vocab_size, d_model, n_layers, d_ff=None):
    """
    Parameter breakdown for a standard decoder-only Transformer.
    Formula: V*d + L*(4d^2 + 8d^2) = V*d + L*12d^2
    """
    if d_ff is None:
        d_ff = 4 * d_model  # standard FFN expansion ratio

    embedding  = vocab_size * d_model            # token embeddings only
    per_layer  = (
        4 * d_model**2 +   # attention: W_Q, W_K, W_V, W_O each d×d
        2 * d_model * d_ff  # FFN: d→d_ff and d_ff→d
    )
    total = embedding + n_layers * per_layer
    return total

# GPT-2 Small: L=12, d=768, V=50257
params = count_transformer_params(
    vocab_size=50257, d_model=768, n_layers=12
)
print(f"GPT-2 Small: {params / 1e6:.1f}M params")  # → ~124M

# GPT-3: L=96, d=12288, V=50257
params_gpt3 = count_transformer_params(
    vocab_size=50257, d_model=12288, n_layers=96
)
print(f"GPT-3: {params_gpt3 / 1e9:.1f}B params")  # → ~175B

# Verify against a real model
from transformers import GPT2Model
model = GPT2Model.from_pretrained("gpt2")
real = sum(p.numel() for p in model.parameters())
print(f"GPT-2 actual: {real / 1e6:.1f}M")  # → 124.4M
Advanced: SwiGLU FFN changes the formula

Llama 2/3 replace the standard 2-matrix FFN with SwiGLU, which uses three matrices (gate, up, down) at reduced width to preserve parameter count:

SwiGLU: same total params, different shape

The parameter count formula holds approximately regardless — SwiGLU adjusts to compensate for the extra matrix.

Quick check

Derivation

For a decoder-only model with L layers and d_model = d, which component dominates parameter count as d scales (assuming V is fixed and d_ff = 4d)?

For a decoder-only model with L layers and d_model = d, which component dominates parameter count as d scales (assuming V is fixed and d_ff = 4d)?
🔥

Break It

Toggle these ablations to see what each component actually contributes. These aren’t hypothetical — real training runs with these ablations fail within the first thousand steps.

Remove Residual Connections
Remove LayerNorm
Remove Attention (FFN-only model)

Quick check

Derivation

A 96-layer model is trained without residual connections. Gradients must pass through all 96 Jacobians sequentially. If each layer&apos;s Jacobian has spectral norm 0.99, what is the approximate gradient magnitude at layer 1 relative to layer 96?

A 96-layer model is trained without residual connections. Gradients must pass through all 96 Jacobians sequentially. If each layer&apos;s Jacobian has spectral norm 0.99, what is the approximate gradient magnitude at layer 1 relative to layer 96?
📊

Real Numbers

Every number below comes from the original papers or official technical reports (except GPT-4, which is community estimates). The scaling law is clear: parameters grow as , training tokens grow even faster.

ModelLayers (L)d_modelHeadsParamsTraining Tokens
GPT-2 Small1276812
GPT-2 XL481600251.5B
GPT-312 288
Llama 2 7B324096327B
Llama 2 70B80819270B
Llama 3 8B324096328B
Llama 3 70B80819270B
GPT-4 (est.)~120~12 288~96~1.8T (MoE)unknown

GPT-4 numbers are community estimates — not confirmed by OpenAI. All other numbers are from official papers or technical reports.

✨ Insight · Llama 3 trains 8B params on 15T tokens — . The Chinchilla scaling law (Hoffmann et al., 2022) says the optimal token count is the parameter count. At 8B params that is 160B tokens — Llama 3 goes far beyond that, trading compute budget for a smaller, higher-quality model that runs cheaply at inference.

Quick check

Trade-off

Llama 3 8B trains on 15T tokens, vastly beyond the Chinchilla-optimal ~160B. A interviewer asks: &ldquo;Is this wasteful compute?&rdquo; What is the correct response?

Llama 3 8B trains on 15T tokens, vastly beyond the Chinchilla-optimal ~160B. A interviewer asks: &ldquo;Is this wasteful compute?&rdquo; What is the correct response?
🧠

Key Takeaways

What to remember for interviews

  1. 1Every modern LLM is a stack of identical Transformer blocks: LayerNorm → Multi-Head Attention → residual add → LayerNorm → FFN → residual add.
  2. 2Attention routes information between token positions; FFN stores factual knowledge per-token; residuals are gradient highways; LayerNorm stabilizes activation scale.
  3. 3The residual stream is the backbone — each sublayer writes additively into a shared d-dimensional vector, never replacing it, preserving earlier information.
  4. 4Parameter count follows V·d + 12Ld²: GPT-2 Small (L=12, d=768, V=50 257) = 124M, GPT-3 (L=96, d=12288) = 175B.
  5. 5Pre-Norm (LayerNorm before each sublayer) keeps the residual path clean and enables stable training at 96+ layers without learning-rate warmup.
📚

Further Reading

🎯

Interview Questions

Difficulty:
Company:

Showing 4 of 4

Walk through the full forward pass of a decoder-only Transformer. What happens at each stage from raw text input to next-token probability?

★★☆
GoogleAnthropicOpenAI

Why do modern Transformers use Pre-Norm (LayerNorm before sublayer) instead of Post-Norm (after)?

★★☆
GoogleAnthropic

What is the causal mask in self-attention and why is it necessary for autoregressive generation?

★☆☆
GoogleOpenAIMeta

Compare the parameter count and compute distribution across the main components of a Transformer. Where do most parameters live?

★★★
GoogleMetaAnthropic
🧠

Recap Quiz

Derivation

GPT-2 Small has L=12, d_model=768, V=50 257. Using the formula Total = V·d + 12Ld², what is the approximate parameter count?

GPT-2 Small has L=12, d_model=768, V=50 257. Using the formula Total = V·d + 12Ld², what is the approximate parameter count?
Trade-off

Llama 3 8B was trained on 15T tokens — far beyond the Chinchilla-optimal ~160B tokens. What is the primary engineering reason to over-train this way?

Llama 3 8B was trained on 15T tokens — far beyond the Chinchilla-optimal ~160B tokens. What is the primary engineering reason to over-train this way?
Derivation

At what sequence length does attention compute (O(n²d)) overtake FFN compute (O(nd²)) for GPT-3 (d=12 288)?

At what sequence length does attention compute (O(n²d)) overtake FFN compute (O(nd²)) for GPT-3 (d=12 288)?
Trade-off

Why does Pre-Norm (LayerNorm before each sublayer) enable stable training without learning-rate warmup, while Post-Norm requires warmup?

Why does Pre-Norm (LayerNorm before each sublayer) enable stable training without learning-rate warmup, while Post-Norm requires warmup?
Derivation

In a standard Transformer, what fraction of parameters live in FFN layers (vs attention) for GPT-3 style config (d_ff = 4d, full attention)?

In a standard Transformer, what fraction of parameters live in FFN layers (vs attention) for GPT-3 style config (d_ff = 4d, full attention)?
Trade-off

Self-attention without positional encoding treats all token positions equivalently. What is the precise term for this property?

Self-attention without positional encoding treats all token positions equivalently. What is the precise term for this property?
Derivation

A model has the same L=12, d=768 config as GPT-2 Small but uses Grouped Query Attention (GQA) with 1 key-value head per 12 query heads. How does the attention parameter count change?

A model has the same L=12, d=768 config as GPT-2 Small but uses Grouped Query Attention (GQA) with 1 key-value head per 12 query heads. How does the attention parameter count change?