Part of Series Transformer Anatomy 37 of 36
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Distributed Data Parallel: Gradient Synchronization, Bucket All-Reduce, and Overlap with Backward 21 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 22 Dropout and Regularization in Transformers: Where It Helps, Where It Hurts 23 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 24 Mixed Precision Training: BF16 Forward, FP32 Master Weights, and the Precision Hierarchy 25 Token Prediction Heads: Next-Token, Multi-Token, and Classifier Heads 26 Mixture of Depths: Conditional Computation Per Layer for Faster Inference 27 Sparse Attention Patterns: Local, Strided, Hash-Based, and Learnable Sparsity 28 Rotary Position Embedding: The Complete Mathematical Derivation 29 Knowledge Distillation: Training Small Models to Match Large Ones 30 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search 31 Pruning at Scale: SparseGPT, Wanda, and Structured Removal of Redundant Parameters 32 The Transformer in 2026: What Changed, What Stayed, and What's Next 33 Data Loading: Tokenization, Sequence Packing, Padding Strategies, and Attention Masks 34 The FlashAttention Backward Pass: Recomputation, Memory Savings, and the 33% Compute Overhead 35 The Inference Engine: Token Generation Loop, KV Cache Management, and Autoregressive Decoding 36 Tensor Parallelism Implementation: Splitting Weights Across GPUs for Training and Inference

Standard attention computes the N×NN \times N attention matrix P=softmax(QKT/d)P = \text{softmax}(QK^T / \sqrt{d}) and stores it for the backward pass. For N=128KN = 128K with FP16, storing PP requires 1280002×2=32.8128000^2 \times 2 = 32.8 GB per head per layer. A 32-head, 80-layer model at 128K context would need 83 TB for attention matrices alone. This is not feasible.

FlashAttention solves this by not storing PP at all. Instead, it stores only the softmax normalization statistics — the per-row maximum and sum — totaling O(N)O(N) memory. During the backward pass, it recomputes PP block by block from QQ, KK, VV and these saved statistics. The recomputation adds roughly 33% extra FLOPs to the backward pass compared to a hypothetical backward pass with PP pre-stored, but saves over 90% of memory.

This post derives the forward pass tiling algorithm, explains exactly what statistics are saved, derives the backward pass recomputation strategy, computes the FLOPs overhead, and provides a complete implementation sketch.

Standard Attention: The Memory Problem

1.1 Forward Pass

The standard attention forward pass:

S=QKT/dRN×NS = QK^T / \sqrt{d} \quad \in \mathbb{R}^{N \times N} P=softmax(S)RN×NP = \text{softmax}(S) \quad \in \mathbb{R}^{N \times N} O=PVRN×dO = PV \quad \in \mathbb{R}^{N \times d}

where Q,K,VRN×dQ, K, V \in \mathbb{R}^{N \times d} are the query, key, and value matrices for one attention head.

import torch
import torch.nn.functional as F

def standard_attention_forward(Q, K, V):
    """Standard attention: stores the full N x N attention matrix.

    Q, K, V: [N, d]
    Returns: O [N, d], P [N, N] (stored for backward)
    """
    d = Q.shape[-1]
    scale = d ** -0.5

    # Scores: [N, N] -- this is the memory bottleneck
    S = Q @ K.T * scale

    # Softmax: [N, N] -- same size
    P = F.softmax(S, dim=-1)

    # Output: [N, d]
    O = P @ V

    # Must store P for backward pass
    return O, P  # P is N x N -- massive for large N

1.2 Backward Pass (Standard)

Given upstream gradient dORN×ddO \in \mathbb{R}^{N \times d}:

dV=PTdOdV = P^T \cdot dO dP=dOVTdP = dO \cdot V^T dS=P(dProwsum(dPP))dS = P \odot (dP - \text{rowsum}(dP \odot P)) dQ=dSK/ddQ = dS \cdot K / \sqrt{d} dK=dSTQ/ddK = dS^T \cdot Q / \sqrt{d}

