Skip to content

Transformer Math

Module 24 · Inference

Flash Attention

Same FLOPs, 2-4x faster — by never writing the N² attention matrix

Status:

Standard attention materializes the full attention matrix in GPU main memory (HBM). For a 4K sequence, that is 16M entries — for 128K, it is 16B entries. Flash Attention computes the exact same result but never writes that matrix to HBM. It tiles the computation into blocks that fit in fast on-chip SRAM, reducing memory from to .

🎮

GPU Memory Hierarchy and Tiling

The key insight: GPU compute is fast, but memory bandwidth is the bottleneck. Flash Attention keeps data in fast SRAM instead of writing to slow HBM.

SRAM (On-chip)

(shared mem + L1)

bandwidth

Small but blazing fast

HBM (GPU Main Memory)

40-80 GB (A100/H100)

bandwidth

Large but ~10x slower than SRAM

Standard Attention

1. S = Q @ K.T <- write N x N to HBM

2. P = softmax(S) <- read N x N, write N x N

3. O = P @ V <- read N x N from HBM

HBM access: O(N^2) reads + writes

Flash Attention (Tiled)

for each K_block, V_block:

for each Q_block:

S_tile = Q_blk @ K_blk.T in SRAM

online softmax + accumulate in SRAM

HBM access: Θ(N²d²/M) — much less

✨ Insight · Flash Attention does the same math (same FLOPs), but restructures WHERE that math happens. By keeping intermediate results in fast SRAM instead of writing to slow HBM, it achieves with zero approximation.

Flash Attention Tiling Simulation

What you’re seeing:tiled Q, K, V blocks loaded into fast SRAM one at a time — the full N×N attention matrix is never written to slow HBM, which is why FlashAttention cuts memory from O(N²) to O(N). What to try: step forward tile by tile and notice that the running softmax denominator is updated incrementally — that’s the online normalisation trick that makes the tiling numerically exact.

Standard AttentionFull N x N in HBMHBM: O(N²) memoryFlash AttentionTile-by-tile in SRAMSRAM: O(N) memoryrunning maxrunning sum16 tiles total

Key insight: Flash Attention computes the exact same result as standard attention. It is not an approximation — it just changes the order of computation to avoid storing the full N x N matrix. The trick: online softmax maintains a running maximum so partial results can be corrected when a larger value is encountered.

Quick check

Trade-off

SRAM is ~19 TB/s and HBM is ~2 TB/s on A100. Why can't Flash Attention just use HBM with a faster GPU?

SRAM is ~19 TB/s and HBM is ~2 TB/s on A100. Why can't Flash Attention just use HBM with a faster GPU?
💡

The Intuition

Think of standard attention like a chef who puts every ingredient on the counter (HBM), works on one thing, puts it back, grabs the next. Flash Attention is a chef who grabs a small batch of ingredients at once, does all the work on the cutting board (SRAM), and only puts the final dish on the counter.

Three key ideas make this work:

  • Tiling — split Q, K, V into blocks that fit in SRAM. Compute attention tile-by-tile
  • Online softmax — maintain a running max and running sum so you never need the full row of scores at once. Rescale partial results when a new max is found
  • IO-awareness — design the algorithm around memory hierarchy (SRAM vs HBM), not just FLOP count. Same FLOPs, fewer memory reads = faster
✨ Insight · Flash Attention doesn't change the math — the output is bit-for-bit identical to standard attention. It only changes the order of operations to be IO-aware. This is a pure systems optimization, not an approximation. Unlike Linformer or Performer (which approximate the attention matrix), Flash Attention is an exact drop-in replacement with zero quality loss.

Flash Attention 3 (2024):Shah et al. targeted H100-specific hardware capabilities that FA2 didn't exploit. H100 Tensor Cores support asynchronous data movement (via TMA — Tensor Memory Accelerator) and native FP8 computation. FA3 overlaps GEMM (matrix multiplication) with softmax computation using warp-level pipelining: while one warp computes softmax on a completed tile, another warp is already issuing the next GEMM. This async block-level pipelining achieves , and FP8 support doubles throughput further for inference. FA3 is H100-specific — FA2 remains the standard on A100 and older GPUs.

