📉 Pre-training & Loss
GPT-3 trained on 300B tokens and never saw a single labeled example — yet learned grammar, facts, math, and reasoning from one loss function.
Every large language model — GPT-4, Claude, Llama — was trained with the same objective: predict the next token. This single loss function, applied to trillions of tokens, is what produces emergent capabilities like reasoning, translation, and code generation. This module covers how that works: cross-entropy loss, perplexity, and teacher forcing.
The Pre-training Loop
What you are seeing
The full pre-training pipeline: raw text enters on the left, gets tokenized into integer IDs, embedded into vectors, processed by the Transformer, and the final hidden state predicts the next token. Cross-entropy loss compares that prediction against the true next token. The teacher forcingarrow (orange) shows that during training, the ground-truth token — not the model's prediction — is fed back as the next input, keeping the training signal clean.
Next-Token Prediction
What you are seeing
The model predicts the next token for The cat sat on the ___. Each bar shows the probability the model assigns to that vocabulary word. The green bar is the ground truth — the loss measures how much probability mass the model puts on the correct answer.
What to try:Drag the temperature slider to see how it reshapes the sampling distribution. At low temperature the model becomes very confident; at high temperature the distribution flattens. Note: temperature is a sampling-time parameter — it does not affect the model's actual cross-entropy loss or perplexity, which are evaluated from the unscaled logits.
-log P (temp-scaled, not eval loss)
1.040
-log P("mat") = -log(0.354)
Effective choices (temp-scaled)
2.8
e^1.04 = 2.8
The Intuition
Next-token predictionis deceptively simple: given a sequence of tokens, predict what comes next. But to do this well, the model must learn grammar, facts, reasoning, and even common sense. The sentence "The Eiffel Tower is in ___" requires geographic knowledge. "If x > 5 and x < 3, then ___" requires logical reasoning.
Cross-entropy loss measures how surprised the model is by the true answer. If the model assigns high probability to the correct token, loss is low. If it spreads probability across many wrong tokens, loss is high. Minimizing cross-entropy is equivalent to maximizing the log-likelihood of the training data.
Perplexity = e^(loss). A perplexity of 20 means the model is, on average, as uncertain as choosing uniformly among 20 options. Lower is better — a perfect model has perplexity 1.
Teacher forcing: during training, the model always receives the ground truth previous token, not its own prediction. This prevents error compounding but creates a train/inference mismatch called exposure bias.
Loss Spikes and Training Stability:Real large-scale pretraining runs are not monotonically smooth. Google's PaLM paper (Chowdhery et al., 2022) documented roughly during training of the largest model. The paper explicitly notes they do not believe bad data batches caused the spikes — replaying the same batches from an earlier checkpoint did not reproduce them. The documented mitigation is to restart from a checkpoint ; this resolved every spike they encountered. Gradient clipping was already active during these runs. Spikes are not bugs — they are a normal feature of training at scale. The PaLM team also introduced z-loss, a regularizer that penalizes large logit magnitudes: where is the softmax partition function. This keeps the softmax numerically stable and reduced spike frequency significantly.
Quick check
A language model has PPL 20 on one benchmark and PPL 5 on another. What is the correct interpretation of this 4× difference?
What does a perplexity of 10 mean intuitively?
Step-by-Step Derivation
Cross-Entropy Loss
For a sequence of tokens, the loss is the negative log-likelihood of the correct token at each position:
Each term is the log-probability the model assigns to the correct token given all preceding tokens. The model outputs a probability distribution over the entire vocabulary at each position via softmax.
Perplexity
Perplexity exponentiates the average loss, converting nats into an interpretable "effective vocabulary size":
Softmax Output Layer
The final layer projects hidden states to vocabulary-sized logits, then applies softmax to get probabilities:
is often tied to the embedding matrix (weight tying), so . This reduces parameters and improves performance.
PyTorch implementation
import torch
import torch.nn.functional as F
def train_step(model, batch, optimizer):
"""Causal LM pretraining step with teacher forcing."""
input_ids = batch["input_ids"] # (B, T)
# Teacher forcing: input = [0..T-1], target = [1..T]
x = input_ids[:, :-1] # (B, T-1)
targets = input_ids[:, 1:] # (B, T-1)
# Forward pass
logits = model(x) # (B, T-1, V)
# Cross-entropy: predict next token at every position
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)), # (B*(T-1), V)
targets.reshape(-1), # (B*(T-1),)
ignore_index=-100, # skip padding
)
# Backward + gradient clip + step
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
return {"loss": loss.item(), "ppl": torch.exp(loss).item()}Quick check
The PyTorch training step shifts input_ids to create x = input_ids[:, :-1] and targets = input_ids[:, 1:]. What does this implement and why is it equivalent to causal masking?
Break It — See What Happens
Quick check
Label smoothing with epsilon=0.1 changes the target from a one-hot [0,…,1,…,0] to a mixture [(0.1/V), …, (0.9), …, (0.1/V)]. What is the correct tradeoff when epsilon is raised to 0.3?
Real-World Numbers
Selected published or commonly cited figures, with estimates labeled explicitly. These are useful for scale intuition rather than as a single canonical benchmark table. — worth knowing as a size anchor. and is now a common open pre-training corpus.
| Model | Perplexity | Training FLOPs | Tokens |
|---|---|---|---|
| GPT-2 (1.5B) | 8.63 (LAMBADA, 0-shot) | ~1.5e20 | |
| GPT-3 (175B) | 3.00 (LAMBADA, 0-shot) | 3.1e23 | |
| Llama-2 (70B) | ~4.1 (C4) | ~1e24 | |
| Llama-3 (70B) | Not public | ~7e24 (est.) | |
| GPT-4 (est.) | Not public | ~2e25 (est.) | ~13T (est.) |
Quick check
Chinchilla says 20 tokens per parameter is compute-optimal. For a 70B model, that is 1.4T tokens. Llama-2 70B trained on 2T tokens. What does the extra 0.6T tokens trade?
Key Takeaways
What to remember for interviews
- 1Next-token prediction forces the model to learn grammar, world knowledge, reasoning, and common sense — all implicitly compressed into one loss function.
- 2Cross-entropy loss measures how surprised the model is by the correct token; minimizing it is equivalent to maximizing the log-likelihood of the training data.
- 3Perplexity = e^(loss): PPL 10 means the model is as uncertain as choosing uniformly from 10 options. GPT-2 achieves 8.63 on LAMBADA (zero-shot); GPT-3 achieves 3.00.
- 4Teacher forcing feeds the ground-truth token at each training step, enabling parallel computation but creating exposure bias — a train/inference mismatch where errors compound at generation time.
- 5Loss spikes during large-scale pretraining are normal; PaLM (~20 spikes) mitigated them by restarting from a checkpoint ~100 steps before the spike and skipping ~200–500 batches, plus z-loss regularization.
Recap quiz
Why does next-token prediction on raw text produce emergent capabilities like code generation and multi-step reasoning, even though neither was supervised during training?
A model predicts token A with probability 0.001 when the correct answer is A. Cross-entropy penalizes this as -log(0.001) ≈ 6.9 nats. Why would MSE over the one-hot vector fail to give an equivalent gradient signal?
GPT-2 achieves 8.63 PPL on LAMBADA and GPT-3 achieves 3.00. How many bits-per-token does GPT-3 use, and what does the PPL gap imply about scale?
At inference, an autoregressive model generates token t from its own prediction at t-1. During training, it always sees the ground-truth t-1. What failure mode does this mismatch cause and which training strategy directly addresses it?
A causal (lower-triangular) attention mask is applied during pretraining. If you removed it and used full bidirectional attention, what specifically breaks in next-token prediction?
Chinchilla says compute-optimal training uses ~20 tokens per parameter. Llama-3 8B was trained on 15T tokens — roughly 1875× its parameter count. What is the engineering rationale for this extreme over-training?
PaLM introduced z-loss to reduce training instabilities. The loss is L_z = 10⁻⁴ · log²(Z), where Z is the softmax partition function. What does this penalize and why does it reduce loss spikes?
Further Reading
- Language Models are Unsupervised Multitask Learners (GPT-2) — Radford et al. 2019 — showed large-scale autoregressive pretraining produces strong zero-shot performance.
- Training Compute-Optimal Large Language Models (Chinchilla) — Hoffmann et al. 2022 — proved most LLMs were undertrained. Optimal ratio: ~20 tokens per parameter.
- Scaling Laws for Neural Language Models — Kaplan et al. 2020 — empirical power laws relating compute, data, and parameters to loss. Foundation of modern training budgets.
- Andrej Karpathy — Let's Build GPT — Implements GPT pretraining end-to-end on Shakespeare — best hands-on companion to understanding the training loop.
- Lilian Weng's Blog — Deep technical posts on LLM pretraining, scaling, and optimization — essential reference for training infrastructure.
- Language Models are Few-Shot Learners (GPT-3) — Brown et al. 2020 — the GPT-3 paper showing that large-scale pretraining enables strong few-shot task performance without fine-tuning.
- The Llama-3 Herd of Models — Meta 2024 — detailed pretraining recipe for Llama-3, including data curation, 15T tokens, and the multi-stage training pipeline.
Interview Questions
Showing 6 of 6
Why does next-token prediction work as a universal training objective? What linguistic and world knowledge does it force the model to learn?
★★☆Why do we use cross-entropy loss instead of MSE (mean squared error) for language modeling?
★★☆A model has perplexity 10 on a test set. What does this mean intuitively? How does perplexity relate to bits-per-character?
★★☆Explain teacher forcing. What is the exposure bias problem, and what are alternatives?
★★☆What is curriculum learning in the context of pre-training? Does the order of training data matter?
★★★What causes loss spikes during pre-training and how are they handled in practice?
★★★