Skip to content

Transformer Math

Module 11 · Training

📐 Optimizers

AdamW fixes a 5-year-old bug in Adam that silently hurts generalization

Status:
🔄

Optimizer Evolution

Each optimizer in this lineage solves one specific failure mode of its predecessor. The timeline below shows the single key idea each one adds — understand these and you understand why AdamW is the common default for decoder-only LLM training recipes.

1SGD1951Adds:θ = θ − lr·∇LLimitation:zigzags in ravines2Momentum1964Adds:+ velocity (EMA of grads)Limitation:still slow in sparse dims3RMSProp2012Adds:+ per-param LR scalingLimitation:no momentum4Adam2014Adds:momentum + per-param LRLimitation:wrong weight decay5AdamW2019Adds:decoupled weight decayLimitation:none — current standardOptimizer memory (per param):1× (θ)2× (θ,v)2× (θ,v²)3× (θ,m,v)3× (θ,m,v)
🎯 Interview · Interview hook:The interviewer says “just use Adam”. You say “actually we use AdamW — Adam's weight decay is broken because the adaptive scaling shrinks it inconsistently across parameters. AdamW applies weight decay directly to weights, outside the adaptive term.” That's the answer that gets you the job.
🧭

Optimizer Trajectories

All three optimizers start at the same point on an elliptical loss landscape (steep in one axis, shallow in the other — the classic “narrow canyon” problem). Watch how each navigates toward the minimum.

What to notice: SGD wastes steps bouncing off canyon walls. Momentum smooths it but overshoots. Adam rescales each dimension independently and drives straight to the star.

w₁ (shallow curvature)w₂ (steep)startminSGDMomentumAdam
💡

Why SGD Fails — and What Each Fix Does

Imagine a loss landscape shaped like a long narrow canyon — steep walls on the sides, shallow slope along the bottom toward the minimum. This is common in deep networks because some parameters affect the loss much more strongly than others.

SGD — zigzags

Takes big steps perpendicular to the canyon (the steep direction) and tiny steps along the canyon floor (the progress direction). Oscillates wildly, converges slowly.

Momentum — smoother

Accumulates a velocity vector. The perpendicular oscillations cancel out (positive then negative gradients average to zero). The canyon-floor gradients accumulate, accelerating in the right direction.

Adam — direct

Divides each gradient by its historical magnitude. Steep dimensions get automatically smaller steps; shallow dimensions get larger steps. The effective step size is equalized across the landscape.

minimumstartSGD — zigzagovershoots steep dim,crawls along floorparameter θ₁ (shallow gradient)θ₂ (steep)
Quick Check

Why does momentum help in the narrow canyon scenario but not fully solve the problem?

📐

Update Rules

Each optimizer is defined entirely by its parameter update rule. Here are the exact equations, followed by the single thing each one adds.

1. Vanilla SGD

SGD update

One hyperparameter: learning rate . Problem: the same step size is applied to every parameter regardless of its gradient history.

2. SGD + Momentum

Momentum update (β typically 0.9)

Adds a velocity vector — an exponential moving average of past gradients. means the current gradient contributes 10%, all past gradients contribute 90%.

3. RMSProp

RMSProp (ρ typically 0.99)

Divides each gradient by the root mean squared of recent gradients for that parameter — effectively scaling the learning rate per-dimension. Parameters with high gradient variance (frequently updated) get smaller steps.

4. Adam (momentum + RMSProp + bias correction)

First and second moment estimates

Bias-corrected estimates

Adam parameter update (ε ≈ 1e-8)

Defaults: , (or 0.95 for LLMs), . The bias correction divides by which is small early in training, correcting the zero-initialization bias.

5. — the fix that matters

Adam weight decay (WRONG)

λθ is added to the gradient, then scaled by the adaptive term. Weight decay strength varies per parameter — inconsistent.

AdamW weight decay (CORRECT)

Weight decay applied directly to θ, outside the adaptive scaling. Every parameter decays at the same rate λ.

💡 Tip · LLM defaults: with is standard for LLMs. Lower β₂ makes the second moment adapt faster to gradient changes — important when training on diverse token distributions where gradient statistics shift rapidly.
PyTorch — AdamW + cosine schedule with warmup
python
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR

model = nn.Transformer(d_model=512, nhead=8)

# AdamW: lr=3e-4, β₁=0.9, β₂=0.95, weight_decay=0.1
# β₂=0.95 (not 0.999) is key for LLMs — faster adaptation to gradient changes
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    betas=(0.9, 0.95),
    weight_decay=0.1,
    eps=1e-8,
)

# Cosine decay with linear warmup
warmup_steps = 2000
total_steps = 100_000

def lr_lambda(step: int) -> float:
    if step < warmup_steps:
        return step / warmup_steps  # linear warmup
    # cosine decay from 1.0 to 0.1 of peak LR
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return 0.1 + 0.9 * 0.5 * (1 + torch.cos(torch.tensor(progress * 3.14159)).item())

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Training loop
for step, (x, y) in enumerate(dataloader):
    optimizer.zero_grad()
    out = model(src=x, tgt=y)       # returns (seq, batch, vocab) — not a loss
    loss = criterion(out.view(-1, out.size(-1)), labels.view(-1))
    loss.backward()
    # Gradient clipping — essential for transformer stability
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()

Quick check

Derivation

In Adam+L2, weight decay is applied inside the adaptive term. In AdamW, it is applied outside. For a parameter with √v̂ = 10, how does the effective per-step weight decay compare between the two?

In Adam+L2, weight decay is applied inside the adaptive term. In AdamW, it is applied outside. For a parameter with √v̂ = 10, how does the effective per-step weight decay compare between the two?
📊

What Real Models Use

These are representative published hyperparameters from well-known papers, condensed for comparison. They are useful defaults to remember, but they are paper-specific rather than universal rules.

ModelOptimizerPeak LRβ₁, β₂Weight DecaySchedule
GPT-3 175BAdam6e-50.9, 0.950.1cosine, 375M warmup tokens
Llama 2 70BAdamW0.9, 0.950.1cosine to 10% of peak,
Chinchilla 70BAdamW1e-40.9, 0.950.1cosine, linear warmup
PaLM 540BAdafactor0.01→1/√k0.9, 1−k⁻⁰·⁸lr² (dynamic)inverse sqrt decay; no factorization (memory-efficient)
GPT-2 (fine-tune)AdamW5e-50.9, 0.9990.01linear decay
✨ Insight · Pattern to memorize: Many modern decoder-only LLM training recipes use instead of the Adam default 0.999, because a lower β₂ lets the second moment adapt more quickly to changing gradient statistics. PaLM is the notable contrast here: it uses Adafactor to cut optimizer-state memory at very large scale.

Quick check

Trade-off

GPT-3 175B trains at peak LR 6e-5, Llama-2 70B at 3e-4 — a 5× difference. Given that larger models generally use smaller LRs, what explains Llama-2 using a higher peak LR despite being similar in scale?

GPT-3 175B trains at peak LR 6e-5, Llama-2 70B at 3e-4 — a 5× difference. Given that larger models generally use smaller LRs, what explains Llama-2 using a higher peak LR despite being similar in scale?
📈

Learning Rate Schedules

Learning rate schedules matter a lot in practice. A very common modern recipe is linear warmup → cosine decay, even though the exact schedule still varies by paper and training setup.

Peak LR (e.g. 3e-4)min LR (10% of peak)Warmuplinear rampCosine decaylr = lr_min + ½(lr_max−lr_min)(1+cos(πt))1.00.50.0LR (normalized)Training steps →

Why warmup?

Early in training, Adam's second moment v is estimated from very few gradient samples — it's noisy and unreliable. Warmup keeps the LR small until v accumulates stable statistics (~2000 steps), preventing large erratic updates at the start.

Why cosine?

Cosine decay is smooth and naturally slows near the end of training — the curve is steep initially (fast descent from the peak) and flat at the end (fine-grained convergence near the minimum). Step decay creates sharp jumps that can hurt.

Why decay to 10%, not 0?

Decaying to 0 can cause underfitting — the model stops learning before reaching its potential. keeps making progress without instability. This is the Llama-2 convention.

Quick check

Trade-off

At the end of cosine decay (t → T), the rate of LR change approaches zero. Step decay halves the LR abruptly at fixed intervals. Which effect does this difference produce near convergence?