Flash-Decoding (2023): FA2 parallelizes across batch size and attention heads, but during autoregressive decode the batch size is 1 and the query sequence length is 1 — leaving most parallelism unused. Flash-Decoding (de la Torre et al.) adds a third parallelism dimension: the KV sequence. It splits the KV cache into chunks, each processed by a separate thread block in parallel, then reduces the partial softmax outputs. For long-context inference (e.g., 128K tokens), this delivers on the decode step — which is the dominant cost when generating long responses against a large KV cache.

Ring Attentionextends Flash Attention's tiling idea across multiple GPUs. In Flash Attention, tiles of K and V are kept in one GPU's SRAM. Ring Attention (Liu et al. 2023) distributes Q, K, V across devices and passes K/V blocks around a logical ring — each device processes its local Q block against the incoming K/V block simultaneously, so computation and communication overlap. The result: attention over sequences of theoretically unbounded length, scaling linearly with the number of GPUs. This is one approach for scaling long-context attention across devices when a single GPU's SRAM/HBM is insufficient — proprietary systems like Gemini don't publicly disclose whether they use Ring Attention specifically.

Quick Check

Why is Flash Attention faster if it does the same number of FLOPs?

📐

Step-by-Step Derivation

Memory: Standard vs Flash

Standard attention must store the full attention matrix:

For N=128K, d=128: standard needs entries (~32 GB in FP16), Flash needs entries (~32 MB).

HBM Access (What Determines Wallclock)

FLOPs are identical. The difference is memory access:

where is SRAM size. Since for typical head dimensions (d=64 or 128), Flash requires far fewer HBM accesses.

💡 Tip · On an A100 with and d=128, both sides in elements: . Equivalently, in bytes (FP16 = 2 B): — Flash's tiling fits comfortably and HBM access drops by a large factor for long sequences.

Quick check

Derivation

Flash HBM access is Θ(N²d²/M). Standard is Θ(Nd + N²). For N=4096, d=128, M=98304 (A100 SRAM in FP16 elements), which requires fewer HBM reads?

Flash HBM access is Θ(N²d²/M). Standard is Θ(Nd + N²). For N=4096, d=128, M=98304 (A100 SRAM in FP16 elements), which requires fewer HBM reads?

Worked Example — A100 SRAM vs HBM, seq_len=2048, d=128

The A100 has 80 GB HBM2e at bandwidth and ~20 MB SRAM at — roughly 10× faster on-chip. Here is exactly what fits where:

TensorShapeBytes (FP16)Fits in SRAM?
Q, K, V (each)2048 × 128512 KBYes (1.5 MB total)
QKᵀ (1 head)2048 × 20488 MBTight (20 MB total w/ 32 heads = 256 MB)
Flash tile (256 × 128)256 × 12864 KBYes, comfortably

Standard attention path:To compute one head, you write QKᵀ (8 MB) to HBM, read it back for softmax, write softmax(QKᵀ) (8 MB) to HBM, read it again to multiply by V. That is 4 × 8 MB = 32 MB of HBM traffic per head just for intermediate results — at this takes ~16 µs per head, not counting compute.

Flash Attention path: Process 256-token tiles. Each Q tile is 256 × 128 × 2 = 64 KB; each K/V tile is another 64 KB. Total tile size ≈ 192 KB — fits in SRAM with room to spare. The full N×N attention matrix is never written to HBM. Only the output O (2048 × 128 = 512 KB) goes back to HBM when done.

Result: , and memory footprint drops from O(N²) to O(N).

Online Softmax (The Key Trick)

Standard softmax needs two passes (find max, then normalize). Online softmax maintains running statistics:

The old output is rescaled by to account for the updated running max, then divided by the updated denominator to maintain correct normalization. This is mathematically exact — no approximation.

PyTorch: Using Flash Attention

python
import torch
import torch.nn.functional as F

