Skip to content

Transformer Math

Module 16 · Training

🖥️ Distributed Training

A 70B-param model needs 280 GB just to hold weights at fp32 — no single GPU exists with that memory. Here's how training across 1,000 GPUs stays in sync.

Status:

A single GPU can train small models, but modern LLMs have billions of parameters — they don't even fit in one GPU's memory. Distributed training splits the work across hundreds or thousands of GPUs. The challenge: keeping them all in sync while minimizing communication overhead.

🔀

Parallelism Strategies Compared

Data ParallelTensor ParallelPipeline ParallelGPU0L1L2L3L4B0GPU1L1L2L3L4B1GPU2L1L2L3L4B2GPU3L1L2L3L4B3all-reduce gradientsGPU0L1d/4L2d/4L3d/4L4d/4GPU1L1d/4L2d/4L3d/4L4d/4GPU2L1d/4L2d/4L3d/4L4d/4GPU3L1d/4L2d/4L3d/4L4d/4all-reduce activationsGPU0L1 – L8GPU1L9 – L16GPU2L17 – L24GPU3L25 – L32activations between stagesCommunication: gradientsCommunication: activationsCommunication: activationsbetween stages
🎮

Three Parallelism Strategies

What you're seeing: Three ways to distribute a 4-layer model across 4 GPUs. Each strategy splits a different dimension of the computation.

Data Parallel

Same model, split data

GPU 0
L1
L2
L3
L4
Batch 0
GPU 1
L1
L2
L3
L4
Batch 1
GPU 2
L1
L2
L3
L4
Batch 2
GPU 3
L1
L2
L3
L4
Batch 3

Each GPU has the full model.
Data is split across GPUs.
Gradients synced via all-reduce.

Tensor Parallel

Split layers across GPUs

GPU 0
L1/0
L2/0
L3/0
L4/0
GPU 1
L1/1
L2/1
L3/1
L4/1
GPU 2
L1/2
L2/2
L3/2
L4/2
GPU 3
L1/3
L2/3
L3/3
L4/3

Each layer is split across GPUs.
GPU k holds slice k of every layer.
All-reduce within each layer.

Pipeline Parallel

Different layers on different GPUs

GPU 0
L1
L2
L3
L4
GPU 1
L1
L2
L3
L4
GPU 2
L1
L2
L3
L4
GPU 3
L1
L2
L3
L4

Each GPU owns different layers.
Activations passed between stages.
Micro-batches reduce bubble.

💡 Tip · What to try:Think about which strategy works when the model doesn't fit on one GPU at all (tensor or pipeline), vs. when it fits but you want faster training (data parallel). Real systems combine all three — that's 3D parallelism.
💡

The Intuition

Why single GPU isn't enough:A 70B parameter model in fp16 needs ~140GB just for parameters. An A100 has 80GB. The model literally doesn't fit. Add optimizer states (2x parameters for Adam) and activations, and you need 500GB+ for training.

DDP (DistributedDataParallel): The simplest approach — each GPU holds a full copy of the model. Each processes a different data batch, then gradients are synchronized via all-reduce. Fast and simple, but requires the model to fit on one GPU.

ZeRO (Zero Redundancy Optimizer): The key insight — in DDP, every GPU stores identical optimizer states, gradients, and parameters. ZeRO eliminates this redundancy progressively:

  • Stage 1: Partition optimizer states ()
  • Stage 2: + Partition gradients ()
  • Stage 3: + Partition parameters ()

FSDP (Fully Sharded Data Parallel):PyTorch's native implementation of ZeRO Stage 3. Each GPU stores only 1/N of the model. Parameters are gathered before each forward/backward computation, then discarded.

Pipeline Parallelism: Split the model by layers across GPUs. GPU 0 runs layers 1-24, GPU 1 runs layers 25-48, etc. Use micro-batches to keep all GPUs busy and reduce the pipeline bubble.

