Skip to content

Transformer Math

Module 18 · Training

🎯 SFT & Post-Training Pipeline

InstructGPT’s SFT used 13K labeled examples — and that alone beat GPT-3 (175B base) on human preference. Why does so little labeled data do so much?

Status:

Pre-training gives you a language model. Post-training turns it into an assistant. The recipe: SFT on instructions, rejection sampling to boost quality, then preference optimization (DPO/RLHF) for alignment. Every production chat model — GPT-4, Claude, Llama — follows this pipeline.

🔧

The Post-Training Pipeline

Every production chat model follows the same six-stage recipe. The diagram shows the data format consumed at each stage and real model examples.

BaseModelSFTRLHF /DPOSafetyTrainingDeployedinstruction-response pairspreference comparisonsred-team prompts +constitutional AIGPT-4 baseClaude baseGPT-4 base→ GPT-4-turboClaude 3GPT-4Claude base → Claude 3 follows the same path
🎮

Post-Training Pipeline

What you are seeing: The 6 stages of modern post-training. Each stage takes the model closer to a deployable assistant.

What to try: Click each stage to see what happens — what goes in, what comes out, and why it matters.

Post-Training PipelineClick a stage to learn what happens at each stepBase ModelPre-trained on next-token prediction1SFTInstruction fine-tuning2Rejection SamplingGenerate N, keep best3DPO / RLHFPreference optimization4EvalBenchmark + human eval5DeployShip to production6Iterate with new data

Stage 2: SFT

Supervised Fine-Tuning on (instruction, response) pairs. Loss is only computed on assistant turns (loss masking). Teaches format, tone, and instruction-following. Typically 1K-50K high-quality examples.

💡

The Intuition

SFT teaches format and instruction-following. A pre-trained model completes text; an SFT model answers questions. The training data is simple: pairs of (instruction, ideal response). The model learns the style — concise, helpful, in the right format — not new knowledge.

Loss masking: In multi-turn data, you only compute loss on assistant tokens. User turns and system prompts are masked (labels set to -100). The model should learn to respond, not to generate user queries.

Chat templates define the exact token format. ChatML uses <|im_start|>assistant. Llama uses [INST] and <<SYS>>. The model is trained on a specific template — using the wrong one at inference degrades quality significantly.

Rejection samplingamplifies quality: generate N responses per prompt (e.g., N=64), score with a reward model, keep only the best. This creates a dataset that is strictly better than the model's average output — cheap data augmentation.

Distillationtransfers capability from a large teacher to a smaller student. The student trains to match the teacher's output distribution (not just the argmax), preserving the teacher's uncertainty via KL divergence loss. This is how many "small but capable" models are created. A related but distinct approach is synthetic data generation: generate instruction/response pairs from a larger model, then train on those as standard SFT (without the KL-divergence loss). Alpaca (7B) used this approach — fine-tuned on generated from text-davinci-003, not KL-based distribution matching.

Sequence packing vs. padding. Naively padding every example to the maximum sequence length wastes compute proportional to the padding fraction. Instead, production SFT pipelines pack multiple short examples end-to-end into a single sequence up to the context limit, with a special separator token between them. The attention mask uses block-diagonal structure so examples within a packed sequence cannot attend to each other — they are causally independent. This removes padding waste entirely and can improve GPU utilization by on datasets with variable length (Phi-3 Technical Report, Abdin et al. 2024). The tradeoff: you must ensure the block-diagonal attention mask is implemented correctly; a bug here silently allows cross-example attention leakage, where the model conditions on a different training example mid-sequence — the worst kind of data leakage because it inflates training metrics while degrading real-world performance.

✨ Insight · The complete modern recipe: Pretrain (trillions of tokens) → SFT () → Rejection Sampling (amplify quality) → DPO/RLHF (align to preferences) → Eval → Deploy. Each step is cheaper than the last but contributes disproportionately to user experience.

Quick check

Derivation

A pre-trained model generates plausible continuations. After SFT, it answers questions instead. What is the primary mechanism by which SFT changes this behavior?

A pre-trained model generates plausible continuations. After SFT, it answers questions instead. What is the primary mechanism by which SFT changes this behavior?
Quick Check