def standard_attention_backward(Q, K, V, P, O, dO):
    """Standard backward: requires P (the stored N x N matrix).

    All inputs: [N, d] except P: [N, N]
    """
    d = Q.shape[-1]
    scale = d ** -0.5

    # dV = P^T @ dO  -- needs P
    dV = P.T @ dO  # [N, d]

    # dP = dO @ V^T  -- needs V (already stored as input)
    dP = dO @ V.T  # [N, N]

    # dS: softmax backward
    # dS_ij = P_ij * (dP_ij - sum_k(dP_ik * P_ik))
    row_sum = (dP * P).sum(dim=-1, keepdim=True)  # [N, 1]
    dS = P * (dP - row_sum)  # [N, N] -- needs P again

    # dQ = dS @ K * scale
    dQ = dS @ K * scale  # [N, d]

    # dK = dS^T @ Q * scale
    dK = dS.T @ Q * scale  # [N, d]

    return dQ, dK, dV

The backward pass reads PP (the N×NN \times N matrix) twice: once for dV=PTdOdV = P^T \cdot dO and once for the softmax gradient. This is why standard attention must store PP during the forward pass.

1.3 Memory Accounting

def attention_memory_standard(N, d, n_heads, n_layers, dtype_bytes=2):
    """Memory for storing attention matrices (standard attention)."""
    # Per-head, per-layer: N x N matrix
    per_head_layer = N * N * dtype_bytes

    # Total across all heads and layers
    total = per_head_layer * n_heads * n_layers

    print(f"N={N:,}, d={d}, heads={n_heads}, layers={n_layers}")
    print(f"Per head per layer: {per_head_layer / 1e9:.2f} GB")
    print(f"Total attention matrices: {total / 1e9:.1f} GB")
    return total

# Llama 70B at context length 128K
attention_memory_standard(N=131072, d=128, n_heads=64, n_layers=80)
# Per head per layer: 34.36 GB
# Total attention matrices: 175,921.9 GB  -- impossibly large
🚨 Danger

At 128K context length, standard attention would need 176 TB just for the attention matrices of one forward pass of Llama 70B. This is why FlashAttention is not optional — it is a prerequisite for training with contexts longer than about 2K tokens.

FlashAttention Forward Pass: Tiled Computation

2.1 Online Softmax

FlashAttention computes softmax without materializing the full N×NN \times N matrix. The key algorithm is online softmax (Milakov and Gimelshein, 2018): compute the softmax in a single pass, maintaining running statistics.

For a row of scores s1,s2,,sNs_1, s_2, \ldots, s_N:

softmax(si)=esimjesjm\text{softmax}(s_i) = \frac{e^{s_i - m}}{\sum_j e^{s_j - m}}

where m=maxjsjm = \max_j s_j. Online softmax computes mm and =jesjm\ell = \sum_j e^{s_j - m} incrementally:

def online_softmax_demo(scores):
    """Online softmax: compute max and sum in a single pass.

    Process scores in blocks. After each block, update the
    running max and running sum.
    """
    block_size = 64
    N = len(scores)

    m = float("-inf")  # Running max
    ell = 0.0          # Running sum of exp(s - m)

    for start in range(0, N, block_size):
        block = scores[start:start + block_size]

        # New max considering this block
        m_new = max(m, block.max().item())

        # Correction factor for previous sum
        correction = torch.exp(torch.tensor(m - m_new))

        # Update running sum
        ell = ell * correction + torch.exp(block - m_new).sum().item()

        # Update max
        m = m_new

    # Final softmax values
    softmax_values = torch.exp(scores - m) / ell
    return softmax_values, m, ell

2.2 Tiled Forward Pass

FlashAttention tiles the computation into blocks of size Br×BcB_r \times B_c (rows of QQ times columns of KK). It processes one block at a time, accumulating the output and softmax statistics:

def flash_attention_forward(Q, K, V, block_size_r=64, block_size_c=64):
    """FlashAttention forward pass (simplified, single-head).

    Q, K, V: [N, d]
    Returns: O [N, d], (m, ell) for backward -- NOT the N x N matrix

    Key insight: we never materialize the full N x N attention matrix.
    Instead, we process blocks and accumulate the output online.
    """
    N, d = Q.shape

    # Output accumulator and softmax statistics
    O = torch.zeros(N, d, device=Q.device, dtype=Q.dtype)
    m = torch.full((N, 1), float("-inf"), device=Q.device)  # Running max per row
    ell = torch.zeros(N, 1, device=Q.device)                # Running sum per row

    scale = d ** -0.5

    # Iterate over K, V blocks (columns of the attention matrix)
    for j in range(0, N, block_size_c):
        j_end = min(j + block_size_c, N)
        K_block = K[j:j_end]  # [Bc, d]
        V_block = V[j:j_end]  # [Bc, d]

        # Iterate over Q blocks (rows of the attention matrix)
        for i in range(0, N, block_size_r):
            i_end = min(i + block_size_r, N)
            Q_block = Q[i:i_end]  # [Br, d]

            # Compute scores for this tile: [Br, Bc]
            S_tile = Q_block @ K_block.T * scale

            # Online softmax update
            m_old = m[i:i_end]           # [Br, 1]
            ell_old = ell[i:i_end]       # [Br, 1]
            O_old = O[i:i_end]           # [Br, d]

            # New max: max of old max and max of this tile
            m_tile = S_tile.max(dim=-1, keepdim=True).values  # [Br, 1]
            m_new = torch.maximum(m_old, m_tile)

            # Correction factors
            exp_old = torch.exp(m_old - m_new)  # Scale old accumulator
            exp_new = torch.exp(S_tile - m_new)  # New tile's contribution

            # Update running sum
            ell_new = exp_old * ell_old + exp_new.sum(dim=-1, keepdim=True)

            # Update output: rescale old output and add new contribution
            O_new = (exp_old * ell_old * O_old + exp_new @ V_block) / ell_new

            # Store updated values
            O[i:i_end] = O_new
            m[i:i_end] = m_new
            ell[i:i_end] = ell_new

    # Save only these for backward (NOT the N x N matrix):
    # - Q, K, V (already stored as inputs)
    # - m: [N, 1] per-row max
    # - ell: [N, 1] per-row sum
    # - O: [N, d] output (needed for softmax gradient)

    return O, m, ell

2.3 What Gets Saved

def memory_saved_for_backward(N, d, dtype_bytes=2):
    """Compare what standard vs FlashAttention saves for backward."""

    # Standard: saves P (the N x N attention matrix)
    standard_bytes = N * N * dtype_bytes

    # FlashAttention: saves m [N, 1] and ell [N, 1]
    # Plus Q, K, V, O which are [N, d] each -- these are needed by both
    flash_saved = N * 2 * 4  # m and ell in FP32, each [N, 1]

    # Shared between both methods: Q, K, V, O
    shared = 4 * N * d * dtype_bytes

    print(f"N = {N:,}, d = {d}")
    print(f"Standard extra (P):       {standard_bytes / 1e9:.2f} GB")
    print(f"FlashAttention extra (m, ell): {flash_saved / 1e6:.2f} MB")
    print(f"Shared (Q, K, V, O):      {shared / 1e6:.2f} MB")
    print(f"Memory reduction: {standard_bytes / max(flash_saved, 1):.0f}x")

    return standard_bytes, flash_saved

# At N = 128K
memory_saved_for_backward(131072, 128)
# Standard extra (P):       34.36 GB
# FlashAttention extra (m, ell): 1.05 MB
# Shared (Q, K, V, O):      134.22 MB
# Memory reduction: 32768x
ℹ️ Note

FlashAttention saves only two vectors per row: mim_i (the max score for row ii) and i\ell_i (the sum of exponentials for row ii). These are the normalization constants of the softmax. Together they occupy 2N×42N \times 4 bytes (FP32), compared to N2×2N^2 \times 2 bytes for the full attention matrix. At N=128KN = 128K, that is 1 MB vs 34 GB — a 32,768x reduction.

FlashAttention Backward Pass

3.1 The Recomputation Strategy

During the backward pass, FlashAttention needs PP (the attention matrix) to compute gradients. Instead of loading a stored copy, it recomputes PP block by block using:

  • QQ, KK (to recompute S=QKT/dS = QK^T / \sqrt{d})
  • mm, \ell (to convert SS to P=eSmP = \frac{e^{S - m}}{\ell} without recomputing the full softmax)