3D Parallelism: Combine all three — tensor parallelism within a node (fast NVLink), pipeline parallelism across nodes, data parallelism across replicas. This is how every frontier model is trained. Llama-3 405B used DP=128, TP=8, PP=16 across 16,384 H100s — each axis targets a different bottleneck.

Expert Parallelism: In MoE (Mixture-of-Experts) models, different experts are placed on different GPUs. When a token is routed to expert k, it must be sent via all-to-all communication to the GPU holding that expert, then the result sent back. The challenge is load imbalance — if certain experts are consistently chosen more often than others, their host GPUs become bottlenecks while others sit idle. DeepSeek-V3 addresses this with auxiliary-loss-free balancing, using a bias term on routing logits that is updated dynamically to steer tokens toward under-utilized experts without adding a separate balancing loss to the training objective.

Sequence Parallelism: Tensor parallelism splits the weight matrices inside each layer, but some operations — LayerNorm — particularly LayerNorm — operate over the full hidden dimension and do not split naturally along the hidden dimension. Dropout is also handled here for efficiency (it runs on only 1/t of the data under SP), though it is elementwise and not fundamentally impossible to tensor-parallelize. Sequence parallelism (Megatron-LM, 2023) splits the sequencedimension across the same TP group for these non-parallelizable operations. This reduces activation memory by TP_degree× for those layers and is applied in tandem with tensor parallelism to achieve full “intra-layer” parallelism without leaving any activation memory unsharded.

✨ Insight · Think of it as splitting along three axes: data parallel splits the batch, tensor parallel splits the layers horizontally (within each layer), pipeline parallel splits layers vertically (across layers). 3D parallelism slices along all three axes simultaneously.

Overlapping Communication with Computation: In naive DDP, every backward pass ends with an all-reduce that stalls all GPUs until gradients are synchronized. Modern frameworks eliminate most of this stall by bucketingparameters and launching all-reduce for each bucket as soon as its gradients are ready — while the backward pass continues computing gradients for earlier layers. PyTorch DDP's default bucket size is 25 MB. Because the backward pass proceeds from last layer to first, the last-layer gradients are ready earliest and their all-reduce completes while the first-layer gradients are still being computed. In practice this overlapping hides 60–80% of the communication latency. FSDP extends this further with prefetching: it issues the all-gather for the nextlayer's parameters while the current layer is executing its forward pass, so parameter reconstruction never stalls computation.

Quick check

Derivation

In naive DDP each GPU holds full optimizer states, gradients, and parameters. ZeRO Stage 2 partitions optimizer states AND gradients. What does Stage 3 add that Stage 2 does not?

In naive DDP each GPU holds full optimizer states, gradients, and parameters. ZeRO Stage 2 partitions optimizer states AND gradients. What does Stage 3 add that Stage 2 does not?
Quick Check

What does ZeRO Stage 3 partition that Stage 1 doesn't?

📐

Step-by-Step Derivation

All-Reduce Communication Cost

In DDP, each GPU must synchronize gradients after every backward pass. The ring all-reduce algorithm sends times the model size across GPUs:

Where is the total parameter size. As , communication cost approaches — independent of GPU count. This is why ring all-reduce scales well.

Effective Batch Size

The total batch size seen by the optimizer per update step:

Where is the per-GPU micro-batch size, is the number of data-parallel replicas, and is gradient accumulation steps. LLaMA-2 70B used tokens per step.

Pipeline Bubble Fraction

The fraction of time GPUs are idle in pipeline parallelism:

Where is the number of pipeline stages and is the number of micro-batches. With stages and micro-batches, bubble is only . More micro-batches = less bubble but more memory.

PyTorch: DDP Setup

python
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# Initialize process group (one process per GPU)
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

# Wrap model with DDP — handles gradient sync automatically
model = MyModel().cuda(local_rank)
model = DDP(model, device_ids=[local_rank])