Why do we mask the loss on user turns during SFT?

📐

Step-by-Step Derivation

SFT Loss with Masking

Standard cross-entropy, but only on assistant tokens. The mask is 1 for assistant tokens and 0 for user/system tokens:

💡 Tip · In practice, set labels to -100 for masked positions — PyTorch's CrossEntropyLoss ignores these by default. No need for an explicit mask tensor.

Rejection Sampling (Best-of-N)

Generate completions for prompt , score each with reward model , select the best:

The expected reward improves as — diminishing returns, but even N=4 gives a meaningful boost. Llama-2 used rejection sampling scored by its between SFT and RLHF stages.

Distillation Loss (KL Divergence)

The student learns to match the teacher 's output distribution, weighted by temperature :

Higher temperature softens the distributions, letting the student learn from the teacher's full probability mass — not just the top token. The factor compensates for the reduced gradient magnitude at high temperatures.

PyTorch implementation
# SFT training loop with instruction masking
# Only compute loss on assistant tokens (labels=-100 for user/system turns)
import torch, torch.nn.functional as F

def sft_loss(model, input_ids, labels, attention_mask):
    """labels: -100 for user/system tokens, token_id for assistant tokens."""
    logits = model(input_ids, attention_mask=attention_mask).logits  # (B, T, V)
    shift_logits = logits[:, :-1].contiguous()      # predict next token
    shift_labels = labels[:, 1:].contiguous()
    # ignore_index=-100 automatically masks non-assistant positions
    return F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        ignore_index=-100,
    )

# Mask helper: only keep assistant-turn token ids in labels
def mask_user_turns(input_ids, assistant_ranges):
    labels = input_ids.clone()
    mask = torch.zeros_like(input_ids, dtype=torch.bool)
    for start, end in assistant_ranges:
        mask[:, start:end] = True
    labels[~mask] = -100
    return labels

PyTorch: SFT Training Loop with Loss Masking

python
import torch
import torch.nn.functional as F

def sft_train_step(model, batch, optimizer):
    """SFT step with loss masking on assistant tokens only."""
    input_ids = batch["input_ids"]          # (B, T)
    labels = batch["labels"]                # (B, T) — user/system tokens = -100
    attention_mask = batch["attention_mask"] # (B, T)

    # Forward pass
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits  # (B, T, V)

    # Shift for next-token prediction
    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = labels[:, 1:].contiguous()

    # Cross-entropy with ignore_index=-100 (auto-masks user turns)
    loss = F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        ignore_index=-100,
    )

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()
    return loss.item()

# Label preparation: mask everything except assistant responses
def prepare_labels(input_ids, assistant_mask):
    """Set labels to -100 for non-assistant tokens."""
    labels = input_ids.clone()
    labels[~assistant_mask] = -100
    return labels

Quick check

Trade-off

In rejection sampling (Best-of-N), you generate N completions and keep the highest-scoring one. If the reward model is miscalibrated and ranks completions by verbosity rather than quality, what happens to the selected outputs over multiple SFT iterations?

In rejection sampling (Best-of-N), you generate N completions and keep the highest-scoring one. If the reward model is miscalibrated and ranks completions by verbosity rather than quality, what happens to the selected outputs over multiple SFT iterations?
🔧

Break It — See What Happens

No loss masking (train on user turns too)
SFT on low-quality data

Quick check

Trade-off

You remove loss masking from SFT training — the model now trains on all tokens (user, system, and assistant). After fine-tuning, the model is used in a chat application. What is the most likely observable failure mode?

You remove loss masking from SFT training — the model now trains on all tokens (user, system, and assistant). After fine-tuning, the model is used in a chat application. What is the most likely observable failure mode?
📊

Real-World Numbers

Model / DatasetSFT ExamplesKey Insight
InstructGPTPioneered the SFT + RLHF pipeline. ~13K SFT demos + , .
Llama-2 ChatVendor-collected. Used rejection sampling (scored by ) between SFT and RLHF.
LIMAExpert-curated. Matched much larger datasets, proving quality beats quantity.
AlpacaSelf-Instruct from GPT-3.5. Popularized open-source SFT, but noisy data.
Dolly 2.0Databricks employees. First commercially-licensed human-generated SFT dataset.
✨ Insight · The number of SFT examples across production models ranges from 1K to 50K — orders of magnitude less than pre-training data. SFT teaches style, not knowledge. Quality and diversity of examples matter far more than raw count.