# Method 1: PyTorch native (auto-selects Flash Attention if available)
# Requires: PyTorch >= 2.0, CUDA, head_dim <= 256
def attention_with_flash(q, k, v, is_causal=True):
    """
    q, k, v: [batch, n_heads, seq_len, d_k]
    Uses Flash Attention kernel automatically when possible.
    """
    return F.scaled_dot_product_attention(
        q, k, v,
        is_causal=is_causal,
        # PyTorch auto-dispatches to flash_attn kernel
    )

# Method 2: flash-attn library (Tri Dao's implementation)
# pip install flash-attn
from flash_attn import flash_attn_func

def attention_with_flash_v2(q, k, v, causal=True):
    """
    q, k, v: [batch, seq_len, n_heads, d_k]  (note: different layout!)
    Returns: [batch, seq_len, n_heads, d_k]
    """
    return flash_attn_func(q, k, v, causal=causal)

# Check which backend PyTorch selects:
# torch.backends.cuda.flash_sdp_enabled()  # True if Flash available
# torch.backends.cuda.mem_efficient_sdp_enabled()  # xformers fallback
PyTorch implementation
# Flash Attention: use F.scaled_dot_product_attention (PyTorch 2.0+)
# Auto-selects Flash Attention kernel when on CUDA with head_dim <= 256
import torch
import torch.nn.functional as F

def causal_self_attention(q, k, v):
    # q, k, v: (batch, n_heads, seq_len, head_dim)
    # is_causal=True applies the causal mask without materializing it
    return F.scaled_dot_product_attention(q, k, v, is_causal=True)

# Verify Flash Attention is enabled:
# torch.backends.cuda.flash_sdp_enabled()   # True on CUDA >= sm80
# torch.backends.cuda.math_sdp_enabled()    # fallback: standard O(N²)

# Flash Attention saves memory: O(N) vs O(N²) for seq_len=128K
# N=128K: standard = 128K² × 2B = 32GB; flash = 128K × 128 × 2B = 32MB
🔧

Break It — See What Happens

Materialize full N x N attention matrix
Tile size larger than SRAM capacity
Materialize full attention matrix at seq_len=32K (long context)

Quick check

Trade-off

If the Flash Attention tile size is set larger than SRAM capacity (~192 KB per SM on A100), what happens?

If the Flash Attention tile size is set larger than SRAM capacity (~192 KB per SM on A100), what happens?
📊

Real-World Numbers

VersionGPUSpeedup vs StandardKey Innovation
Flash Attention v1A100Tiling + online softmax, O(N) memory
Flash Attention v2A100Reversed loop order, better warp partitioning, fewer non-matmul FLOPs
Flash Attention v3H1001.5-2x over v2 on H100Async TMA, FP8 tensor cores, warp specialization
PyTorch SDPA (native)A100/H100Auto-dispatches to Flash v2F.scaled_dot_product_attention, zero code changes
✨ Insight · Flash Attention is now the default in every major framework. PyTorch 2.0+ automatically uses it via F.scaled_dot_product_attention. It enabled the jump from 4K to 128K+ context lengths — without it, 128K attention would need ~32 GB per head just for the attention matrix.

Quick check

Trade-off

FA2 reaches 50–73% of A100 theoretical peak FLOPS. What FA2 change over FA1 primarily explains this jump?

FA2 reaches 50–73% of A100 theoretical peak FLOPS. What FA2 change over FA1 primarily explains this jump?
🚀

SOTA 2024: FlashAttention-3 (NeurIPS Spotlight)

FlashAttention-3 (Shah et al., NeurIPS 2024 spotlight, arxiv:2407.08608) achieves — roughly on the same hardware. The key: it targets H100-specific features unavailable on A100.