# Training loop is identical to single-GPU
for batch in dataloader:
    loss = model(batch).loss
    loss.backward()       # DDP all-reduces gradients here
    optimizer.step()
    optimizer.zero_grad()

PyTorch: FSDP Wrapper

python
import torch
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

# FSDP shards parameters, gradients, and optimizer states
model = FSDP(
    MyModel().cuda(),
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # ZeRO-3
    # SHARD_GRAD_OP = ZeRO-2, NO_SHARD = DDP
    auto_wrap_policy=size_based_auto_wrap_policy,
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,
    ),
)
# Parameters gathered on-demand for forward/backward, then freed
PyTorch implementation
# DDP training step — full end-to-end with gradient sync
import torch, os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

dist.init_process_group("nccl")
rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(rank)

model = MyTransformer().cuda(rank)
model = DDP(model, device_ids=[rank])      # wraps gradient all-reduce
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for batch in dataloader:
    input_ids = batch["input_ids"].cuda(rank)
    labels = batch["labels"].cuda(rank)

    loss = model(input_ids, labels=labels).loss
    loss.backward()                        # DDP all-reduces gradients here
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    optimizer.zero_grad()

Quick check

Derivation

A pipeline with P=16 stages uses M=48 micro-batches. What is the bubble fraction, and how does it compare to a run with P=4, M=12?

A pipeline with P=16 stages uses M=48 micro-batches. What is the bubble fraction, and how does it compare to a run with P=4, M=12?
🔧

Break It — See What Happens

No gradient sync between GPUs
Batch size too small per GPU

Quick check

Trade-off

If gradient synchronization is disabled in DDP and each GPU trains independently for 100 steps, then sync is re-enabled, what happens to training loss?

If gradient synchronization is disabled in DDP and each GPU trains independently for 100 steps, then sync is re-enabled, what happens to training loss?
📊

Real-World Numbers

ModelHardwareStrategyTraining Time
Llama-2 70BFSDP (ZeRO-3), no PP
DeepSeek-V33D parallelism + expert parallelism (MoE)
GPT-4 (estimated)3D parallelism (details undisclosed)
Llama-3.1 405B~54 days, 30.8M GPU-hours
✨ Insight · The scale is staggering: Llama-3.1 405B used 16,384 GPUs for almost 2 months. At that scale, hardware failures happen every few hours — fault tolerance and checkpointing are as important as the training algorithm itself.

Quick check

Trade-off

Llama-3.1 405B uses TP=8, PP=16, DP=128 on 16,384 H100s. A reviewer proposes increasing TP from 8 to 64 to reduce per-GPU memory further. What is the main reason to reject this?

Llama-3.1 405B uses TP=8, PP=16, DP=128 on 16,384 H100s. A reviewer proposes increasing TP from 8 to 64 to reduce per-GPU memory further. What is the main reason to reject this?

Modern Distributed Training (2024–2025)

The distributed training stack has changed significantly since the classic FSDP1 + ZeRO era. Four advances are now interview-relevant: FSDP2, Ring/Context Parallelism, ZeRO++, and Multi-Token Prediction.

FSDP2 — PyTorch 2.3+ default (2024)

FSDP1 shards optimizer states, gradients, and parameters at the module boundary — the entire module's parameters are gathered and scattered together. FSDP2 switches to per-parameter sharding: each individual parameter tensor is its own DTensor, enabling finer-grained memory control and eliminating the “flat parameter” bookkeeping that caused FSDP1 composability problems with TP and PP. The key interview point: FSDP2 is the recommended PyTorch primitive for combining data parallelism with tensor or pipeline parallelism in 3D-parallel runs (per pytorch.org/blog/fsdp2).

Ring Attention / Context Parallelism — torchtitan demo (2024)

Standard tensor parallelism shards the model; context parallelism shards the sequence. Each GPU holds all model weights but only a slice of the input tokens. Attention is computed in a “ring” pattern: GPU computes its local QK^T block, then passes its K/V slice to the next GPU while computing with the received slice. This enables (torchtitan demo, 2024) without blowing up per-GPU memory.

