Skip to content

Transformer Math

Module 15 · Training

🔥 GPU & Mixed Precision

bf16 training uses half the memory with zero accuracy loss — why wasn't this the default?

Status:

Training a frontier-scale model on a single A100 would be utterly impractical — rough estimates put it at tens of thousands of years, though exact figures depend on undisclosed training compute. The bottleneck isn't arithmetic — it's memory bandwidth. Understanding where data lives on the GPU and how number formats trade precision for speed is the difference between a 2× speedup and a 10×.

🔺

GPU Memory Hierarchy

What you're seeing: The five memory tiers on a modern GPU (e.g., A100). Each level is faster but smaller. The key insight: attention spends most of its time moving data between HBM and compute cores — not doing arithmetic.

Registers~256KB / SMthread-local, ~cycle accessShared Mem / L1~192 KB / SM (A100)~19 TB/s · ~5 cyclesL2 Cache~40 MB~6 TB/s · ~100 cyclesHBM / Global Memory80 GB~2–3.35 TB/sCPU RAM (via PCIe)up to ~TB~32–64 GB/sFaster & SmallerLarger & Slower
✨ Insight · Attention is memory-bandwidth-bound, not compute-bound. Standard self-attention reads/writes O(N²) elements for O(N²) ops — . An A100's ridge point is . Flash Attention fixes this by keeping tiles in , achieving a .
🔢

Number Format Bit Layouts

Each format divides 32, 16, or 8 bits into sign, exponent, and mantissa. The exponent controls range; the mantissa controls precision.

fp321/8/23 bitsSexponent (8)mantissa (23 bits)range: ±3.4×10³⁸precision: ~7 decimal digitsfp161/5/10 bitsSexp(5)mantissa (10)⚠ max: 65,504overflow risk → needs loss scalingprecision: ~3.3 decimal digitsbf161/8/7 bitsSexponent (8)mantissa (7)✓ same range as fp32no loss scaling neededless precision: ~2.3 digitsfp8 E4M31/4/3 bitsSe(4)m(3)H100 Transformer Enginerequires per-tensor scaling~2× throughput vs bf16
🎯 Interview · Interview trap: bf16 has less mantissa precisionthan fp16 (7 vs 10 bits), yet it's preferred for training. Why? Because the 8-bit exponent gives it fp32's dynamic range — , overflow is extremely rare for bf16. Precision matters far less than avoiding NaN explosions during training.

Quick check

Trade-off

bf16 has fewer mantissa bits than fp16 (7 vs 10), yet all major LLMs are trained in bf16. What property makes bf16 the safer choice?

bf16 has fewer mantissa bits than fp16 (7 vs 10), yet all major LLMs are trained in bf16. What property makes bf16 the safer choice?
💡

Mixed Precision Training

What breaks without it?

Training purely in fp32 wastes 2× memory and 2× bandwidth — tensor cores (the fast SIMD units on modern GPUs) are optimized for bf16/fp16 matmuls. Training purely in fp16 causes gradients to underflow to zero or overflow to NaN. Mixed precision gets the best of both.

The 3-step recipe

① Master weights: fp32

Stored in full precision. Small gradient updates (e.g., 1e-7) need fp32 to avoid rounding to zero when added to weights. Updated once per optimizer step.

② Forward + backward: bf16

Activations, intermediate tensors, and gradients computed in bf16. , 2× smaller activation memory.

③ Optimizer step: fp32

Gradients cast to fp32 before AdamW moment accumulation. fp32 master weights updated. Downcast copy sent back to bf16 for next forward pass.

Forward Passbf16 / fp16Backward Passbf16 / fp16Optimizer Stepfp32 master wtslossgradients
Quick Check

A model trained with fp16 mixed precision suddenly produces NaN losses on step 1000. What is the most likely cause?

Quick check

Trade-off