This is the critical insight: with mm and \ell saved from the forward pass, we can reconstruct any tile of PP from the corresponding tile of SS:

Pij=eSijmiiP_{ij} = \frac{e^{S_{ij} - m_i}}{\ell_i}

No global reduction is needed — each element of PP depends only on its row’s mm and \ell.

3.2 Full Backward Implementation

def flash_attention_backward(Q, K, V, O, dO, m, ell,
                              block_size_r=64, block_size_c=64):
    """FlashAttention backward pass with recomputation.

    Q, K, V: [N, d] -- original inputs
    O: [N, d] -- forward output
    dO: [N, d] -- upstream gradient
    m: [N, 1] -- per-row max from forward
    ell: [N, 1] -- per-row sum from forward

    Returns: dQ [N, d], dK [N, d], dV [N, d]

    Key: we recompute P block-by-block instead of loading it.
    """
    N, d = Q.shape
    scale = d ** -0.5

    # Initialize gradient accumulators
    dQ = torch.zeros_like(Q)
    dK = torch.zeros_like(K)
    dV = torch.zeros_like(V)

    # Precompute D = rowsum(dO * O) -- needed for softmax backward
    # D_i = sum_j(dO_ij * O_ij) for each row i
    D = (dO * O).sum(dim=-1, keepdim=True)  # [N, 1]

    # Process tiles
    for j in range(0, N, block_size_c):
        j_end = min(j + block_size_c, N)
        K_block = K[j:j_end]     # [Bc, d]
        V_block = V[j:j_end]     # [Bc, d]

        dK_block = torch.zeros_like(K_block)
        dV_block = torch.zeros_like(V_block)

        for i in range(0, N, block_size_r):
            i_end = min(i + block_size_r, N)
            Q_block = Q[i:i_end]     # [Br, d]
            dO_block = dO[i:i_end]   # [Br, d]
            m_block = m[i:i_end]     # [Br, 1]
            ell_block = ell[i:i_end] # [Br, 1]
            D_block = D[i:i_end]     # [Br, 1]

            # RECOMPUTE: scores for this tile
            S_tile = Q_block @ K_block.T * scale  # [Br, Bc]

            # RECOMPUTE: attention weights from saved statistics
            # P_tile = exp(S_tile - m) / ell
            P_tile = torch.exp(S_tile - m_block) / ell_block  # [Br, Bc]

            # Now compute gradients using recomputed P_tile

            # dV_block += P_tile^T @ dO_block
            dV_block += P_tile.T @ dO_block  # [Bc, d]

            # dP_tile = dO_block @ V_block^T
            dP_tile = dO_block @ V_block.T  # [Br, Bc]

            # dS_tile = P_tile * (dP_tile - D)
            # This is the softmax backward formula:
            # dS_ij = P_ij * (dP_ij - sum_k(dP_ik * P_ik))
            # where D_i = sum_k(dO_ik * O_ik) = sum_k(dP_ik * P_ik)
            dS_tile = P_tile * (dP_tile - D_block)  # [Br, Bc]

            # dQ_block += dS_tile @ K_block * scale
            dQ[i:i_end] += dS_tile @ K_block * scale  # [Br, d]

            # dK_block += dS_tile^T @ Q_block * scale
            dK_block += dS_tile.T @ Q_block * scale  # [Bc, d]

        dK[j:j_end] = dK_block
        dV[j:j_end] = dV_block

    return dQ, dK, dV

3.3 The D Vector

The DD vector deserves explanation. In the softmax backward, we need:

dSij=Pij(dPijkdPikPik)dS_{ij} = P_{ij} \cdot (dP_{ij} - \sum_k dP_{ik} \cdot P_{ik})

The term kdPikPik\sum_k dP_{ik} \cdot P_{ik} is per-row and requires summing over all kk columns. But we also know:

kdPikPik=jdOijOij\sum_k dP_{ik} \cdot P_{ik} = \sum_j dO_{ij} \cdot O_{ij}

So Di=jdOijOijD_i = \sum_j dO_{ij} \cdot O_{ij}, which can be computed from dOdO and OO (both already available) without needing PP.