Three H100-specific innovations in FA3

  1. Warp specialization via TMA (Tensor Memory Accelerator): H100 has dedicated async data-movement hardware (TMA) that can transfer tiles to SRAM without occupying CUDA cores. FA3 pipelines TMA loads of the next K/V tile while WGMMA computes on the current tile — hiding memory latency behind compute.
  2. WGMMA async pipeline:H100's warpgroup matrix-multiply-accumulate (WGMMA) instruction issues asynchronously. FA3 keeps the pipeline full by overlapping the softmax rescaling of one tile with the matrix multiply of the next — unlike FA2 which stalls between the two.
  3. FP8 tensor core path: FA3 adds an FP8 forward pass that reaches on H100 FP8 tensor cores, vs ~740 TFLOPs/s in FP16.

H100 throughput comparison (FP16, causal attention)

VersionGPUTFLOPs/s (FP16)% of peakKey mechanism
FA2H100~280~29%Reversed loop order, warp partitioning
FA3H100~75%TMA async + WGMMA pipeline + warp specialization
FA3 FP8H100~62% FP8 peakFP8 tensor cores, E4M3 format
⚠ Warning · Critical interview trap:FA3 specifically targets H100 hardware (TMA + WGMMA are H100-only instructions). On A100 — which has neither TMA nor WGMMA — FA3 provides no meaningful speedup over FA2. Always qualify: “FA3 is a 1.5–2× win on H100, but FA2 is still the right choice on A100.” Blog post: tridao.me/blog/2024/flash3/
How to use FA3 in practice
# FA3 is available via flash-attn >= 2.6 on H100
# pip install flash-attn --no-build-isolation

import torch
import torch.nn.functional as F
from flash_attn import flash_attn_func  # dispatches FA3 on H100 automatically

# Standard usage — FA3 is transparent to the caller
q = torch.randn(2, 32, 4096, 128, device="cuda", dtype=torch.float16)  # (B, H, T, d_k)
k = torch.randn(2, 32, 4096, 128, device="cuda", dtype=torch.float16)
v = torch.randn(2, 32, 4096, 128, device="cuda", dtype=torch.float16)

# flash_attn_func expects (B, T, H, d_k) layout
q = q.transpose(1, 2)  # -> (B, T, H, d_k)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = flash_attn_func(q, k, v, causal=True)  # FA3 on H100, FA2 on A100

# PyTorch SDPA auto-dispatches to FA2 (not FA3 yet as of PyTorch 2.4)
out_pt = F.scaled_dot_product_attention(
    q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
    is_causal=True
)  # ~FA2 speed on both A100 and H100
🔭

Attention Variants Frontier (2024–2025)

Flash Attention, MLA, and sparse attention are orthogonal axes of optimization. The 2024–2025 frontier extended all three dimensions:

TechniqueAxisKey claimSource
Flash Attention 3Memory I/O740 TFLOPs/s on H100 (75% MFU) via TMA+WGMMAtridao.me, 2024
Differential AttentionNoise cancelingSubtracts two softmax maps → cancels attention noise, reduces hallucinations, improves long-context retrievalarxiv:2410.05258, Oct 2024
Native Sparse Attention (NSA)Compute sparsityHardware-aligned sparse attention; , vs dense attention (DeepSeek, ACL 2025 Best Paper)arxiv:2502.11089, Feb 2025
✨ Insight · Orthogonal stack, not competing choices. MLA (DeepSeek V2) compresses the KV cache at rest; Flash Attention reduces HBM traffic during the computation; Differential Attention reduces attention noise in the output; NSA eliminates most compute entirely. Production systems can combine all four. See also Long Context for how NSA enables 64K+ training at reduced cost.
Deep dive: Differential Attention (DIFF Transformer, Oct 2024)

Standard self-attention is noisy: most attention heads spread weight across irrelevant tokens, with only a few high-value entries. DIFF Transformer (MSR + Tsinghua, arxiv:2410.05258) computes two independent softmax attention maps (using two separate {Q,K} pairs per head) and subtracts one from the other, scaled by a learned λ parameter:

Attn = softmax(Q₁Kᵀ/√d)·V − λ · softmax(Q₂Kᵀ/√d)·V

The subtraction cancels distributed “noise” attention, leaving only the high-signal spikes. Empirically: fewer hallucinations on long-context retrieval tasks, stronger in-context learning, and better key-information extraction vs the standard transformer at matched parameter count. The head dimension is halved to keep total params equal.