An A100 80GB delivers 312 TFLOP/s BF16 but only 19.5 TFLOP/s FP32 on tensor cores. How does mixed-precision training exploit this gap without sacrificing convergence?

An A100 80GB delivers 312 TFLOP/s BF16 but only 19.5 TFLOP/s FP32 on tensor cores. How does mixed-precision training exploit this gap without sacrificing convergence?
📐

Loss Scaling (fp16 Only)

. Gradients in deep networks are often tiny (1e-5 to 1e-8). Without scaling, they flush to zero — the model stops learning.

⚠ Warning · Why bf16 doesn't need this:bf16 shares fp32's 8-bit exponent, so it can represent values down to ~1.2×10⁻³⁸. Gradients of 1e-8 are representable. fp16's 5-bit exponent only reaches ~6×10⁻⁸ — borderline at best, zero at worst.

How loss scaling works:

  1. Multiply the loss by a large scalar (e.g., 2¹⁵ = 32768) before calling .backward()
  2. All gradients are scaled by — they're now in fp16's safe range
  3. Before the optimizer step, divide all gradients by to recover true gradient magnitudes
  4. If any gradient is NaN/Inf (overflow still occurred), skip the optimizer step and halve

Scaled loss and gradient

Optimizer step with unscaled gradient

PyTorch implementation:

torch.autocast + GradScaler (fp16)

python
import torch
from torch.cuda.amp import autocast, GradScaler

model = MyTransformer().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

# GradScaler handles loss scaling automatically for fp16
# Not needed for bf16 (same exponent range as fp32)
scaler = GradScaler()

for batch in dataloader:
    inputs, labels = batch
    optimizer.zero_grad()

    # Forward + backward in fp16 automatically
    with autocast(dtype=torch.float16):
        logits = model(inputs)
        loss = criterion(logits, labels)

    # Scale loss → backward → unscale → optimizer step
    scaler.scale(loss).backward()   # grads are scaled
    scaler.unscale_(optimizer)      # unscale before clip
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)          # optimizer updates fp32 master weights
    scaler.update()                 # adjust scale factor for next step
bf16 — no GradScaler needed

torch.autocast bf16 — simpler

python
# bf16 rarely needs GradScaler — fp32-range exponent makes overflow very unlikely
with autocast(dtype=torch.bfloat16):
    logits = model(inputs)
    loss = criterion(logits, labels)

loss.backward()  # gradients in bf16, accumulated in fp32
optimizer.step() # master weights always fp32

Quick check

Derivation

GradScaler multiplies loss by S = 2¹⁵ = 32,768 before backward(). A gradient that would have been 1e-6 without scaling is now 32,768 × 1e-6 = 0.033 in fp16. Why is 0.033 safe but 1e-6 was not?

GradScaler multiplies loss by S = 2¹⁵ = 32,768 before backward(). A gradient that would have been 1e-6 without scaling is now 32,768 × 1e-6 = 0.033 in fp16. Why is 0.033 safe but 1e-6 was not?
🔧

Break It — See What Happens

Train in fp32 only (no mixed precision)
Train in fp16 without loss scaling (GradScaler disabled)
📊

Real GPU Specs

All numbers from NVIDIA datasheets. Arithmetic intensity at ridge point = peak TFLOP/s ÷ peak HBM bandwidth.

GPUHBMHBM BWFP32TF32BF16FP8Ridge (BF16)
A100 SXM 80GB80 GB19.5 TF156 TF312 TF156 FLOPs/B
H100 SXM 80GB80 GB67 TF756 TF989 TF295 FLOPs/B
H200 SXM 141GB141 GB4.8 TB/s67 TF756 TF989 TF1,979 TF206 FLOPs/B
✨ Insight · H200 vs H100: Same compute (identical die), but . This directly speeds up memory-bound workloads like attention and token generation. Compute-bound workloads (large matmuls) see no improvement.

BF16 TFLOP/s comparison (tensor core peak)