def compute_D_vector(dO, O):
    """Compute D = rowsum(dO * O).

    This avoids needing P for the softmax backward's
    normalization term. D_i = sum_j(dO_ij * O_ij).

    dO: [N, d] upstream gradient
    O:  [N, d] forward pass output
    Returns: [N, 1]
    """
    return (dO * O).sum(dim=-1, keepdim=True)

FLOPs Analysis: The 33% Overhead

4.1 Forward Pass FLOPs

The forward pass computes:

  1. S=QKTS = QK^T: 2N2d2N^2 d FLOPs (matrix multiply)
  2. P=softmax(S)P = \text{softmax}(S): roughly 5N25N^2 FLOPs (exp, sum, divide — negligible vs. matmul)
  3. O=PVO = PV: 2N2d2N^2 d FLOPs

Total forward: 4N2d4N^2 d FLOPs (same for standard and FlashAttention).

4.2 Standard Backward FLOPs

With PP stored:

  1. dV=PTdOdV = P^T \cdot dO: 2N2d2N^2 d FLOPs
  2. dP=dOVTdP = dO \cdot V^T: 2N2d2N^2 d FLOPs
  3. dS=P(dPD)dS = P \odot (dP - D): roughly 3N23N^2 FLOPs (element-wise, negligible)
  4. dQ=dSKdQ = dS \cdot K: 2N2d2N^2 d FLOPs
  5. dK=dSTQdK = dS^T \cdot Q: 2N2d2N^2 d FLOPs

Total standard backward: 8N2d8N^2 d FLOPs.

4.3 FlashAttention Backward FLOPs

FlashAttention recomputes SS and PP in the backward pass:

  1. Recompute S=QKTS = QK^T: 2N2d2N^2 d FLOPs (extra)
  2. Recompute PP from SS, mm, \ell: roughly 3N23N^2 FLOPs (negligible)
  3. dV=PTdOdV = P^T \cdot dO: 2N2d2N^2 d FLOPs
  4. dP=dOVTdP = dO \cdot V^T: 2N2d2N^2 d FLOPs
  5. dS=P(dPD)dS = P \odot (dP - D): roughly 3N23N^2 FLOPs (negligible)
  6. dQ=dSKdQ = dS \cdot K: 2N2d2N^2 d FLOPs
  7. dK=dSTQdK = dS^T \cdot Q: 2N2d2N^2 d FLOPs

Total FlashAttention backward: 10N2d10N^2 d FLOPs.

4.4 The Overhead Calculation

def flops_comparison(N, d):
    """Compare FLOPs for standard vs FlashAttention."""
    # Forward (same for both)
    fwd_flops = 4 * N**2 * d

    # Backward
    std_bwd_flops = 8 * N**2 * d
    flash_bwd_flops = 10 * N**2 * d  # +2N^2d from recomputation

    # Total
    std_total = fwd_flops + std_bwd_flops    # 12 N^2 d
    flash_total = fwd_flops + flash_bwd_flops  # 14 N^2 d

    overhead = (flash_total - std_total) / std_total

    print(f"Standard total:        {std_total / 1e12:.2f} TFLOP")
    print(f"FlashAttention total:  {flash_total / 1e12:.2f} TFLOP")
    print(f"Overhead:              {overhead:.1%}")
    print(f"Recompute cost:        {2 * N**2 * d / 1e12:.2f} TFLOP")

    return overhead

# Llama 70B, N=4096, d=128, per head per layer
flops_comparison(4096, 128)
# Standard total:        0.03 TFLOP
# FlashAttention total:  0.03 TFLOP
# Overhead:              16.7%
Performance

The recomputation overhead is 2N2d2N^2 d extra FLOPs (one additional QKTQK^T matmul in the backward). This is 25% of the standard backward cost (8N2d8N^2 d) or 16.7% of the total training cost (12N2d12N^2 d). The commonly cited 33% figure comes from 2N2d/(4N2d+2N2d)2N^2 d / (4N^2 d + 2N^2 d) — one-third of a combined forward-plus-recompute cost. Regardless of how you count, the overhead is small compared to the 32,768x memory savings.

Why Recomputation Is Faster Than It Sounds

5.1 Memory Bandwidth Is the Bottleneck

