Skip to content

Transformer Math

Module 9 · The Transformer

🏗️ The Full Forward Pass

What happens to the word "cat" in 0.003 seconds?

Status:

You know the individual components -- embeddings, attention, normalization, FFN. Now see how they connect. The forward pass is the complete journey a token takes from input to predicted next token. Understanding this end-to-end flow is essential for debugging, optimizing, and explaining Transformer behavior in interviews.

🔄

The Residual Stream

What you're seeing: The residual stream view from Elhage et al. 2021 (Anthropic). The thick horizontal line is the shared communication highway — a vector of dimension d_model that flows left to right through the entire network. At each Transformer block, Attention and FFN each branch off, do their computation, and add their result back to the stream. The stream itself is never replaced — only incremented.

Key insight: Every component is a residual writer. Removing any one component degrades the model slightly, but the stream still flows. This is why layers can be pruned and why the model is robust to individual component failures.

What you’re seeing: the residual stream flowing through a stack of Transformer blocks — each Attention and FFN sublayer branches off, computes a delta, and adds it back to the shared stream rather than replacing it. What to try: trace the stream from input embedding to final output and count how many additive writes happen per layer.

Token Embed+ PositionLN →Logits· · ·N blocksBlock 1Multi-Head Attn+FFN (SwiGLU)+Block 2Multi-Head Attn+FFN (SwiGLU)+Each component reads from and writes to the stream additivelyresidual stream (d_model)
🎮

Token Journey Through the Transformer

What you're seeing: The complete path a token takes through a decoder-only Transformer. Each box is a processing stage. The dashed bracket marks the block that repeats N times (32 for Llama-2 7B, 80 for 70B).

What to try: Click any stage to see what happens there. Notice the two residual connections (orange) inside each block -- these are the gradient highways. The final LayerNorm and linear projection convert the last hidden state into vocabulary logits.

|
|
Transformer Block (repeated N times)
|
|
|
|
|
|
|
|
💡

The Intuition

The Residual Stream View: Instead of thinking of layers as sequential processing stages, think of the model as a shared residual stream. Each attention head and FFN sublayer reads from this stream, processes information, and writes its contribution back. The stream accumulates features across all layers.

Diagram 1 — One token's journey through the full forward pass

"cat"raw textToken ID2845Embed[d=4096]+ Pos EncRoPEBlock × NAttn + FFNFinal LNRMSNormLinear→ vocabSoftmax→ P(next)"mat" = 0.31Token "cat" traveling through the complete Transformer forward pass

Training (Teacher Forcing):All target tokens are known upfront. The model processes the entire sequence in parallel using a causal mask to prevent future peeking. Loss is computed at every position simultaneously: "given tokens 1..t, predict token t+1." This is why training is compute-bound (massive parallel matmuls).

Inference (Autoregressive): Tokens are generated one at a time. Each new token requires a full forward pass through all layers. The KV cache stores previous keys/values to avoid recomputation, .

Diagram 2 — Autoregressive generation: each output feeds back as the next input

Input contextTransformerOutput token→ next input"The"Fwd Passall N layers"cat"appended"The cat"Fwd Passall N layers"sat"appended"The cat sat"Fwd Passall N layers"on"
✨ Insight · Training sees the whole movie at once (parallel). Inference watches one frame at a time (sequential). The causal mask is what makes parallel training equivalent to sequential generation -- each position can only see the past.

The Residual Stream in Detail: Elhage et al. (Anthropic, 2021) formalized the view that the residual connection is not just a training trick — it is the primary communication channel of the network. At each layer, attention heads and FFN sublayers are additive writers to a shared vector. Because the residual is preserved end-to-end, early layers can communicate directly to late layers by writing features that later layers read without modification. This explains why layer pruning often hurts less than expected: each layer makes a small contribution to the stream rather than performing a critical sequential transformation. It also explains superposition — the stream encodes far more features than its dimensionality would suggest, because different features destructively interfere only rarely in practice.

