Skip to content

Transformer Math

Module 10 · Training

🔙 Backpropagation

Naive finite differences would need 175 billion forward passes to compute GPT-3's gradients. Backprop does it in one. Here's the trick.

Status:
⬅️

Gradient Flow Through a Computation Graph

What you're seeing: The computation graph for L = (wx + b − y)² — the mean squared error loss for a single linear neuron. Nodes are the intermediate values produced during the forward pass. Edges carry data forward (solid) and gradients backward (dashed).

What to look for: Each dashed arrow is labelled with the local gradientof that edge — the derivative of the node's output with respect to its input. During backprop, the upstream gradient (arriving from the right) is multiplied by the local gradient to produce the gradient that flows further left. That multiplication is the chain rule.

∂L/∂e = 2e∂e/∂a = 1∂a/∂z = 1∂a/∂b = 1∂z/∂w = x∂z/∂x = wxinputwweightbbiasytargetzwxaz+bea−yLforward passbackward pass (gradient)
✨ Insight · Key insight:Each node multiplies the upstream gradient by its local derivative — this IS the chain rule. No special “backprop algorithm” exists beyond applying the chain rule once per node, in reverse topological order.
🎮

Computation Graph Walkthrough

Consider the function L = (a × b + c)² with inputs . We'll build a computation graph, run the forward pass to get , then run one backward pass to get every gradient simultaneously.

💡 Tip · What you're seeing: Each node stores its output value (forward pass) and will receive a gradient (backward pass). Edges represent data flow — and gradient flow goes in the opposite direction.
∂L/∂a = −24∂L/∂b = 16∂L/∂d = 8∂L/∂c = 8∂L/∂e = 8a= 2b= −3c= 10×d = −6+e = 4( )²L = 16L (loss)node value (forward)gradient (backward)
Forward pass: compute values left → right
d = a × b = 2 × (−3) = −6
e = d + c = −6 + 10 = 4
L = e² = 4² = 16
Backward pass: apply chain rule right → left

Start with and propagate backwards through each operation using its local derivative. Every gradient is computed exactly once:

GradientComputationValue
dL/dLBy definition1
dL/ded(e²)/de = 2e = 2×48
dL/dddL/de × de/dd = 8 × 18
dL/dcdL/de × de/dc = 8 × 18
dL/dadL/dd × dd/da = 8 × b = 8 × (−3)−24
dL/dbdL/dd × dd/db = 8 × a = 8 × 216
✨ Insight · Key result: To decrease the loss, nudge a by +1 and it shifts L by −24. Nudge c by +1 and L shifts by +8. This is exactly whatgradient descent uses — the negative gradient points toward lower loss.
💡

Why Backprop Works

Neural networks are just big computation graphs — millions of elementary operations chained together. To train them, we need for every weight . Why not just perturb each weight?

MethodHow it worksCostAccuracy
Numerical gradientPerturb each weight: (L(w+ε) − L(w)) / εO(n) forward passesApprox (finite diff error)
BackpropagationOne backward pass traverses the graph onceO(1) backward passesExact (analytical)
✨ Insight · Backprop is just the chain rule applied to a computation graph — nothing more. The “magic” is that intermediate results computed during the forward pass can be reused to compute gradients in the backward pass, so each node is visited exactly once in each direction.

For a network with (GPT-3), numerical gradient estimation would require 175 billion forward passes per training step. Backprop needs one forward pass and one backward pass — , regardless of the number of parameters.

Quick check

Derivation

Numerical gradient estimation of a 7B-parameter model at 50 ms/forward would take how long per step, vs backprop at ~100 ms?

Numerical gradient estimation of a 7B-parameter model at 50 ms/forward would take how long per step, vs backprop at ~100 ms?
Quick Check

In the expression L = (a*b + c)², which input has the largest-magnitude gradient? Why?

📐

The Chain Rule — Formal Derivation

For a composition of functions , the chain rule states:

Univariate chain rule

For the multivariate case where , gradients accumulate from all downstream paths:

Multivariate chain rule (sum over all paths)

When a variable fans out to multiple downstream nodes (e.g., a weight used in multiple layers), its gradient is the sum of gradients from all downstream paths. This is why PyTorch accumulates gradients with += rather than overwriting.

For vector/matrix operations, the gradient generalizes to a Jacobian :

Jacobian matrix for f: R^n → R^m

🎯 Interview · Interview insight: Reverse-mode AD (backprop) works from the output backward, accumulating a row vector times Jacobian at each step. This is efficient when outputs are scalar (losses). Forward-mode AD accumulates a Jacobian times column vector — efficient when inputs are few. Most ML frameworks use reverse-mode because loss functions map R^n → R.

PyTorch autograd: building the same graph automatically

PyTorch builds the computation graph dynamically as you execute operations on tensors with requires_grad=True. Calling loss.backward() traverses it in reverse.

PyTorch implementation
import torch

# Define inputs with gradient tracking
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(-3.0, requires_grad=True)
c = torch.tensor(10.0, requires_grad=True)

# Forward pass — PyTorch builds the computation graph dynamically
d = a * b        # d = -6
e = d + c        # e = 4
L = e ** 2       # L = 16

# Backward pass — one call computes ALL gradients via reverse-mode AD
L.backward()

print(f"dL/da = {a.grad.item()}")  # -24.0  (chain rule: 2e * b = 2*4*-3)
print(f"dL/db = {b.grad.item()}")  # 16.0   (2e * a = 2*4*2)
print(f"dL/dc = {c.grad.item()}")  # 8.0    (2e * 1 = 2*4)

