Skip to content

Transformer Math

Module 4 · The Transformer

🧮 MLP & Matmul

A 2-layer MLP can approximate any function — so why do we need 96 layers?

Status:

Why can a 2-layer network with 1,000 neurons approximate any continuous function — yet GPT-4 needs 96 layers? Because the Universal Approximation Theorem tells you what is possible, not how efficiently. Everything hinges on matrix multiplication, and one wrong initialization can make training impossible before it begins.

🧠

MLP Architecture

A Multi-Layer Perceptron stacks fully-connected layers: every neuron in one layer connects to every neuron in the next. The highlighted path (blue nodes and bold edges) traces a single forward pass signal from one input feature through both hidden layers to an output.

Input (features)Hidden 1 (ReLU)Hidden 2 (ReLU)Output (softmax)Each arrow = one weight. Total weights = 3×4 + 4×4 + 4×2 = 36

Highlighted path shows one activation flowing through the network. Bias terms (one per neuron) are not shown — add 4 + 4 + 2 = 10 more parameters for a total of 46.

🎮

Matrix Multiplication Visualized

Every forward pass is a chain of matmuls. Here: a 3×2 matrix A times a 2×4 matrix B produces a 3×4 output C. The highlighted cell C[1][2] (green) is the dot product of row 1 (blue) in A and column 2 (amber) in B.

A (3×2)B (2×4)C (3×4)123456×10210132=1285341811562817C[1][2] = 3×2 + 4×3 = 18(row 1 of A) · (col 2 of B)

Each output element requires one dot product. For an (m×k) × (k×n) matmul: m·n dot products, each of length k → O(m·k·n) FLOPs.

A single neuron applies this matmul idea to one row: it takes a weighted sum of its inputs, adds a bias, then applies a nonlinearity (activation function).

x₁wx₂wx₃wΣ + bσ(·)yy = σ(w₁x₁ + w₂x₂ + w₃x₃ + b)
💡

Why Matrices?

A neuron is a weighted vote: each input gets a weight, the votes are summed, and the result is thresholded by an activation function. A layer is manyneurons computing in parallel — that's a row of a weight matrix. An MLP is votes of votes: the second layer votes on the first layer's outputs.

✨ Insight · Matrix multiplication is the single most important operation in deep learning. Everything — attention, FFN, embedding lookup, output projection — is matmul. Modern GPUs are essentially matmul engines with memory.

Why not just use for-loops over neurons? Three reasons:

  • Batched computation — process 512 examples at once with a single GEMM call instead of 512 sequential passes.
  • GPU parallelism — matrix multiply maps directly to tensor cores running thousands of multiply-add operations in parallel.
  • Gradient flow — the matmul Jacobian is another matrix multiply, making backprop expressible entirely in the same language as the forward pass.
🎯 Interview · Interview pattern: "Why is batch size important?" — larger batches amortize fixed GPU launch overhead and improve arithmetic intensity (ratio of FLOPs to memory bytes). Below ~32, the GPU is mostly idle waiting on memory.

Quick check

Trade-off

A 512-neuron hidden layer processes a batch of 64 inputs. Using a for-loop over neurons takes O(512 × 64) sequential steps. A matmul does it in one GPU call. Beyond speed, what does batched matmul enable that sequential loops do not?

A 512-neuron hidden layer processes a batch of 64 inputs. Using a for-loop over neurons takes O(512 × 64) sequential steps. A matmul does it in one GPU call. Beyond speed, what does batched matmul enable that sequential loops do not?
Quick Check

A weight matrix W has shape (512, 2048). An input batch x has shape (32, 512). What is the shape of xW (ignoring bias), and how many FLOPs does this require?

⚖️

Weight Initialization

Before training begins, weights must be initialized carefully. Consider what happens in a 50-layer network: if each layer scales activations by even 1.01×, after 50 layers you get 1.01⁵⁰ ≈ 1.64× amplification. With 1.5×, you get 1.5⁵⁰ ≈ 637,000× — explosion. With 0.99×, you get 0.99⁵⁰ ≈ 0.6× — slow vanishing.

⚠ Warning · Bad initialization can make learning literally impossible — not just slow. If pre-activations saturate (sigmoid outputs ≈ 1 or 0), gradients are zero and the network cannot learn from the first batch.

Activation distributions: good init vs. bad init

Activation valueGood init (σ≈1)Too large (explodes)Too small (vanishes)

With good init, activations stay in a stable range across layers. Too large → exploding activations; too small → vanishing signal.

Xavier / Glorot Init (Glorot & Bengio 2010)

Sample weights from a uniform distribution in the range ±√(6 / (fan_in + fan_out)). Keeps variance ≈ 1 through tanh / sigmoid layers. Derived by requiring , which gives .