At the end of cosine decay (t → T), the rate of LR change approaches zero. Step decay halves the LR abruptly at fixed intervals. Which effect does this difference produce near convergence?
🔥

Break It

Toggle these to see what breaks when you remove key design decisions from the optimizer stack.

Use SGD instead of Adam
Skip learning rate warmup

Quick check

Derivation

You skip warmup entirely on a 7B transformer. At step 1, β₂=0.95, g₁²=100. What is v̂₁?

You skip warmup entirely on a 7B transformer. At step 1, β₂=0.95, g₁²=100. What is v̂₁?
🧠

Key Takeaways

What to remember for interviews

  1. 1Adam's per-parameter learning rate scaling is what makes it the default for transformers: embedding rows for rare tokens can have very small gradient magnitudes while dense attention weights see large ones.
  2. 2Adam automatically scales each parameter's step size by its gradient history, so all parameters make meaningful progress regardless of how often they are updated.
  3. 3SGD treats every parameter identically — this is fatal for sparse input distributions like token embeddings.
🧠

Recap quiz

Trade-off

In standard Adam with L2 regularization, a parameter with a large historical gradient variance gets weight decay that is effectively _____ compared to a parameter with small gradient variance.

In standard Adam with L2 regularization, a parameter with a large historical gradient variance gets weight decay that is effectively _____ compared to a parameter with small gradient variance.
Trade-off

GPT-3, Llama-2, and Chinchilla all use β₂ = 0.95 instead of Adam&apos;s default 0.999. Which effect does this produce?

GPT-3, Llama-2, and Chinchilla all use β₂ = 0.95 instead of Adam&apos;s default 0.999. Which effect does this produce?
Derivation

A colleague removes the 2000-step warmup from a Llama-2 fine-tune run and sees loss diverge in the first 200 steps. The root cause is that Adam&apos;s second moment v starts at zero, making the bias-corrected v̂ very _____ early in training.

A colleague removes the 2000-step warmup from a Llama-2 fine-tune run and sees loss diverge in the first 200 steps. The root cause is that Adam&apos;s second moment v starts at zero, making the bias-corrected v̂ very _____ early in training.
Trade-off

Cosine decay reduces the learning rate gradually from peak to 10% of peak over training. Compared to step-decay (halve LR every N steps), what is cosine&apos;s key advantage near the minimum?

Cosine decay reduces the learning rate gradually from peak to 10% of peak over training. Compared to step-decay (halve LR every N steps), what is cosine&apos;s key advantage near the minimum?
Trade-off

Training a ResNet-50 on ImageNet, a researcher finds SGD with momentum achieves 77.0% top-1 accuracy while AdamW achieves 76.1%. The most cited explanation for SGD&apos;s edge in vision CNNs is:

Training a ResNet-50 on ImageNet, a researcher finds SGD with momentum achieves 77.0% top-1 accuracy while AdamW achieves 76.1%. The most cited explanation for SGD&apos;s edge in vision CNNs is:
Derivation

Adam initializes m₀ = 0 and v₀ = 0. After one gradient step with g₁, what is the bias-corrected m̂₁ for β₁ = 0.9?

Adam initializes m₀ = 0 and v₀ = 0. After one gradient step with g₁, what is the bias-corrected m̂₁ for β₁ = 0.9?
Trade-off

PaLM 540B uses Adafactor instead of AdamW. The primary reason at very large scale is:

PaLM 540B uses Adafactor instead of AdamW. The primary reason at very large scale is:
📚

Further Reading

🎯

Interview Questions

Difficulty:
Company:

Showing 5 of 5

Why is AdamW preferred over Adam for training large language models? What bug does it fix?

★★☆
OpenAIGoogleAnthropic

What is the learning rate warmup phase and why is it necessary for transformer training?

★★☆
GoogleAnthropic

Derive the bias correction terms in Adam. Why are m̂ and v̂ needed?

★★★
MetaGoogle

When would you prefer SGD with momentum over Adam for training? What are the tradeoffs?

★★☆
MetaOpenAI

What determines the optimal learning rate? How do practitioners find it?

★★☆
OpenAIAnthropicGoogle