# Training loop: zero → forward → loss → backward → step
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for x_batch, y_batch in dataloader:
    optimizer.zero_grad()           # MUST clear before each step
    logits = model(x_batch)
    loss = torch.nn.functional.cross_entropy(logits, y_batch)
    loss.backward()                 # populate .grad for every parameter
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()                # w = w - lr * grad
⚠ Warning · Common pitfall: Always call optimizer.zero_grad() before loss.backward(). PyTorch accumulates gradients (+=) into .grad — if you forget to zero them, gradients from previous batches corrupt the update. This is intentional design for gradient accumulation across micro-batches, but a common source of bugs.

Quick check

Derivation

For z = x × y in a computation graph, what are the local gradients ∂z/∂x and ∂z/∂y, and what does this imply about gradient flow when one input is near zero?

For z = x × y in a computation graph, what are the local gradients ∂z/∂x and ∂z/∂y, and what does this imply about gradient flow when one input is near zero?
🔧

Break It — See What Happens

Replace backprop with numerical gradient estimation
💡 Tip · Gradient clipping in practice:

Quick check

Trade-off

If you remove gradient clipping from a deep network, what failure mode appears first and why?

If you remove gradient clipping from a deep network, what failure mode appears first and why?
📊

Real Gradient Computation Costs

in typical training loops. It does not redo the whole model blindly; it walks the same graph while multiplying local Jacobians and accumulating upstream gradients. , making it the standard tool for fine-tuning large models.

Note: forward/backward latencies below are illustrative order-of-magnitude estimates (single sequence, FP32, A100 GPU) — not sourced benchmarks. Actual times vary with batch size, sequence length, precision, and hardware.

ModelParametersForward (ms)Backward (ms)Numerical (estimated)
GPT-2124M~5 ms~10 ms~7.2 days
Llama-2 7B7B~50 ms~100 ms~11 years
GPT-3175B~500 ms~1,000 ms~2,800 years
✨ Insight · Why backward ≈ 2× forward:

Quick check

Derivation

Fine-tuning a 32-layer model: standard training stores all 32 layer activations. With gradient checkpointing at √32 ≈ 5 checkpoints, how does memory change?

Fine-tuning a 32-layer model: standard training stores all 32 layer activations. With gradient checkpointing at √32 ≈ 5 checkpoints, how does memory change?
🧠

Key Takeaways

What to remember for interviews

  1. 1Backprop is just the chain rule applied to a computation graph — each node's gradient is computed once from downstream gradients, visiting every edge exactly once in reverse.
  2. 2Reverse-mode AD (backprop) computes all gradients in one backward pass regardless of parameter count — O(1) passes vs O(n) forward passes for numerical gradient estimation.
  3. 3The backward pass costs ~2× the forward pass: it traverses the same graph but multiplies local Jacobians and accumulates upstream gradients at each node.
  4. 4When a variable fans out to multiple downstream nodes, its gradient is the sum of all downstream paths — this is why PyTorch accumulates gradients with += rather than overwriting.
  5. 5Gradient checkpointing (Chen et al., 2016) reduces memory to O(√n) for an n-layer network by recomputing activations on the fly, at ~30% extra compute cost in their √n-checkpoint schedule.
🧠

Recap quiz

Derivation

A network has N=175B parameters. How many FLOPs does one training step cost, and why is backward more expensive than forward?

A network has N=175B parameters. How many FLOPs does one training step cost, and why is backward more expensive than forward?
Derivation

Gradient checkpointing reduces activation memory from O(n) to O(√n) for an n-layer network. What is the compute overhead and why that specific fraction?

Gradient checkpointing reduces activation memory from O(n) to O(√n) for an n-layer network. What is the compute overhead and why that specific fraction?
Derivation

Numerical gradient estimation for GPT-2 (124M params, ~5 ms per forward pass) would take how long per training step?

Numerical gradient estimation for GPT-2 (124M params, ~5 ms per forward pass) would take how long per training step?
Trade-off

In PyTorch, a weight tensor is used in two different layers. How does backprop handle this fan-out, and what would happen if gradients were overwritten instead of accumulated?

In PyTorch, a weight tensor is used in two different layers. How does backprop handle this fan-out, and what would happen if gradients were overwritten instead of accumulated?
Trade-off

Reverse-mode AD (backprop) is optimal for loss functions mapping R^n → R. What computation pattern would make forward-mode AD more efficient?

Reverse-mode AD (backprop) is optimal for loss functions mapping R^n → R. What computation pattern would make forward-mode AD more efficient?
Trade-off

The backward pass costs ~2× the forward pass. If you halve the forward pass compute via mixed-precision inference, what happens to the backward/forward ratio during training?

The backward pass costs ~2× the forward pass. If you halve the forward pass compute via mixed-precision inference, what happens to the backward/forward ratio during training?
Derivation

An add gate in a computation graph fans out to three downstream operations. What is the local gradient of the add gate with respect to each of its inputs, and why?

An add gate in a computation graph fans out to three downstream operations. What is the local gradient of the add gate with respect to each of its inputs, and why?
📚

Further Reading

🎯

Interview Questions

Difficulty:
Company:

Showing 5 of 5

Why is backpropagation O(n) in the number of parameters, not O(n²)?

★★☆
GoogleMetaOpenAI

Think about how many times each weight is visited in the backward pass.

What is the vanishing gradient problem and when does it occur?

★★☆
GoogleAnthropicMetaOpenAI

Compare forward-mode vs reverse-mode automatic differentiation. When would you use each?

★★★
GoogleAnthropicOpenAI

Think about the shape of the Jacobian and which direction is cheaper to compute.

Explain gradient checkpointing and the tradeoff it makes.

★★★
MetaAnthropicGoogle

What does it mean for a gradient to be numerically stable, and why do log-space operations help?

★★☆
GoogleAnthropic