A100 80GB
312 TF
H100 80GB
989 TF
H200 141GB
989 TF

Quick check

Derivation

The H200 has the same bf16 TFLOP/s as the H100 (989 TF) but its ridge point drops from 295 to 206 FLOP/byte due to higher bandwidth. What does a lower ridge point mean operationally?

The H200 has the same bf16 TFLOP/s as the H100 (989 TF) but its ridge point drops from 295 to 206 FLOP/byte due to higher bandwidth. What does a lower ridge point mean operationally?
🧠

Key Takeaways

What to remember for interviews

  1. 1Mixed precision uses bf16 for forward/backward (2x faster tensor cores, 2x less memory) and fp32 master weights + optimizer states to prevent gradient underflow.
  2. 2bf16 is preferred over fp16 for training: same 8-bit exponent as fp32 gives the same dynamic range (~3.4e38 max), making overflow extremely rare in practice and loss scaling unnecessary.
  3. 3The roofline model bounds performance by min(peak FLOPS, bandwidth × arithmetic intensity). Standard attention has intensity ~1 — far below the ridge — making it memory-bound.
  4. 4Flash Attention tiles computation into SRAM blocks to avoid materializing the N×N matrix in HBM, moving attention from memory-bound toward compute-bound.
  5. 5Activation checkpointing (gradient checkpointing) trades ~33% extra compute for O(sqrt(L)) activation memory instead of O(L) — essential for training large models.
📚

Further Reading

🎯

Interview Questions

Difficulty:
Company:

Showing 5 of 5

Why is bf16 preferred over fp16 for training large language models?

★☆☆
GoogleAnthropic

Explain the roofline model and what it predicts about attention computation.

★★☆
GoogleMeta

What is arithmetic intensity, and how does it govern GPU utilization?

★★☆
MetaOpenAI

Why does Flash Attention help with the memory bandwidth bottleneck?

★★☆
AnthropicOpenAI

How does gradient accumulation help with limited GPU memory, and what are its tradeoffs?

★☆☆
GoogleMeta
🧠

Recap quiz

Trade-off

An A100 80GB runs bf16 tensor-core matmuls at 312 TFLOP/s but fp32 at only 19.5 TFLOP/s. What is the primary reason mixed-precision training does NOT lose accuracy despite running forward+backward in bf16?

An A100 80GB runs bf16 tensor-core matmuls at 312 TFLOP/s but fp32 at only 19.5 TFLOP/s. What is the primary reason mixed-precision training does NOT lose accuracy despite running forward+backward in bf16?
Derivation

Training in fp16 with GradScaler disabled fails within ~100 steps. What is the precise failure chain?

Training in fp16 with GradScaler disabled fails within ~100 steps. What is the precise failure chain?
Derivation

Standard self-attention has arithmetic intensity ≈ 1 FLOP/byte. The A100's ridge point is ~156 FLOP/byte. What does this imply about standard attention performance relative to the GPU's theoretical peak?

Standard self-attention has arithmetic intensity ≈ 1 FLOP/byte. The A100's ridge point is ~156 FLOP/byte. What does this imply about standard attention performance relative to the GPU's theoretical peak?
Trade-off

An engineer proposes replacing bf16 mixed precision with fp16 mixed precision to gain extra mantissa precision (10 vs 7 bits). What is the main operational cost?

An engineer proposes replacing bf16 mixed precision with fp16 mixed precision to gain extra mantissa precision (10 vs 7 bits). What is the main operational cost?
Trade-off

The H200 SXM has the same BF16 TFLOP/s as the H100 (989 TF) but 4.8 TB/s vs 3.35 TB/s HBM bandwidth. Which workload sees the largest real-world speedup moving from H100 to H200?

The H200 SXM has the same BF16 TFLOP/s as the H100 (989 TF) but 4.8 TB/s vs 3.35 TB/s HBM bandwidth. Which workload sees the largest real-world speedup moving from H100 to H200?