Skip to content

Transformer Math

Module 12 · Training

📉 Pre-training & Loss

GPT-3 trained on 300B tokens and never saw a single labeled example — yet learned grammar, facts, math, and reasoning from one loss function.

Status:

Every large language model — GPT-4, Claude, Llama — was trained with the same objective: predict the next token. This single loss function, applied to trillions of tokens, is what produces emergent capabilities like reasoning, translation, and code generation. This module covers how that works: cross-entropy loss, perplexity, and teacher forcing.

📚

The Pre-training Loop

What you are seeing

The full pre-training pipeline: raw text enters on the left, gets tokenized into integer IDs, embedded into vectors, processed by the Transformer, and the final hidden state predicts the next token. Cross-entropy loss compares that prediction against the true next token. The teacher forcingarrow (orange) shows that during training, the ground-truth token — not the model's prediction — is fed back as the next input, keeping the training signal clean.

RawTextToken-izerTokenIDsEmbed-dingTrans-former(N layers)Next TokenPrediction(softmax)Cross-EntropyLossGround Truth“on” (id 319)Teacher Forcing(ground truth fed back as next input)Input: “The cat sat”Predict: “on”
🎮

Next-Token Prediction

What you are seeing

The model predicts the next token for The cat sat on the ___. Each bar shows the probability the model assigns to that vocabulary word. The green bar is the ground truth — the loss measures how much probability mass the model puts on the correct answer.

What to try:Drag the temperature slider to see how it reshapes the sampling distribution. At low temperature the model becomes very confident; at high temperature the distribution flattens. Note: temperature is a sampling-time parameter — it does not affect the model's actual cross-entropy loss or perplexity, which are evaluated from the unscaled logits.

The cat sat on the___
1.0

-log P (temp-scaled, not eval loss)

1.040

-log P("mat") = -log(0.354)

Effective choices (temp-scaled)

2.8

e^1.04 = 2.8

💡

The Intuition

Next-token predictionis deceptively simple: given a sequence of tokens, predict what comes next. But to do this well, the model must learn grammar, facts, reasoning, and even common sense. The sentence "The Eiffel Tower is in ___" requires geographic knowledge. "If x > 5 and x < 3, then ___" requires logical reasoning.

Cross-entropy loss measures how surprised the model is by the true answer. If the model assigns high probability to the correct token, loss is low. If it spreads probability across many wrong tokens, loss is high. Minimizing cross-entropy is equivalent to maximizing the log-likelihood of the training data.

Perplexity = e^(loss). A perplexity of 20 means the model is, on average, as uncertain as choosing uniformly among 20 options. Lower is better — a perfect model has perplexity 1.

Teacher forcing: during training, the model always receives the ground truth previous token, not its own prediction. This prevents error compounding but creates a train/inference mismatch called exposure bias.

✨ Insight · Why does predicting the next token teach everything? Because language is a compression of human knowledge. To predict well, you must model the data-generating process — which includes physics, psychology, logic, and every other pattern humans write about.

Loss Spikes and Training Stability:Real large-scale pretraining runs are not monotonically smooth. Google's PaLM paper (Chowdhery et al., 2022) documented roughly during training of the largest model. The paper explicitly notes they do not believe bad data batches caused the spikes — replaying the same batches from an earlier checkpoint did not reproduce them. The documented mitigation is to restart from a checkpoint ; this resolved every spike they encountered. Gradient clipping was already active during these runs. Spikes are not bugs — they are a normal feature of training at scale. The PaLM team also introduced z-loss, a regularizer that penalizes large logit magnitudes: where is the softmax partition function. This keeps the softmax numerically stable and reduced spike frequency significantly.

Quick check

Derivation

A language model has PPL 20 on one benchmark and PPL 5 on another. What is the correct interpretation of this 4× difference?

A language model has PPL 20 on one benchmark and PPL 5 on another. What is the correct interpretation of this 4× difference?
Quick Check

What does a perplexity of 10 mean intuitively?

📐

Step-by-Step Derivation

Cross-Entropy Loss

For a sequence of tokens, the loss is the negative log-likelihood of the correct token at each position:

Each term is the log-probability the model assigns to the correct token given all preceding tokens. The model outputs a probability distribution over the entire vocabulary at each position via softmax.

Perplexity

Perplexity exponentiates the average loss, converting nats into an interpretable "effective vocabulary size":

💡 Tip · . This means at each token, the model is effectively choosing from ~8–9 plausible candidates. Human perplexity on English text is estimated around 12–20 (varies by task and methodology).

Softmax Output Layer

The final layer projects hidden states to vocabulary-sized logits, then applies softmax to get probabilities:

is often tied to the embedding matrix (weight tying), so . This reduces parameters and improves performance.

PyTorch implementation
import torch
import torch.nn.functional as F

def train_step(model, batch, optimizer):
    """Causal LM pretraining step with teacher forcing."""
    input_ids = batch["input_ids"]          # (B, T)

    # Teacher forcing: input = [0..T-1], target = [1..T]
    x = input_ids[:, :-1]                   # (B, T-1)
    targets = input_ids[:, 1:]              # (B, T-1)

    # Forward pass
    logits = model(x)                       # (B, T-1, V)

    # Cross-entropy: predict next token at every position
    loss = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),  # (B*(T-1), V)
        targets.reshape(-1),                  # (B*(T-1),)
        ignore_index=-100,                    # skip padding
    )

    # Backward + gradient clip + step
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

    return {"loss": loss.item(), "ppl": torch.exp(loss).item()}

Quick check

Derivation

The PyTorch training step shifts input_ids to create x = input_ids[:, :-1] and targets = input_ids[:, 1:]. What does this implement and why is it equivalent to causal masking?

The PyTorch training step shifts input_ids to create x = input_ids[:, :-1] and targets = input_ids[:, 1:]. What does this implement and why is it equivalent to causal masking?
🔧

Break It — See What Happens

