⚡ Flash Attention
Same FLOPs, 2-4x faster — by never writing the N² attention matrix
Standard attention materializes the full attention matrix in GPU main memory (HBM). For a 4K sequence, that is 16M entries — for 128K, it is 16B entries. Flash Attention computes the exact same result but never writes that matrix to HBM. It tiles the computation into blocks that fit in fast on-chip SRAM, reducing memory from to .
GPU Memory Hierarchy and Tiling
The key insight: GPU compute is fast, but memory bandwidth is the bottleneck. Flash Attention keeps data in fast SRAM instead of writing to slow HBM.
SRAM (On-chip)
(shared mem + L1)
bandwidth
Small but blazing fast
HBM (GPU Main Memory)
40-80 GB (A100/H100)
bandwidth
Large but ~10x slower than SRAM
Standard Attention
1. S = Q @ K.T <- write N x N to HBM
2. P = softmax(S) <- read N x N, write N x N
3. O = P @ V <- read N x N from HBM
HBM access: O(N^2) reads + writes
Flash Attention (Tiled)
for each K_block, V_block:
for each Q_block:
S_tile = Q_blk @ K_blk.T in SRAM
online softmax + accumulate in SRAM
HBM access: Θ(N²d²/M) — much less
Flash Attention Tiling Simulation
What you’re seeing:tiled Q, K, V blocks loaded into fast SRAM one at a time — the full N×N attention matrix is never written to slow HBM, which is why FlashAttention cuts memory from O(N²) to O(N). What to try: step forward tile by tile and notice that the running softmax denominator is updated incrementally — that’s the online normalisation trick that makes the tiling numerically exact.
Key insight: Flash Attention computes the exact same result as standard attention. It is not an approximation — it just changes the order of computation to avoid storing the full N x N matrix. The trick: online softmax maintains a running maximum so partial results can be corrected when a larger value is encountered.
Quick check
SRAM is ~19 TB/s and HBM is ~2 TB/s on A100. Why can't Flash Attention just use HBM with a faster GPU?
The Intuition
Think of standard attention like a chef who puts every ingredient on the counter (HBM), works on one thing, puts it back, grabs the next. Flash Attention is a chef who grabs a small batch of ingredients at once, does all the work on the cutting board (SRAM), and only puts the final dish on the counter.
Three key ideas make this work:
- Tiling — split Q, K, V into blocks that fit in SRAM. Compute attention tile-by-tile
- Online softmax — maintain a running max and running sum so you never need the full row of scores at once. Rescale partial results when a new max is found
- IO-awareness — design the algorithm around memory hierarchy (SRAM vs HBM), not just FLOP count. Same FLOPs, fewer memory reads = faster
Flash Attention 3 (2024):Shah et al. targeted H100-specific hardware capabilities that FA2 didn't exploit. H100 Tensor Cores support asynchronous data movement (via TMA — Tensor Memory Accelerator) and native FP8 computation. FA3 overlaps GEMM (matrix multiplication) with softmax computation using warp-level pipelining: while one warp computes softmax on a completed tile, another warp is already issuing the next GEMM. This async block-level pipelining achieves , and FP8 support doubles throughput further for inference. FA3 is H100-specific — FA2 remains the standard on A100 and older GPUs.
Flash-Decoding (2023): FA2 parallelizes across batch size and attention heads, but during autoregressive decode the batch size is 1 and the query sequence length is 1 — leaving most parallelism unused. Flash-Decoding (de la Torre et al.) adds a third parallelism dimension: the KV sequence. It splits the KV cache into chunks, each processed by a separate thread block in parallel, then reduces the partial softmax outputs. For long-context inference (e.g., 128K tokens), this delivers on the decode step — which is the dominant cost when generating long responses against a large KV cache.
Ring Attentionextends Flash Attention's tiling idea across multiple GPUs. In Flash Attention, tiles of K and V are kept in one GPU's SRAM. Ring Attention (Liu et al. 2023) distributes Q, K, V across devices and passes K/V blocks around a logical ring — each device processes its local Q block against the incoming K/V block simultaneously, so computation and communication overlap. The result: attention over sequences of theoretically unbounded length, scaling linearly with the number of GPUs. This is one approach for scaling long-context attention across devices when a single GPU's SRAM/HBM is insufficient — proprietary systems like Gemini don't publicly disclose whether they use Ring Attention specifically.
Why is Flash Attention faster if it does the same number of FLOPs?
Step-by-Step Derivation
Memory: Standard vs Flash
Standard attention must store the full attention matrix:
For N=128K, d=128: standard needs entries (~32 GB in FP16), Flash needs entries (~32 MB).
HBM Access (What Determines Wallclock)
FLOPs are identical. The difference is memory access:
where is SRAM size. Since for typical head dimensions (d=64 or 128), Flash requires far fewer HBM accesses.
Quick check
Flash HBM access is Θ(N²d²/M). Standard is Θ(Nd + N²). For N=4096, d=128, M=98304 (A100 SRAM in FP16 elements), which requires fewer HBM reads?
Worked Example — A100 SRAM vs HBM, seq_len=2048, d=128
The A100 has 80 GB HBM2e at bandwidth and ~20 MB SRAM at — roughly 10× faster on-chip. Here is exactly what fits where:
| Tensor | Shape | Bytes (FP16) | Fits in SRAM? |
|---|---|---|---|
| Q, K, V (each) | 2048 × 128 | 512 KB | Yes (1.5 MB total) |
| QKᵀ (1 head) | 2048 × 2048 | 8 MB | Tight (20 MB total w/ 32 heads = 256 MB) |
| Flash tile (256 × 128) | 256 × 128 | 64 KB | Yes, comfortably |
Standard attention path:To compute one head, you write QKᵀ (8 MB) to HBM, read it back for softmax, write softmax(QKᵀ) (8 MB) to HBM, read it again to multiply by V. That is 4 × 8 MB = 32 MB of HBM traffic per head just for intermediate results — at this takes ~16 µs per head, not counting compute.
Flash Attention path: Process 256-token tiles. Each Q tile is 256 × 128 × 2 = 64 KB; each K/V tile is another 64 KB. Total tile size ≈ 192 KB — fits in SRAM with room to spare. The full N×N attention matrix is never written to HBM. Only the output O (2048 × 128 = 512 KB) goes back to HBM when done.
Result: , and memory footprint drops from O(N²) to O(N).
Online Softmax (The Key Trick)
Standard softmax needs two passes (find max, then normalize). Online softmax maintains running statistics:
The old output is rescaled by to account for the updated running max, then divided by the updated denominator to maintain correct normalization. This is mathematically exact — no approximation.
PyTorch: Using Flash Attention
import torch
import torch.nn.functional as F
# Method 1: PyTorch native (auto-selects Flash Attention if available)
# Requires: PyTorch >= 2.0, CUDA, head_dim <= 256
def attention_with_flash(q, k, v, is_causal=True):
"""
q, k, v: [batch, n_heads, seq_len, d_k]
Uses Flash Attention kernel automatically when possible.
"""
return F.scaled_dot_product_attention(
q, k, v,
is_causal=is_causal,
# PyTorch auto-dispatches to flash_attn kernel
)
# Method 2: flash-attn library (Tri Dao's implementation)
# pip install flash-attn
from flash_attn import flash_attn_func
def attention_with_flash_v2(q, k, v, causal=True):
"""
q, k, v: [batch, seq_len, n_heads, d_k] (note: different layout!)
Returns: [batch, seq_len, n_heads, d_k]
"""
return flash_attn_func(q, k, v, causal=causal)
# Check which backend PyTorch selects:
# torch.backends.cuda.flash_sdp_enabled() # True if Flash available
# torch.backends.cuda.mem_efficient_sdp_enabled() # xformers fallbackPyTorch implementation
# Flash Attention: use F.scaled_dot_product_attention (PyTorch 2.0+)
# Auto-selects Flash Attention kernel when on CUDA with head_dim <= 256
import torch
import torch.nn.functional as F
def causal_self_attention(q, k, v):
# q, k, v: (batch, n_heads, seq_len, head_dim)
# is_causal=True applies the causal mask without materializing it
return F.scaled_dot_product_attention(q, k, v, is_causal=True)
# Verify Flash Attention is enabled:
# torch.backends.cuda.flash_sdp_enabled() # True on CUDA >= sm80
# torch.backends.cuda.math_sdp_enabled() # fallback: standard O(N²)
# Flash Attention saves memory: O(N) vs O(N²) for seq_len=128K
# N=128K: standard = 128K² × 2B = 32GB; flash = 128K × 128 × 2B = 32MBBreak It — See What Happens
Quick check
If the Flash Attention tile size is set larger than SRAM capacity (~192 KB per SM on A100), what happens?
Real-World Numbers
| Version | GPU | Speedup vs Standard | Key Innovation |
|---|---|---|---|
| Flash Attention v1 | A100 | Tiling + online softmax, O(N) memory | |
| Flash Attention v2 | A100 | Reversed loop order, better warp partitioning, fewer non-matmul FLOPs | |
| Flash Attention v3 | H100 | 1.5-2x over v2 on H100 | Async TMA, FP8 tensor cores, warp specialization |
| PyTorch SDPA (native) | A100/H100 | Auto-dispatches to Flash v2 | F.scaled_dot_product_attention, zero code changes |
F.scaled_dot_product_attention. It enabled the jump from 4K to 128K+ context lengths — without it, 128K attention would need ~32 GB per head just for the attention matrix.Quick check
FA2 reaches 50–73% of A100 theoretical peak FLOPS. What FA2 change over FA1 primarily explains this jump?
SOTA 2024: FlashAttention-3 (NeurIPS Spotlight)
FlashAttention-3 (Shah et al., NeurIPS 2024 spotlight, arxiv:2407.08608) achieves — roughly on the same hardware. The key: it targets H100-specific features unavailable on A100.
Three H100-specific innovations in FA3
- Warp specialization via TMA (Tensor Memory Accelerator): H100 has dedicated async data-movement hardware (TMA) that can transfer tiles to SRAM without occupying CUDA cores. FA3 pipelines TMA loads of the next K/V tile while WGMMA computes on the current tile — hiding memory latency behind compute.
- WGMMA async pipeline:H100's warpgroup matrix-multiply-accumulate (WGMMA) instruction issues asynchronously. FA3 keeps the pipeline full by overlapping the softmax rescaling of one tile with the matrix multiply of the next — unlike FA2 which stalls between the two.
- FP8 tensor core path: FA3 adds an FP8 forward pass that reaches on H100 FP8 tensor cores, vs ~740 TFLOPs/s in FP16.
H100 throughput comparison (FP16, causal attention)
| Version | GPU | TFLOPs/s (FP16) | % of peak | Key mechanism |
|---|---|---|---|---|
| FA2 | H100 | ~280 | ~29% | Reversed loop order, warp partitioning |
| FA3 | H100 | ~75% | TMA async + WGMMA pipeline + warp specialization | |
| FA3 FP8 | H100 | ~62% FP8 peak | FP8 tensor cores, E4M3 format |
How to use FA3 in practice
# FA3 is available via flash-attn >= 2.6 on H100
# pip install flash-attn --no-build-isolation
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_func # dispatches FA3 on H100 automatically
# Standard usage — FA3 is transparent to the caller
q = torch.randn(2, 32, 4096, 128, device="cuda", dtype=torch.float16) # (B, H, T, d_k)
k = torch.randn(2, 32, 4096, 128, device="cuda", dtype=torch.float16)
v = torch.randn(2, 32, 4096, 128, device="cuda", dtype=torch.float16)
# flash_attn_func expects (B, T, H, d_k) layout
q = q.transpose(1, 2) # -> (B, T, H, d_k)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = flash_attn_func(q, k, v, causal=True) # FA3 on H100, FA2 on A100
# PyTorch SDPA auto-dispatches to FA2 (not FA3 yet as of PyTorch 2.4)
out_pt = F.scaled_dot_product_attention(
q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
is_causal=True
) # ~FA2 speed on both A100 and H100Attention Variants Frontier (2024–2025)
Flash Attention, MLA, and sparse attention are orthogonal axes of optimization. The 2024–2025 frontier extended all three dimensions:
| Technique | Axis | Key claim | Source |
|---|---|---|---|
| Flash Attention 3 | Memory I/O | 740 TFLOPs/s on H100 (75% MFU) via TMA+WGMMA | tridao.me, 2024 |
| Differential Attention | Noise canceling | Subtracts two softmax maps → cancels attention noise, reduces hallucinations, improves long-context retrieval | arxiv:2410.05258, Oct 2024 |
| Native Sparse Attention (NSA) | Compute sparsity | Hardware-aligned sparse attention; , vs dense attention (DeepSeek, ACL 2025 Best Paper) | arxiv:2502.11089, Feb 2025 |
Deep dive: Differential Attention (DIFF Transformer, Oct 2024)
Standard self-attention is noisy: most attention heads spread weight across irrelevant tokens, with only a few high-value entries. DIFF Transformer (MSR + Tsinghua, arxiv:2410.05258) computes two independent softmax attention maps (using two separate {Q,K} pairs per head) and subtracts one from the other, scaled by a learned λ parameter:
Attn = softmax(Q₁Kᵀ/√d)·V − λ · softmax(Q₂Kᵀ/√d)·V
The subtraction cancels distributed “noise” attention, leaving only the high-signal spikes. Empirically: fewer hallucinations on long-context retrieval tasks, stronger in-context learning, and better key-information extraction vs the standard transformer at matched parameter count. The head dimension is halved to keep total params equal.
Deep dive: Native Sparse Attention (NSA, DeepSeek, Feb 2025)
NSA (arxiv:2502.11089) addresses the core cost of long-context training: O(n²) compute. Rather than attending to all tokens, NSA uses a three-branch hardware-aligned sparse pattern: (1) compressed tokens — block-mean pooled global summary, (2) selected tokens — top-k important positions, (3) sliding window — local context. All three branches are implemented with custom CUDA kernels that map onto Tensor Core tile sizes, making the sparsity actually faster rather than conditionally slower (the common failure mode of software-sparse attention).
Benchmarks at 64K context vs dense FlashAttention-2: training throughput , decoding . Named ACL 2025 Best Paper. Used in DeepSeek V3.1+ as the production long-context recipe.
Key Takeaways
What to remember for interviews
- 1Flash Attention computes the exact same result as standard attention but never writes the N×N matrix to HBM — tiles fit entirely in fast on-chip SRAM (~19 TB/s vs ~2 TB/s for HBM).
- 2Online softmax maintains a running max and running sum per block, rescaling partial results when a larger max is found — this is what makes single-pass tiling mathematically exact.
- 3Memory drops from O(N²) to O(N); HBM access drops from Θ(Nd + N²) to Θ(N²d²/M). For 128K context, standard attention needs ~32 GB per head; Flash needs ~32 MB.
- 4FA3 (H100 only): TMA async data movement + WGMMA pipeline achieves ~740 TFLOPs/s FP16 (~75% utilization) vs FA2's ~280 TFLOPs/s. On A100, FA3 ≈ FA2 — no H100-specific hardware available.
- 5PyTorch 2.0+ auto-dispatches F.scaled_dot_product_attention to Flash Attention — zero code changes needed to get the speedup.
Recap Quiz
A colleague says Flash Attention is faster because it approximates the softmax. How do you correct them?
For a 128K-token sequence with d=128 in FP16, how much memory does standard attention need per head vs Flash Attention?
During online softmax, why must the accumulated output O^(j) be multiplied by exp(m^(j) − m^(j+1)) before adding the new tile?
FA3 achieves 1.5–2× speedup over FA2 on H100 but not on A100. What H100-specific capabilities does FA3 exploit that A100 lacks?
FA2 parallelizes across batch and heads. Why does this leave GPU utilization near zero during single-token decode with batch=1, and what does Flash-Decoding add?
Standard attention on an A100 is IO-bound, not compute-bound. What is the core arithmetic intensity argument?
Further Reading
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness — The original Flash Attention paper — tiling attention computation to exploit GPU SRAM, achieving 2-4x speedup.
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning — Flash Attention v2 — improved work partitioning across warps and thread blocks for up to 2x additional speedup.
- Ring Attention with Blockwise Transformers for Near-Infinite Context — Liu et al. 2023 — extends Flash Attention's tiling idea across multiple GPUs by passing K/V blocks in a ring, enabling theoretically unbounded context length.
- Lilian Weng's Blog — Technical posts on efficient attention, memory optimization, and inference acceleration for LLMs.
- FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision — Shah et al. 2024 — Flash Attention v3 for H100, using async TMA and FP8 tensor cores to achieve 1.5-2x speedup over v2.
- Online normalizer calculation for softmax (Milakov & Gimelshein) — The online softmax algorithm that makes Flash Attention's single-pass tiling mathematically correct — understand this and the whole trick clicks.
Interview Questions
Showing 5 of 5
Flash Attention does the same number of FLOPs as standard attention. Why is it faster?
★★★How does online softmax work and why is it essential for Flash Attention?
★★★What is IO-awareness and why does it matter for GPU kernels?
★★☆Compare Flash Attention v1, v2, and v3. What changed in each version?
★★★Is Flash Attention exact or approximate? What are the implications?
★★☆