Modern GPUs are memory-bandwidth limited for attention. The theoretical compute time for a 4096×40964096 \times 4096 attention matrix multiply on an H100 is:

2×40962×128990×1012=4.3μs\frac{2 \times 4096^2 \times 128}{990 \times 10^{12}} = 4.3 \, \mu s

But loading QQ and KK from HBM takes:

(4096×128+4096×128)×23.35×1012=0.63μs\frac{(4096 \times 128 + 4096 \times 128) \times 2}{3.35 \times 10^{12}} = 0.63 \, \mu s

For the H100 (990 TFLOP/s, 3.35 TB/s), the crossover point is 990/3.35=295990/3.35 = 295 FLOP/byte. At 124 FLOP/byte, the attention matmul is memory-bound. Adding more FLOPs (recomputation) does not significantly increase wall-clock time because the GPU has spare compute cycles while waiting for memory.

def roofline_analysis(N, d, gpu_flops_tflops, gpu_bw_tbs):
    """Roofline analysis for attention computation."""
    # FLOPs for QK^T
    flops = 2 * N * N * d

    # Bytes loaded/stored (Q, K input; S output)
    bytes_moved = (N * d + N * d + N * N) * 2  # FP16

    # Arithmetic intensity
    ai = flops / bytes_moved

    # Roofline crossover
    crossover = gpu_flops_tflops / gpu_bw_tbs

    # Actual time
    compute_time = flops / (gpu_flops_tflops * 1e12)
    memory_time = bytes_moved / (gpu_bw_tbs * 1e12)
    actual_time = max(compute_time, memory_time)

    bottleneck = "compute" if compute_time > memory_time else "memory"

    print(f"N={N}, d={d}")
    print(f"Arithmetic intensity: {ai:.1f} FLOP/byte")
    print(f"Roofline crossover:   {crossover:.1f} FLOP/byte")
    print(f"Bottleneck:           {bottleneck}")
    print(f"Compute time:         {compute_time * 1e6:.2f} us")
    print(f"Memory time:          {memory_time * 1e6:.2f} us")
    print(f"Actual time:          {actual_time * 1e6:.2f} us")

    return ai, bottleneck

# H100 SXM
roofline_analysis(N=4096, d=128, gpu_flops_tflops=990, gpu_bw_tbs=3.35)

5.2 FlashAttention’s SRAM Tiling

FlashAttention keeps blocks of QQ, KK, VV in SRAM (shared memory on the GPU). Each block is Br×dB_r \times d or Bc×dB_c \times d, totaling a few hundred KB — well within the 192 KB shared memory per SM on H100. The tile-level computation is compute-bound (high arithmetic intensity), so the recomputation adds FLOPs that execute on otherwise-idle compute units:

def sram_usage_analysis(block_r, block_c, d, bytes_per_elem=2):
    """Compute SRAM usage for FlashAttention tiles."""
    q_bytes = block_r * d * bytes_per_elem
    k_bytes = block_c * d * bytes_per_elem
    v_bytes = block_c * d * bytes_per_elem
    s_bytes = block_r * block_c * bytes_per_elem
    o_bytes = block_r * d * bytes_per_elem
    stats_bytes = block_r * 2 * 4  # FP32 for m, ell

    total = q_bytes + k_bytes + v_bytes + s_bytes + o_bytes + stats_bytes

    print(f"Block sizes: Br={block_r}, Bc={block_c}, d={d}")
    print(f"  Q block:    {q_bytes / 1024:.1f} KB")
    print(f"  K block:    {k_bytes / 1024:.1f} KB")
    print(f"  V block:    {v_bytes / 1024:.1f} KB")
    print(f"  S tile:     {s_bytes / 1024:.1f} KB")
    print(f"  O block:    {o_bytes / 1024:.1f} KB")
    print(f"  Statistics: {stats_bytes / 1024:.1f} KB")
    print(f"  Total:      {total / 1024:.1f} KB")

    return total

# Typical FlashAttention-2 block sizes for H100
sram_usage_analysis(block_r=128, block_c=128, d=128)
# Total: ~192 KB -- fits in H100's 228 KB shared memory per SM

Correctness Verification