He / Kaiming Init (He et al. 2015)

Sample from a zero-mean normal distribution with std . ReLU kills ~50% of activations (outputs zero for negative inputs), so variance is halved. Doubling the scale to compensates. Use for ReLU, LeakyReLU, and variants.

💡 Tip · PyTorch's nn.Linear uses Kaiming uniform by default. Call torch.nn.init.xavier_uniform_(layer.weight) when using tanh or sigmoid activations. Forget this and your 10-layer network may not train at all.

Quick check

Derivation

Xavier init sets sigma^2 = 2/(fan_in + fan_out). He init sets sigma^2 = 2/fan_in. Both derive from requiring Var(output) = Var(input). What does ReLU do to the variance that forces the extra factor of 2 in He init?

Xavier init sets sigma^2 = 2/(fan_in + fan_out). He init sets sigma^2 = 2/fan_in. Both derive from requiring Var(output) = Var(input). What does ReLU do to the variance that forces the extra factor of 2 in He init?
📐

MLP Formula & Universal Approximation

A two-layer MLP (one hidden layer) with hidden dim d and nonlinearity computes:

Hidden layer

Output layer

The Universal Approximation Theorem : given any continuous function from R^n to R^m on a compact domain and any error tolerance ε > 0, there exists a single-hidden-layer MLP that approximates the target function to within ε, provided the activation is non-polynomial (Leshno et al.) and the hidden layer is wide enough.

⚠ Warning · Gradient descent may never find the solution. The network needed could require exponentially many neurons. UAT is a lower bound on expressibility, not a recipe for architecture design.

PyTorch implementation

import torch
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, out_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),   # W1, b1 — Kaiming uniform init
            nn.ReLU(),                        # σ: kills negatives, keeps positives
            nn.Linear(hidden_dim, out_dim),  # W2, b2
        )
        # For tanh activations, override with Xavier:
        # nn.init.xavier_uniform_(self.net[0].weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

# h = ReLU(W1 x + b1),  y_hat = W2 h + b2
model = MLP(in_dim=512, hidden_dim=2048, out_dim=256)

# Count parameters
total = sum(p.numel() for p in model.parameters())
print(f"Parameters: {total:,}")  # 1,050,624 + 524,544 = 1,575,168
# Layer 0: 512*2048 + 2048 = 1,050,624
# Layer 2: 2048*256 + 256  = 524,544
# Total: 1,575,168

x = torch.randn(32, 512)  # batch of 32
y = model(x)
print(y.shape)  # (32, 256)
PyTorch: custom weight init + checking activation stats
import torch
import torch.nn as nn

def build_mlp_with_init(in_dim, hidden_dim, out_dim, activation="relu"):
    layer1 = nn.Linear(in_dim, hidden_dim)
    layer2 = nn.Linear(hidden_dim, out_dim)

    if activation == "relu":
        # He/Kaiming: std = sqrt(2 / fan_in)
        nn.init.kaiming_normal_(layer1.weight, nonlinearity="relu")
    else:
        # Xavier: std = sqrt(2 / (fan_in + fan_out))
        nn.init.xavier_uniform_(layer1.weight)

    nn.init.zeros_(layer1.bias)
    nn.init.zeros_(layer2.bias)

    act = nn.ReLU() if activation == "relu" else nn.Tanh()
    return nn.Sequential(layer1, act, layer2)

model = build_mlp_with_init(512, 2048, 256)

# Check activation stats after first layer
x = torch.randn(1000, 512)
with torch.no_grad():
    h = torch.relu(model[0](x))  # post-activation
    print(f"Mean: {h.mean():.4f}")   # should be ~0.5-0.8 with good init
    print(f"Std:  {h.std():.4f}")    # should be ~1.0
    print(f"Dead neurons: {(h == 0).float().mean():.1%}")  # ideally ~50%

Quick check

Trade-off

The UAT (Hornik 1989) proves a 2-layer MLP can approximate any continuous function. A candidate in an interview says: “So for any ML problem, a 2-layer network is sufficient — we just need enough neurons.” What is the most important thing they are missing?

The UAT (Hornik 1989) proves a 2-layer MLP can approximate any continuous function. A candidate in an interview says: “So for any ML problem, a 2-layer network is sufficient — we just need enough neurons.” What is the most important thing they are missing?
🔧

Break It

Initialize every weight to zero
Use an excessively large init scale

Stable training needs both asymmetry and controlled variance. Good initialization gives each neuron a distinct starting direction without blowing up the signal.

Quick check

Derivation

You zero-initialize all weights in a 3-hidden-layer MLP with ReLU activations. You train on MNIST for 10 epochs with SGD, lr=0.01. The training loss barely moves. Explain the failure chain from initialization to gradient update.