Deep dive: Native Sparse Attention (NSA, DeepSeek, Feb 2025)

NSA (arxiv:2502.11089) addresses the core cost of long-context training: O(n²) compute. Rather than attending to all tokens, NSA uses a three-branch hardware-aligned sparse pattern: (1) compressed tokens — block-mean pooled global summary, (2) selected tokens — top-k important positions, (3) sliding window — local context. All three branches are implemented with custom CUDA kernels that map onto Tensor Core tile sizes, making the sparsity actually faster rather than conditionally slower (the common failure mode of software-sparse attention).

Benchmarks at 64K context vs dense FlashAttention-2: training throughput , decoding . Named ACL 2025 Best Paper. Used in DeepSeek V3.1+ as the production long-context recipe.

🧠

Key Takeaways

What to remember for interviews

  1. 1Flash Attention computes the exact same result as standard attention but never writes the N×N matrix to HBM — tiles fit entirely in fast on-chip SRAM (~19 TB/s vs ~2 TB/s for HBM).
  2. 2Online softmax maintains a running max and running sum per block, rescaling partial results when a larger max is found — this is what makes single-pass tiling mathematically exact.
  3. 3Memory drops from O(N²) to O(N); HBM access drops from Θ(Nd + N²) to Θ(N²d²/M). For 128K context, standard attention needs ~32 GB per head; Flash needs ~32 MB.
  4. 4FA3 (H100 only): TMA async data movement + WGMMA pipeline achieves ~740 TFLOPs/s FP16 (~75% utilization) vs FA2's ~280 TFLOPs/s. On A100, FA3 ≈ FA2 — no H100-specific hardware available.
  5. 5PyTorch 2.0+ auto-dispatches F.scaled_dot_product_attention to Flash Attention — zero code changes needed to get the speedup.
🧠

Recap Quiz

Trade-off

A colleague says Flash Attention is faster because it approximates the softmax. How do you correct them?

A colleague says Flash Attention is faster because it approximates the softmax. How do you correct them?
Derivation

For a 128K-token sequence with d=128 in FP16, how much memory does standard attention need per head vs Flash Attention?

For a 128K-token sequence with d=128 in FP16, how much memory does standard attention need per head vs Flash Attention?
Derivation

During online softmax, why must the accumulated output O^(j) be multiplied by exp(m^(j) − m^(j+1)) before adding the new tile?

During online softmax, why must the accumulated output O^(j) be multiplied by exp(m^(j) − m^(j+1)) before adding the new tile?
Trade-off

FA3 achieves 1.5–2× speedup over FA2 on H100 but not on A100. What H100-specific capabilities does FA3 exploit that A100 lacks?

FA3 achieves 1.5–2× speedup over FA2 on H100 but not on A100. What H100-specific capabilities does FA3 exploit that A100 lacks?
Derivation

FA2 parallelizes across batch and heads. Why does this leave GPU utilization near zero during single-token decode with batch=1, and what does Flash-Decoding add?

FA2 parallelizes across batch and heads. Why does this leave GPU utilization near zero during single-token decode with batch=1, and what does Flash-Decoding add?
Derivation

Standard attention on an A100 is IO-bound, not compute-bound. What is the core arithmetic intensity argument?

Standard attention on an A100 is IO-bound, not compute-bound. What is the core arithmetic intensity argument?
📚

Further Reading

🎯

Interview Questions

Difficulty:
Company:

Showing 5 of 5

Flash Attention does the same number of FLOPs as standard attention. Why is it faster?

★★★
OpenAIAnthropic

How does online softmax work and why is it essential for Flash Attention?

★★★
OpenAIAnthropic

What is IO-awareness and why does it matter for GPU kernels?

★★☆
GoogleMeta

Compare Flash Attention v1, v2, and v3. What changed in each version?

★★★
OpenAIAnthropic

Is Flash Attention exact or approximate? What are the implications?

★★☆
OpenAIGoogle