Source: github.com/pytorch/torchtitan

Sequence Parallelism (Megatron-LM) — Korthikanti et al., 2022 — arxiv:2205.05198

Megatron-LM's sequence parallelism partitions activations along the sequence dimension during non-tensor-parallel operations (LayerNorm, Dropout). In tensor-parallel mode, the attention and MLP outputs are already sharded across TP ranks; the gap was the residual stream activations, which were replicated. Sequence parallelism fills that gap by keeping activations sharded in sequence through the full forward pass, reducing activation memory proportionally to the TP degree. Combined with selective activation recomputation, this enables training very long sequences without full activation checkpointing.

ZeRO++ — DeepSpeed (2023–2024)

Standard ZeRO-3 gathers full-precision parameters from remote nodes before each forward layer. Over slow cross-node interconnects (e.g. 200 Gb/s IB vs 3.2 Tb/s NVLink), this all-gather dominates step time. ZeRO++ uses quantized weight gathering: parameters are compressed to INT8 before the cross-node all-gather, reducing bandwidth by ~4×, then dequantized locally. A hierarchical partitioning (qwZ) keeps full-precision copies within each node and only sends quantized deltas across nodes (per deepspeed.ai, 2023-10).

Multi-Token Prediction (MTP) — DeepSeek-V3 + Meta (2024)

Standard language model training predicts one next token per forward pass. MTP adds auxiliary output heads that predict tokens ahead simultaneously. Each position produces losses; the model learns richer representations to be predictive at multiple horizons. DeepSeek-V3 uses , which is dropped at inference — giving a free training-time regulariser. At inference, the MTP head can also serve as a speculative decoding draft model (per arxiv:2412.19437, DeepSeek-V3 Tech Report). Meta's MTP paper (arxiv:2404.19737) showed consistent improvements across code and reasoning benchmarks.

Sources: arxiv:2412.19437 (DeepSeek-V3) + arxiv:2404.19737 (Meta MTP)

PyTorch: FSDP2 per-parameter sharding (vs FSDP1 flat-param)
python
import torch
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed._composable.fsdp import fully_shard  # FSDP2 API

# FSDP1 (legacy) — module-level sharding, flat params under the hood
# model = FSDP(model, ...)  # wraps entire module tree

# FSDP2 (PyTorch 2.3+) — per-parameter DTensor sharding
# Composable with TP/PP; no flat-param bookkeeping.
model = nn.Transformer(d_model=1024, nhead=8, num_encoder_layers=6)

# Apply FSDP2 per transformer layer for optimal granularity
for layer in model.encoder.layers:
    fully_shard(layer)         # each layer's params become DTensors

fully_shard(model)             # top-level wrapper for optimizer states

# Now combine with Tensor Parallelism (FSDP2 is composable)
# from torch.distributed.tensor.parallel import parallelize_module
# parallelize_module(model, device_mesh["tp"], ...)
✨ Insight · Interview pattern:When asked “how would you train a 70B model on 256 GPUs?” — the canonical 2025 answer is: FSDP2 for data + parameter sharding within a node, tensor parallelism (TP=8) across NVLink peers, pipeline parallelism (PP=4) across node boundaries, and context parallelism for sequences beyond 32K. ZeRO++ or quantized weight gathering for any cross-datacenter links. MTP heads for free training-time regularisation. The composability story is why FSDP2 replaced FSDP1 as the default.
🧠

Key Takeaways