You zero-initialize all weights in a 3-hidden-layer MLP with ReLU activations. You train on MNIST for 10 epochs with SGD, lr=0.01. The training loss barely moves. Explain the failure chain from initialization to gradient update.
📊

Real Numbers

Model / BlockShapeParamsWhy it matters
GPT-2 FFN768 -> 3072 -> 768Most parameters sit in the MLP, not attention.
Llama-2 7B FFN4096 -> 11008 -> 4096135M / blockSwiGLU uses 3 projections (gate, up, down) — dominates both memory and compute.
Toy 512 -> 2048 -> 2562 layers1.58MSmall enough to reason about by hand, large enough to show batching effects.

A Transformer FFN is just an MLP with a much wider hidden layer. That is why understanding MLP initialization and variance flow pays off directly in LLM training.

🧠

Key Takeaways

What to remember for interviews

  1. 1An MLP is a stack of linear transforms (Wx + b) with nonlinear activations between them — without activations, all layers collapse into one.
  2. 2The universal approximation theorem guarantees a 2-layer MLP can approximate any continuous function, but depth gives exponentially more efficient representations.
  3. 3Weight initialization (Xavier/He) prevents vanishing or exploding activations — zero init causes symmetry breaking failure where all neurons learn the same thing.
  4. 4In Transformers, the FFN block is a 2-layer MLP (d→4d→d) containing ~67% of all parameters — it acts as a key-value memory storing factual knowledge.
🧠

Recap quiz

Trade-off

The Universal Approximation Theorem (Hornik et al. 1989) says a 2-layer MLP can approximate any continuous function. Why do Transformers like GPT-3 use 96 layers instead of 1 very wide layer?

The Universal Approximation Theorem (Hornik et al. 1989) says a 2-layer MLP can approximate any continuous function. Why do Transformers like GPT-3 use 96 layers instead of 1 very wide layer?
Derivation

A 10-layer MLP uses tanh activations throughout. Your colleague switches all weight init to He/Kaiming. What will happen, and why?

A 10-layer MLP uses tanh activations throughout. Your colleague switches all weight init to He/Kaiming. What will happen, and why?
Derivation

A Transformer FFN block has shape d_model=768, hidden_dim=3072 (4× expansion). For a batch of B=32 and sequence length T=512, how many FLOPs does one FFN forward pass cost?

A Transformer FFN block has shape d_model=768, hidden_dim=3072 (4× expansion). For a batch of B=32 and sequence length T=512, how many FLOPs does one FFN forward pass cost?
Trade-off

A team initializes all weights in an MLP to 0 to ensure stable training. What is the precise failure mode?

A team initializes all weights in an MLP to 0 to ensure stable training. What is the precise failure mode?
Trade-off

The Transformer FFN block expands d_model to 4×d_model in the hidden layer (“4× expansion”). This is used in GPT-2 (768→3072), GPT-3 (12288→49152), and Llama. Why 4× specifically, and what would happen with 1× or 16×?

The Transformer FFN block expands d_model to 4×d_model in the hidden layer (“4× expansion”). This is used in GPT-2 (768→3072), GPT-3 (12288→49152), and Llama. Why 4× specifically, and what would happen with 1× or 16×?
Derivation

GPT-2 Small has 12 layers. Each FFN block has shape 768→3072→768. How many parameters do all FFN blocks contribute to the total 124M?

GPT-2 Small has 12 layers. Each FFN block has shape 768→3072→768. How many parameters do all FFN blocks contribute to the total 124M?
Trade-off

A researcher claims: “We can replace GPT-3’s 96-layer FFN stack with a single FFN layer of 96× the hidden dimension, maintaining the same parameter count.” What is the fundamental flaw?

A researcher claims: “We can replace GPT-3’s 96-layer FFN stack with a single FFN layer of 96× the hidden dimension, maintaining the same parameter count.” What is the fundamental flaw?
📚

Further Reading

🎯

Interview Questions

Difficulty:
Company:

Showing 5 of 5

Why does Xavier init use 1/√fan_in? What problem does it solve, and when do you use Kaiming instead?

★★☆
GoogleMetaAnthropic

What is the difference between nn.Linear and torch.matmul? When would you prefer one over the other?

★☆☆
MetaOpenAI

Explain the dying ReLU problem. What causes it, how do you detect it, and what are the fixes?

★★☆
GoogleAnthropic

What does the Universal Approximation Theorem actually guarantee? What does it NOT tell you?

★★★
GoogleAnthropicOpenAI

A 3-layer MLP has input dim 512, hidden dim 2048, output dim 256. How many parameters? How does batch size affect compute?

★★☆
MetaGoogle