No teacher forcing (use model's own predictions)
Label smoothing too aggressive (epsilon = 0.3)

Quick check

Trade-off

Label smoothing with epsilon=0.1 changes the target from a one-hot [0,…,1,…,0] to a mixture [(0.1/V), …, (0.9), …, (0.1/V)]. What is the correct tradeoff when epsilon is raised to 0.3?

Label smoothing with epsilon=0.1 changes the target from a one-hot [0,…,1,…,0] to a mixture [(0.1/V), …, (0.9), …, (0.1/V)]. What is the correct tradeoff when epsilon is raised to 0.3?
📊

Real-World Numbers

Selected published or commonly cited figures, with estimates labeled explicitly. These are useful for scale intuition rather than as a single canonical benchmark table. — worth knowing as a size anchor. and is now a common open pre-training corpus.

ModelPerplexityTraining FLOPsTokens
GPT-2 (1.5B)8.63 (LAMBADA, 0-shot)~1.5e20
GPT-3 (175B)3.00 (LAMBADA, 0-shot)3.1e23
Llama-2 (70B)~4.1 (C4)~1e24
Llama-3 (70B)Not public~7e24 (est.)
GPT-4 (est.)Not public~2e25 (est.)~13T (est.)
✨ Insight · Perplexity keeps improving with scale, but the gains are logarithmic. Going from PPL 100 to 20 is "easy" (10x more compute). Going from 20 to 5 takes orders of magnitude more. This is the scaling laws story (Module 11). ; models like Llama-3 deliberately exceed this to reduce per-query inference cost. , giving a closed-form handle on total training compute.

Quick check

Trade-off

Chinchilla says 20 tokens per parameter is compute-optimal. For a 70B model, that is 1.4T tokens. Llama-2 70B trained on 2T tokens. What does the extra 0.6T tokens trade?

Chinchilla says 20 tokens per parameter is compute-optimal. For a 70B model, that is 1.4T tokens. Llama-2 70B trained on 2T tokens. What does the extra 0.6T tokens trade?
🧠

Key Takeaways

What to remember for interviews

  1. 1Next-token prediction forces the model to learn grammar, world knowledge, reasoning, and common sense — all implicitly compressed into one loss function.
  2. 2Cross-entropy loss measures how surprised the model is by the correct token; minimizing it is equivalent to maximizing the log-likelihood of the training data.
  3. 3Perplexity = e^(loss): PPL 10 means the model is as uncertain as choosing uniformly from 10 options. GPT-2 achieves 8.63 on LAMBADA (zero-shot); GPT-3 achieves 3.00.
  4. 4Teacher forcing feeds the ground-truth token at each training step, enabling parallel computation but creating exposure bias — a train/inference mismatch where errors compound at generation time.
  5. 5Loss spikes during large-scale pretraining are normal; PaLM (~20 spikes) mitigated them by restarting from a checkpoint ~100 steps before the spike and skipping ~200–500 batches, plus z-loss regularization.
🧠

Recap quiz

Trade-off

Why does next-token prediction on raw text produce emergent capabilities like code generation and multi-step reasoning, even though neither was supervised during training?

Why does next-token prediction on raw text produce emergent capabilities like code generation and multi-step reasoning, even though neither was supervised during training?
Derivation

A model predicts token A with probability 0.001 when the correct answer is A. Cross-entropy penalizes this as -log(0.001) ≈ 6.9 nats. Why would MSE over the one-hot vector fail to give an equivalent gradient signal?

A model predicts token A with probability 0.001 when the correct answer is A. Cross-entropy penalizes this as -log(0.001) ≈ 6.9 nats. Why would MSE over the one-hot vector fail to give an equivalent gradient signal?
Derivation

GPT-2 achieves 8.63 PPL on LAMBADA and GPT-3 achieves 3.00. How many bits-per-token does GPT-3 use, and what does the PPL gap imply about scale?

GPT-2 achieves 8.63 PPL on LAMBADA and GPT-3 achieves 3.00. How many bits-per-token does GPT-3 use, and what does the PPL gap imply about scale?
Trade-off

At inference, an autoregressive model generates token t from its own prediction at t-1. During training, it always sees the ground-truth t-1. What failure mode does this mismatch cause and which training strategy directly addresses it?

At inference, an autoregressive model generates token t from its own prediction at t-1. During training, it always sees the ground-truth t-1. What failure mode does this mismatch cause and which training strategy directly addresses it?
Derivation

A causal (lower-triangular) attention mask is applied during pretraining. If you removed it and used full bidirectional attention, what specifically breaks in next-token prediction?

A causal (lower-triangular) attention mask is applied during pretraining. If you removed it and used full bidirectional attention, what specifically breaks in next-token prediction?
Trade-off

Chinchilla says compute-optimal training uses ~20 tokens per parameter. Llama-3 8B was trained on 15T tokens — roughly 1875× its parameter count. What is the engineering rationale for this extreme over-training?

Chinchilla says compute-optimal training uses ~20 tokens per parameter. Llama-3 8B was trained on 15T tokens — roughly 1875× its parameter count. What is the engineering rationale for this extreme over-training?
Derivation

PaLM introduced z-loss to reduce training instabilities. The loss is L_z = 10⁻⁴ · log²(Z), where Z is the softmax partition function. What does this penalize and why does it reduce loss spikes?

PaLM introduced z-loss to reduce training instabilities. The loss is L_z = 10⁻⁴ · log²(Z), where Z is the softmax partition function. What does this penalize and why does it reduce loss spikes?
📚

Further Reading

🎯

Interview Questions

Difficulty:
Company:

Showing 6 of 6

Why does next-token prediction work as a universal training objective? What linguistic and world knowledge does it force the model to learn?

★★☆
GoogleOpenAI

Why do we use cross-entropy loss instead of MSE (mean squared error) for language modeling?

★★☆
GoogleMeta

A model has perplexity 10 on a test set. What does this mean intuitively? How does perplexity relate to bits-per-character?

★★☆
OpenAIAnthropic

Explain teacher forcing. What is the exposure bias problem, and what are alternatives?

★★☆
GoogleMeta

What is curriculum learning in the context of pre-training? Does the order of training data matter?

★★★
GoogleAnthropic

What causes loss spikes during pre-training and how are they handled in practice?

★★★
GoogleOpenAI