What to remember for interviews

  1. 1DDP replicates the full model on every GPU and syncs gradients via all-reduce — fast when the model fits in one GPU's memory.
  2. 2ZeRO eliminates redundancy by sharding optimizer states (Stage 1), gradients (Stage 2), and parameters (Stage 3) across GPUs. PyTorch FSDP implements Stage 3.
  3. 3Pipeline parallelism splits the model by layers; micro-batches reduce the bubble fraction to (P-1)/(P-1+M).
  4. 43D parallelism combines tensor (within node, fast NVLink), pipeline (across nodes), and data parallelism — used for every frontier model training run.
  5. 5Communication is overlapped with computation: DDP buckets gradients and all-reduces them while the backward pass continues; FSDP prefetches the next layer's parameters.
  6. 6FSDP2 (PyTorch 2.3+) uses per-parameter DTensor sharding instead of FSDP1 flat params — better composability with TP/PP. Ring Attention enables 1M-seq training on 64 GPUs. ZeRO++ cuts cross-node bandwidth ~4× via quantized weight gathering.
🧠

Recap quiz

Trade-off

A 13B-parameter model barely fits on an 80 GB A100 at bf16 with a batch size of 1. Which ZeRO stage should you enable first to scale to batch size 8 without adding GPUs?

A 13B-parameter model barely fits on an 80 GB A100 at bf16 with a batch size of 1. Which ZeRO stage should you enable first to scale to batch size 8 without adding GPUs?
Derivation

A training run uses P=4 pipeline stages and M=12 micro-batches. What fraction of GPU time is wasted in the pipeline bubble, and what is the fastest way to cut it in half?

A training run uses P=4 pipeline stages and M=12 micro-batches. What fraction of GPU time is wasted in the pipeline bubble, and what is the fastest way to cut it in half?
Trade-off

You have 8 GPUs and a 7B-parameter model that fits on a single A100. Which parallelism strategy maximizes throughput, and why?

You have 8 GPUs and a 7B-parameter model that fits on a single A100. Which parallelism strategy maximizes throughput, and why?
Derivation

In ring all-reduce with N GPUs, the communication cost formula is 2(N-1)/N × M/BW. What happens to communication cost as N grows from 8 to 1024 GPUs?

In ring all-reduce with N GPUs, the communication cost formula is 2(N-1)/N × M/BW. What happens to communication cost as N grows from 8 to 1024 GPUs?
Trade-off

FSDP (ZeRO Stage 3) and DDP both train a 7B model on 8 GPUs. FSDP uses less peak memory per GPU. What is the communication overhead FSDP pays that DDP does not?

FSDP (ZeRO Stage 3) and DDP both train a 7B model on 8 GPUs. FSDP uses less peak memory per GPU. What is the communication overhead FSDP pays that DDP does not?
Trade-off

Llama-3.1 405B uses TP=8, PP=16, DP=128 on 16,384 H100s. Why is tensor parallelism (TP) constrained to within a single node (8 GPUs), while pipeline (PP) spans nodes?

Llama-3.1 405B uses TP=8, PP=16, DP=128 on 16,384 H100s. Why is tensor parallelism (TP) constrained to within a single node (8 GPUs), while pipeline (PP) spans nodes?
Derivation

You have 8 GPUs, per-GPU batch size of 32, and run 4 gradient accumulation steps before each optimizer update. A colleague says the effective batch is 256. Are they right?

You have 8 GPUs, per-GPU batch size of 32, and run 4 gradient accumulation steps before each optimizer update. A colleague says the effective batch is 256. Are they right?
📚

Further Reading

🎯

Interview Questions

Difficulty:
Company:

Showing 7 of 7

Compare DDP and FSDP. When would you choose one over the other?

★★☆
GoogleMeta

Explain the pipeline bubble problem and how micro-batches help reduce it.

★★☆
GoogleMeta

What is activation checkpointing and what tradeoff does it make?

★★☆
GoogleOpenAI

Why does gradient accumulation help in distributed training, and how does it relate to effective batch size?

★☆☆
GoogleMeta

Design a 3D parallelism strategy for training a 175B parameter model on 1024 GPUs. Explain your choices.

★★★
GoogleOpenAI

How do you handle fault tolerance in large-scale distributed training?

★★★
GoogleMeta

How would you debug a single slow rank in a 1,000-GPU training job?

★★★
GoogleAnthropic