🔥 GPU & Mixed Precision
bf16 training uses half the memory with zero accuracy loss — why wasn't this the default?
Training a frontier-scale model on a single A100 would be utterly impractical — rough estimates put it at tens of thousands of years, though exact figures depend on undisclosed training compute. The bottleneck isn't arithmetic — it's memory bandwidth. Understanding where data lives on the GPU and how number formats trade precision for speed is the difference between a 2× speedup and a 10×.
GPU Memory Hierarchy
What you're seeing: The five memory tiers on a modern GPU (e.g., A100). Each level is faster but smaller. The key insight: attention spends most of its time moving data between HBM and compute cores — not doing arithmetic.
Number Format Bit Layouts
Each format divides 32, 16, or 8 bits into sign, exponent, and mantissa. The exponent controls range; the mantissa controls precision.
Quick check
bf16 has fewer mantissa bits than fp16 (7 vs 10), yet all major LLMs are trained in bf16. What property makes bf16 the safer choice?
Mixed Precision Training
What breaks without it?
Training purely in fp32 wastes 2× memory and 2× bandwidth — tensor cores (the fast SIMD units on modern GPUs) are optimized for bf16/fp16 matmuls. Training purely in fp16 causes gradients to underflow to zero or overflow to NaN. Mixed precision gets the best of both.
The 3-step recipe
Stored in full precision. Small gradient updates (e.g., 1e-7) need fp32 to avoid rounding to zero when added to weights. Updated once per optimizer step.
Activations, intermediate tensors, and gradients computed in bf16. , 2× smaller activation memory.
Gradients cast to fp32 before AdamW moment accumulation. fp32 master weights updated. Downcast copy sent back to bf16 for next forward pass.
A model trained with fp16 mixed precision suddenly produces NaN losses on step 1000. What is the most likely cause?
Quick check
An A100 80GB delivers 312 TFLOP/s BF16 but only 19.5 TFLOP/s FP32 on tensor cores. How does mixed-precision training exploit this gap without sacrificing convergence?
Loss Scaling (fp16 Only)
. Gradients in deep networks are often tiny (1e-5 to 1e-8). Without scaling, they flush to zero — the model stops learning.
How loss scaling works:
- Multiply the loss by a large scalar (e.g., 2¹⁵ = 32768) before calling
.backward() - All gradients are scaled by — they're now in fp16's safe range
- Before the optimizer step, divide all gradients by to recover true gradient magnitudes
- If any gradient is NaN/Inf (overflow still occurred), skip the optimizer step and halve
Scaled loss and gradient
Optimizer step with unscaled gradient
PyTorch implementation:
torch.autocast + GradScaler (fp16)
import torch
from torch.cuda.amp import autocast, GradScaler
model = MyTransformer().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
# GradScaler handles loss scaling automatically for fp16
# Not needed for bf16 (same exponent range as fp32)
scaler = GradScaler()
for batch in dataloader:
inputs, labels = batch
optimizer.zero_grad()
# Forward + backward in fp16 automatically
with autocast(dtype=torch.float16):
logits = model(inputs)
loss = criterion(logits, labels)
# Scale loss → backward → unscale → optimizer step
scaler.scale(loss).backward() # grads are scaled
scaler.unscale_(optimizer) # unscale before clip
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer) # optimizer updates fp32 master weights
scaler.update() # adjust scale factor for next stepbf16 — no GradScaler needed
torch.autocast bf16 — simpler
# bf16 rarely needs GradScaler — fp32-range exponent makes overflow very unlikely
with autocast(dtype=torch.bfloat16):
logits = model(inputs)
loss = criterion(logits, labels)
loss.backward() # gradients in bf16, accumulated in fp32
optimizer.step() # master weights always fp32Quick check
GradScaler multiplies loss by S = 2¹⁵ = 32,768 before backward(). A gradient that would have been 1e-6 without scaling is now 32,768 × 1e-6 = 0.033 in fp16. Why is 0.033 safe but 1e-6 was not?
Break It — See What Happens
Real GPU Specs
All numbers from NVIDIA datasheets. Arithmetic intensity at ridge point = peak TFLOP/s ÷ peak HBM bandwidth.
| GPU | HBM | HBM BW | FP32 | TF32 | BF16 | FP8 | Ridge (BF16) |
|---|---|---|---|---|---|---|---|
| A100 SXM 80GB | 80 GB | 19.5 TF | 156 TF | 312 TF | — | 156 FLOPs/B | |
| H100 SXM 80GB | 80 GB | 67 TF | 756 TF | 989 TF | 295 FLOPs/B | ||
| H200 SXM 141GB | 141 GB | 4.8 TB/s | 67 TF | 756 TF | 989 TF | 1,979 TF | 206 FLOPs/B |
BF16 TFLOP/s comparison (tensor core peak)
Quick check
The H200 has the same bf16 TFLOP/s as the H100 (989 TF) but its ridge point drops from 295 to 206 FLOP/byte due to higher bandwidth. What does a lower ridge point mean operationally?
Key Takeaways
What to remember for interviews
- 1Mixed precision uses bf16 for forward/backward (2x faster tensor cores, 2x less memory) and fp32 master weights + optimizer states to prevent gradient underflow.
- 2bf16 is preferred over fp16 for training: same 8-bit exponent as fp32 gives the same dynamic range (~3.4e38 max), making overflow extremely rare in practice and loss scaling unnecessary.
- 3The roofline model bounds performance by min(peak FLOPS, bandwidth × arithmetic intensity). Standard attention has intensity ~1 — far below the ridge — making it memory-bound.
- 4Flash Attention tiles computation into SRAM blocks to avoid materializing the N×N matrix in HBM, moving attention from memory-bound toward compute-bound.
- 5Activation checkpointing (gradient checkpointing) trades ~33% extra compute for O(sqrt(L)) activation memory instead of O(L) — essential for training large models.
Further Reading
- Andrej Karpathy — Zero To Hero: Building GPT (Device + Precision chapters) — Karpathy's hands-on walkthrough of moving training to GPU and adding mixed precision — the practical foundation for these concepts.
- NVIDIA Mixed Precision Training Guide — Official NVIDIA guide covering loss scaling, tensor cores, and the AMP workflow with benchmarks.
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness — Dao et al. 2022 — the paper that introduced tiled SRAM attention and made the roofline model central to ML systems work.
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning — Dao 2023 — extends Flash Attention with improved thread block partitioning, achieving ~2× additional speedup on H100.
- Lilian Weng — Large Transformer Model Inference Optimization — Detailed breakdown of memory hierarchy, arithmetic intensity, and precision tradeoffs with worked numerical examples.
Interview Questions
Showing 5 of 5
Why is bf16 preferred over fp16 for training large language models?
★☆☆Explain the roofline model and what it predicts about attention computation.
★★☆What is arithmetic intensity, and how does it govern GPU utilization?
★★☆Why does Flash Attention help with the memory bandwidth bottleneck?
★★☆How does gradient accumulation help with limited GPU memory, and what are its tradeoffs?
★☆☆Recap quiz
An A100 80GB runs bf16 tensor-core matmuls at 312 TFLOP/s but fp32 at only 19.5 TFLOP/s. What is the primary reason mixed-precision training does NOT lose accuracy despite running forward+backward in bf16?
Training in fp16 with GradScaler disabled fails within ~100 steps. What is the precise failure chain?
Standard self-attention has arithmetic intensity ≈ 1 FLOP/byte. The A100's ridge point is ~156 FLOP/byte. What does this imply about standard attention performance relative to the GPU's theoretical peak?
An engineer proposes replacing bf16 mixed precision with fp16 mixed precision to gain extra mantissa precision (10 vs 7 bits). What is the main operational cost?
The H200 SXM has the same BF16 TFLOP/s as the H100 (989 TF) but 4.8 TB/s vs 3.35 TB/s HBM bandwidth. Which workload sees the largest real-world speedup moving from H100 to H200?