Diagram 3 — Final softmax output: top-5 next-token probabilities for “The cat sat on the”

0.00.10.20.30.31"mat"0.18"floor"0.12"chair"0.07"roof"0.05"table"← top predictionSoftmax output for context "The cat sat on the" — top 5 of ~32k tokensP(next token)

Quick check

Trade-off

During autoregressive inference at batch size 1, which resource is the binding constraint on a modern GPU?

During autoregressive inference at batch size 1, which resource is the binding constraint on a modern GPU?
Quick Check

What is the key difference between the training and inference forward pass?

📐

Step-by-Step Derivation

Complete Forward Pass (One Block)

Pre-Norm decoder block. Input :

After blocks:

where projects hidden states to vocabulary logits.

Training Loss (Cross-Entropy)

For a sequence of tokens, the model predicts each token given the preceding context:

Perplexity is the exponentiated loss:

PPL = 1 means perfect prediction. PPL = means random guessing. State-of-the-art models typically achieve PPL in the low single digits on standard benchmarks.

FLOPs per Forward Pass (2N Rule)

For a dense transformer with non-embedding parameters, a single token's forward pass costs approximately:

The factor of 2 comes from the multiply-accumulate structure of matrix multiplication: each weight contributes one multiply and one add per token. Derived in . For GPT-3 (), one token costs .

At training time with sequence length and batch size :

The 6N factor (= 2N forward + 4N backward) is derived in Kaplan et al. and used by to compute compute-optimal training runs. The Chinchilla result: compute-optimal token count ≈ 20× parameter count.

PyTorch implementation
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerBlock(nn.Module):
    def __init__(self, d_model=4096, n_heads=32, d_ff=11008):
        super().__init__()
        self.ln1 = nn.RMSNorm(d_model)
        self.attn = MultiHeadAttention(d_model, n_heads)
        self.ln2 = nn.RMSNorm(d_model)
        self.ffn = SwiGLU_FFN(d_model, d_ff)

    def forward(self, x, kv_cache=None):
        # Pre-Norm + Attention + Residual
        h = x + self.attn(self.ln1(x), kv_cache=kv_cache)
        # Pre-Norm + FFN + Residual
        return h + self.ffn(self.ln2(h))

# Full decoder language model
class DecoderLM(nn.Module):
    def __init__(self, vocab_size=32000, d_model=4096, n_layers=32, n_heads=32):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([TransformerBlock(d_model, n_heads) for _ in range(n_layers)])
        self.final_ln = nn.RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids):
        x = self.embed(input_ids)        # [B, T, d]
        for layer in self.layers:
            x = layer(x)                 # N transformer blocks
        logits = self.lm_head(self.final_ln(x))  # [B, T, vocab]
        return logits

Quick check

Derivation

The 2N FLOPs-per-token rule (Kaplan et al. 2020) counts operations for a dense transformer. Where does the factor of 2 come from?

The 2N FLOPs-per-token rule (Kaplan et al. 2020) counts operations for a dense transformer. Where does the factor of 2 come from?
🔧

Break It -- See What Happens

Swap attention and FFN order
Remove final LayerNorm

Quick check

Trade-off

If you remove one transformer block from the middle of a 32-layer model, what does the residual stream view predict about the quality degradation?

If you remove one transformer block from the middle of a 32-layer model, what does the residual stream view predict about the quality degradation?
📊

Real-World Numbers

ModelLayersd_modelHeadsParameters
GPT-212
GPT-39696
Llama-2 7B32
Llama-2 70B8064
GPT-4 * (est.)~120~12k-16k~96-128~1.8T (MoE)
Llama-3 70B8081926470.6B

* GPT-4 architecture details not officially confirmed by OpenAI — figures are community estimates from leaks and benchmarking.

Training vs Inference