def verify_flash_attention_correctness(N=1024, d=64):
    """Verify FlashAttention produces correct results."""
    torch.manual_seed(42)
    Q = torch.randn(N, d, device="cuda")
    K = torch.randn(N, d, device="cuda")
    V = torch.randn(N, d, device="cuda")

    # Standard attention
    O_std, P_std = standard_attention_forward(Q, K, V)

    # FlashAttention
    O_flash, m, ell = flash_attention_forward(Q, K, V)

    # Compare
    max_diff = (O_std - O_flash).abs().max().item()
    mean_diff = (O_std - O_flash).abs().mean().item()

    print(f"Max absolute difference:  {max_diff:.2e}")
    print(f"Mean absolute difference: {mean_diff:.2e}")

    # Should be close to machine epsilon for the dtype
    assert max_diff < 1e-3, f"Forward mismatch: {max_diff}"

    # Verify backward
    dO = torch.randn_like(O_std)

    dQ_flash, dK_flash, dV_flash = flash_attention_backward(
        Q.detach(), K.detach(), V.detach(),
        O_flash.detach(), dO, m, ell
    )

    print("Correctness verified.")

Wall-Clock Performance

7.1 End-to-End Training Speed

Despite the 16.7% extra FLOPs, FlashAttention is faster in wall-clock time because:

  1. No HBM reads/writes of the N×NN \times N matrix
  2. Tiled computation stays in SRAM
  3. Better GPU utilization (compute-bound instead of memory-bound)
📊

Attention Forward+Backward Wall Time (single head, d=128, H100)

ImplementationTime (ms)Speedup
Standard (N=2048) 0.42 ms baseline
FlashAttention-2 (N=2048) 0.18 ms 2.3x
Standard (N=8192) 5.8 ms baseline
FlashAttention-2 (N=8192) 1.2 ms 4.8x
Standard (N=32768) OOM impossible
FlashAttention-2 (N=32768) 12.4 ms only option

FlashAttention Speedup vs Sequence Length

Metric 512102420484096819216384
FlashAttention-2 (H100)
1.5
1.8
2.3
3.1
4.8
7.2
FlashAttention-2 (A100)
1.3
1.6
2
2.7
4.1
6

The speedup increases with sequence length because:

  • Standard attention’s memory overhead grows as O(N2)O(N^2)
  • FlashAttention’s memory overhead grows as O(N)O(N)
  • At long sequences, standard attention becomes entirely memory-bound while FlashAttention remains compute-efficient

7.2 Memory Usage

Peak Memory Usage: Standard vs FlashAttention

Metric 204840968192163843276865536131072
Standard Attention
0.008
0.034
0.134
0.537
2.147
8.59
34.36
FlashAttention-2
0.002
0.003
0.005
0.008
0.015
0.028
0.054

At N=131072N = 131072 (128K), standard attention needs 34.36 GB per head. FlashAttention needs 54 MB. That is a 636x reduction.

FlashAttention-2 and FlashAttention-3 Improvements

8.1 FlashAttention-2: Better Parallelism

FlashAttention-2 (Dao, 2023) improved on v1 by:

  1. Parallelizing over the sequence length dimension instead of batch/heads only. This keeps all SMs busy even with small batch sizes.

  2. Reducing non-matmul FLOPs by restructuring the online softmax to minimize register pressure.

  3. Better work partitioning between warps within a thread block.

def flash_attention_2_parallelism():
    """FlashAttention-2 parallelism strategy."""
    return {
        "outer_loop": "over K, V blocks (columns)",
        "inner_loop": "over Q blocks (rows)",
        "advantage": (
            "Each thread block accumulates dQ for one row block. "
            "No atomic additions needed. "
            "v1 parallelized over rows and needed atomics for dK, dV."
        ),
        "occupancy": (
            "With N=4096, Bc=128: 32 column blocks. "
            "With 32 batch items and 32 heads: 32*32*32 = 32768 thread blocks. "
            "H100 has 132 SMs -- excellent occupancy."
        ),
    }

8.2 FlashAttention-3: Hopper-Specific Optimizations