Quick check

Trade-off

Llama-2 Chat used ~27K SFT examples and applied rejection sampling before RLHF. LIMA achieved comparable instruction-following with 1K. Why did Meta collect 27K rather than stopping at 1K?

Llama-2 Chat used ~27K SFT examples and applied rejection sampling before RLHF. LIMA achieved comparable instruction-following with 1K. Why did Meta collect 27K rather than stopping at 1K?
🧠

Key Takeaways

What to remember for interviews

  1. 1SFT teaches format and instruction-following — not new knowledge. The model learns the style of a helpful assistant from (instruction, response) pairs.
  2. 2Loss masking is critical: set labels to -100 for user/system tokens so gradients only flow through assistant turns.
  3. 3Chat templates define the exact token delimiters (ChatML, Llama format). Using the wrong template at inference degrades quality significantly.
  4. 4Rejection sampling (Best-of-N) amplifies quality cheaply: generate N responses, keep the best-scored one. E[max reward] scales as O(√log N) — even N=4 gives a meaningful boost.
  5. 5Quality beats quantity in SFT: LIMA showed 1,000 expert-curated examples match 52K noisy ones. Sequence packing (vs. padding) improves GPU utilization 30-50%.
🧠

Recap Quiz

Derivation

LIMA matched Alpaca (52K) with only 1,000 examples. Which claim best explains why fewer examples can suffice for SFT?

LIMA matched Alpaca (52K) with only 1,000 examples. Which claim best explains why fewer examples can suffice for SFT?
Trade-off

InstructGPT used ~13K SFT demos plus ~33K preference comparisons. Why are the preference pairs more than 2× the SFT data?

InstructGPT used ~13K SFT demos plus ~33K preference comparisons. Why are the preference pairs more than 2× the SFT data?
Derivation

Sequence packing improves GPU utilization by 30–50% on variable-length SFT datasets. What can go wrong if the attention mask is not correctly implemented as block-diagonal?

Sequence packing improves GPU utilization by 30–50% on variable-length SFT datasets. What can go wrong if the attention mask is not correctly implemented as block-diagonal?
Trade-off

A team switches from SFT on 50K machine-generated examples to SFT on 2K human-expert examples. They observe lower training loss on the 50K set but higher win-rate on human evals for the 2K model. What best explains this?

A team switches from SFT on 50K machine-generated examples to SFT on 2K human-expert examples. They observe lower training loss on the 50K set but higher win-rate on human evals for the 2K model. What best explains this?
Derivation

During SFT inference, a model was fine-tuned with the ChatML template but is called with the Llama chat template at deployment. What failure mode do you expect?

During SFT inference, a model was fine-tuned with the ChatML template but is called with the Llama chat template at deployment. What failure mode do you expect?
Derivation

Rejection sampling (Best-of-N) improves expected reward as O(√log N). If N=4 gives a meaningful boost, why not use N=1024 for maximum quality?

Rejection sampling (Best-of-N) improves expected reward as O(√log N). If N=4 gives a meaningful boost, why not use N=1024 for maximum quality?
📚

Further Reading

🎯

Interview Questions

Difficulty:
Company:

Showing 6 of 6

What is the role of SFT vs RLHF in the post-training pipeline? Why do you need both?

★★☆
OpenAIAnthropic

Explain loss masking in SFT. Why do we only compute loss on assistant turns, not user turns?

★★☆
GoogleMeta

Compare ChatML and Llama chat template formats. Why do chat templates matter?

★☆☆
OpenAIMeta

How does rejection sampling improve post-training? Walk through the process.

★★☆
MetaAnthropic

Data quality vs quantity in SFT: LIMA used 1,000 examples and matched models trained on 52K. Why?

★★☆
MetaGoogle

When should you SFT a model vs just use prompt engineering? What are the tradeoffs?

★★☆
GoogleOpenAI