AspectTrainingInference
ParallelismAll tokens in parallel (causal mask)Sequential (autoregressive)
InputGround-truth tokens (teacher forcing)Model's own previous output
KV CacheNot neededCritical -- avoids recomputing all previous tokens
BottleneckCompute-bound (matmuls)Memory-bound (KV cache bandwidth)
🎯 Interview · Common interview question: "Why can training run in parallel but inference must be sequential?" Key answer: During training, all target tokens are known, and the causal mask simulates autoregressive constraints while allowing parallel computation. During inference, each token depends on the previous output.

Quick check

Trade-off

A team wants to maximize tokens-per-second throughput on a 7B-parameter model at fixed GPU budget. Latency SLO is relaxed. What is the most effective lever?

A team wants to maximize tokens-per-second throughput on a 7B-parameter model at fixed GPU budget. Latency SLO is relaxed. What is the most effective lever?
🧠

Key Takeaways

What to remember for interviews

  1. 1Training uses teacher forcing: all target tokens are known upfront, so the entire sequence is processed in parallel using a causal mask — making training compute-bound.
  2. 2Inference is autoregressive: each token is generated one at a time, feeding the previous output as the next input — making inference memory-bound and sequential.
  3. 3The causal mask (preventing future peeking) is what makes parallel training mathematically equivalent to sequential generation.
  4. 4The residual stream (Elhage et al., Anthropic 2021) is the primary communication channel: attention heads and FFN sublayers are additive writers to a shared vector, never replacing it.
  5. 5KV cache stores previously computed key/value vectors to avoid redundant recomputation at each inference step, at the cost of memory growing linearly with sequence length.
🧠

Recap quiz

Derivation

A dense transformer has 7B non-embedding parameters. Using the 2N rule, how many FLOPs does one forward pass through a single token require?

A dense transformer has 7B non-embedding parameters. Using the 2N rule, how many FLOPs does one forward pass through a single token require?
Derivation

A serving cluster must deliver 100 tokens/s on a 175B-parameter dense model. Approximately how many TFLOP/s of compute is needed just for the forward passes?

A serving cluster must deliver 100 tokens/s on a 175B-parameter dense model. Approximately how many TFLOP/s of compute is needed just for the forward passes?
Trade-off

During autoregressive decode, why is the bottleneck memory bandwidth rather than compute, even on a high-FLOP GPU?

During autoregressive decode, why is the bottleneck memory bandwidth rather than compute, even on a high-FLOP GPU?
Trade-off

Chinchilla (Hoffmann et al. 2022) shows the compute-optimal token-to-parameter ratio is ~20×. A team trains Llama 3 8B on 15T tokens. What does this imply?

Chinchilla (Hoffmann et al. 2022) shows the compute-optimal token-to-parameter ratio is ~20×. A team trains Llama 3 8B on 15T tokens. What does this imply?
Trade-off

The residual stream (Elhage et al. 2021) frames layers as additive writers to a shared bus. Which consequence does this explain best?

The residual stream (Elhage et al. 2021) frames layers as additive writers to a shared bus. Which consequence does this explain best?
Trade-off

logit → softmax → sampling: an interviewer asks why you would use temperature scaling before softmax instead of top-k sampling alone. What is the core tradeoff?

logit → softmax → sampling: an interviewer asks why you would use temperature scaling before softmax instead of top-k sampling alone. What is the core tradeoff?
📚

Further Reading

🎯

Interview Questions

Difficulty:
Company:

Showing 5 of 5

Training is parallel but inference is sequential -- why?

★★★
GoogleOpenAIAnthropicMeta

Walk through the complete forward pass of a single token through one Transformer block.

★★☆
GoogleOpenAIAnthropicMeta

What is teacher forcing and what is its main drawback?

★★☆
GoogleOpenAIAnthropic

Explain KV cache and why it matters for inference performance.

★★☆
OpenAIAnthropicMeta

What is the residual stream view of Transformers and why is it useful?

★★★
AnthropicOpenAI