FlashAttention-3 (Shah et al., 2024) exploits H100-specific hardware features:

  1. Asynchronous WGMMA instructions: overlap GEMM computation with data loading
  2. TMA (Tensor Memory Accelerator): hardware-accelerated HBM-to-SRAM transfers
  3. FP8 support: compute attention in FP8 E4M3 on tensor cores for 2x throughput
  4. Warp specialization: different warps within a thread block handle different tasks (producer/consumer pattern)
📊

FlashAttention Versions (H100, N=8192, d=128, BF16)

ImplementationThroughput (TFLOP/s)Speedup vs Standard
Standard PyTorch (cuDNN) 95 baseline
FlashAttention-1 180 +89%
FlashAttention-2 320 +237%
FlashAttention-3 (BF16) 510 +437%
FlashAttention-3 (FP8) 740 +679%

Integration with Activation Checkpointing

9.1 The Interaction

FlashAttention is itself a form of activation checkpointing — it trades compute (recomputing PP) for memory (not storing PP). When combined with standard activation checkpointing (which recomputes entire layers during backward), the interactions must be considered:

def checkpoint_with_flash_attention(layer, x):
    """Activation checkpointing + FlashAttention.

    Standard activation checkpointing: don't save intermediate
    activations during forward; recompute them during backward.

    FlashAttention: don't save the attention matrix; recompute
    it during backward using saved statistics.

    Combined: the attention matrix is recomputed twice --
    once by activation checkpointing (re-running forward),
    and once by FlashAttention (within that re-run's backward).

    But FlashAttention never materializes it, so the memory
    savings stack.
    """
    return torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)

def memory_with_checkpointing(N, d, n_layers, n_heads,
                                use_flash, use_checkpoint):
    """Memory analysis with combinations of optimizations."""
    hidden_size = n_heads * d
    attn_activations = N * hidden_size * 2
    ffn_activations = N * hidden_size * 4
    attn_matrix = N * N * n_heads * 2 if not use_flash else N * n_heads * 8

    per_layer = attn_activations + ffn_activations + attn_matrix

    if use_checkpoint:
        n_saved_layers = 1
    else:
        n_saved_layers = n_layers

    total_activations = per_layer * n_saved_layers

    print(f"Flash={use_flash}, Checkpoint={use_checkpoint}")
    print(f"  Per-layer activations: {per_layer / 1e9:.2f} GB")
    print(f"  Total activations:     {total_activations / 1e9:.2f} GB")

    return total_activations

9.2 Combined Memory Savings

📊

Memory Savings: Flash + Checkpointing (Llama 7B, N=4096)

ConfigurationPeak Activation MemoryReduction
Standard attention, no checkpoint 68 GB baseline
FlashAttention, no checkpoint 11 GB -84%
Standard attention + checkpoint 6.2 GB -91%
FlashAttention + checkpoint 2.1 GB -97%

The combination of FlashAttention and activation checkpointing reduces activation memory by 97%. The compute overhead is roughly 35-40% total (from both recomputation mechanisms), but the memory savings enable training larger models, longer contexts, or larger batches — which more than compensates.

Summary: The Tradeoff

FlashAttention makes a deliberate engineering trade:

What you give upWhat you get
2N2d2N^2 d extra FLOPs per backward (recomputing QKTQK^T)N2N^2 fewer elements stored (PP matrix eliminated)
16.7% more total FLOPsOver 99% less attention memory at N=128K
Slightly more complex implementationTraining at sequence lengths that would otherwise be impossible

The overhead is invisible in practice because:

  1. The GPU has spare compute capacity when attention is memory-bound
  2. Eliminating HBM reads/writes of the N×NN \times N matrix more than compensates for the extra FLOPs
  3. The memory savings enable larger batch sizes, which improve GPU utilization

This is why FlashAttention is used universally. There is no scenario in modern LLM training where storing the full attention matrix is preferable.

References

  1. Dao, T. et al. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS 2022.
  2. Dao, T. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” ICLR 2024.
  3. Shah, J. et al. “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision.” arXiv 2024.
  4. Milakov, M. and Gimelshein, N. “Online normalizer calculation for softmax.” arXiv 2018.
  5. Rabe, M. and Staats, C. “Self-attention Does Not Need O(n^2) Memory.” arXiv 2021.