Standard attention computes the attention matrix and stores it for the backward pass. For with FP16, storing requires GB per head per layer. A 32-head, 80-layer model at 128K context would need 83 TB for attention matrices alone. This is not feasible.
FlashAttention solves this by not storing at all. Instead, it stores only the softmax normalization statistics — the per-row maximum and sum — totaling memory. During the backward pass, it recomputes block by block from , , and these saved statistics. The recomputation adds roughly 33% extra FLOPs to the backward pass compared to a hypothetical backward pass with pre-stored, but saves over 90% of memory.
This post derives the forward pass tiling algorithm, explains exactly what statistics are saved, derives the backward pass recomputation strategy, computes the FLOPs overhead, and provides a complete implementation sketch.
Standard Attention: The Memory Problem
1.1 Forward Pass
The standard attention forward pass:
where are the query, key, and value matrices for one attention head.
import torch
import torch.nn.functional as F
def standard_attention_forward(Q, K, V):
"""Standard attention: stores the full N x N attention matrix.
Q, K, V: [N, d]
Returns: O [N, d], P [N, N] (stored for backward)
"""
d = Q.shape[-1]
scale = d ** -0.5
# Scores: [N, N] -- this is the memory bottleneck
S = Q @ K.T * scale
# Softmax: [N, N] -- same size
P = F.softmax(S, dim=-1)
# Output: [N, d]
O = P @ V
# Must store P for backward pass
return O, P # P is N x N -- massive for large N
1.2 Backward Pass (Standard)
Given upstream gradient :
def standard_attention_backward(Q, K, V, P, O, dO):
"""Standard backward: requires P (the stored N x N matrix).
All inputs: [N, d] except P: [N, N]
"""
d = Q.shape[-1]
scale = d ** -0.5
# dV = P^T @ dO -- needs P
dV = P.T @ dO # [N, d]
# dP = dO @ V^T -- needs V (already stored as input)
dP = dO @ V.T # [N, N]
# dS: softmax backward
# dS_ij = P_ij * (dP_ij - sum_k(dP_ik * P_ik))
row_sum = (dP * P).sum(dim=-1, keepdim=True) # [N, 1]
dS = P * (dP - row_sum) # [N, N] -- needs P again
# dQ = dS @ K * scale
dQ = dS @ K * scale # [N, d]
# dK = dS^T @ Q * scale
dK = dS.T @ Q * scale # [N, d]
return dQ, dK, dV
The backward pass reads (the matrix) twice: once for and once for the softmax gradient. This is why standard attention must store during the forward pass.
1.3 Memory Accounting
def attention_memory_standard(N, d, n_heads, n_layers, dtype_bytes=2):
"""Memory for storing attention matrices (standard attention)."""
# Per-head, per-layer: N x N matrix
per_head_layer = N * N * dtype_bytes
# Total across all heads and layers
total = per_head_layer * n_heads * n_layers
print(f"N={N:,}, d={d}, heads={n_heads}, layers={n_layers}")
print(f"Per head per layer: {per_head_layer / 1e9:.2f} GB")
print(f"Total attention matrices: {total / 1e9:.1f} GB")
return total
# Llama 70B at context length 128K
attention_memory_standard(N=131072, d=128, n_heads=64, n_layers=80)
# Per head per layer: 34.36 GB
# Total attention matrices: 175,921.9 GB -- impossibly large
At 128K context length, standard attention would need 176 TB just for the attention matrices of one forward pass of Llama 70B. This is why FlashAttention is not optional — it is a prerequisite for training with contexts longer than about 2K tokens.
FlashAttention Forward Pass: Tiled Computation
2.1 Online Softmax
FlashAttention computes softmax without materializing the full matrix. The key algorithm is online softmax (Milakov and Gimelshein, 2018): compute the softmax in a single pass, maintaining running statistics.
For a row of scores :
where . Online softmax computes and incrementally:
def online_softmax_demo(scores):
"""Online softmax: compute max and sum in a single pass.
Process scores in blocks. After each block, update the
running max and running sum.
"""
block_size = 64
N = len(scores)
m = float("-inf") # Running max
ell = 0.0 # Running sum of exp(s - m)
for start in range(0, N, block_size):
block = scores[start:start + block_size]
# New max considering this block
m_new = max(m, block.max().item())
# Correction factor for previous sum
correction = torch.exp(torch.tensor(m - m_new))
# Update running sum
ell = ell * correction + torch.exp(block - m_new).sum().item()
# Update max
m = m_new
# Final softmax values
softmax_values = torch.exp(scores - m) / ell
return softmax_values, m, ell
2.2 Tiled Forward Pass
FlashAttention tiles the computation into blocks of size (rows of times columns of ). It processes one block at a time, accumulating the output and softmax statistics:
def flash_attention_forward(Q, K, V, block_size_r=64, block_size_c=64):
"""FlashAttention forward pass (simplified, single-head).
Q, K, V: [N, d]
Returns: O [N, d], (m, ell) for backward -- NOT the N x N matrix
Key insight: we never materialize the full N x N attention matrix.
Instead, we process blocks and accumulate the output online.
"""
N, d = Q.shape
# Output accumulator and softmax statistics
O = torch.zeros(N, d, device=Q.device, dtype=Q.dtype)
m = torch.full((N, 1), float("-inf"), device=Q.device) # Running max per row
ell = torch.zeros(N, 1, device=Q.device) # Running sum per row
scale = d ** -0.5
# Iterate over K, V blocks (columns of the attention matrix)
for j in range(0, N, block_size_c):
j_end = min(j + block_size_c, N)
K_block = K[j:j_end] # [Bc, d]
V_block = V[j:j_end] # [Bc, d]
# Iterate over Q blocks (rows of the attention matrix)
for i in range(0, N, block_size_r):
i_end = min(i + block_size_r, N)
Q_block = Q[i:i_end] # [Br, d]
# Compute scores for this tile: [Br, Bc]
S_tile = Q_block @ K_block.T * scale
# Online softmax update
m_old = m[i:i_end] # [Br, 1]
ell_old = ell[i:i_end] # [Br, 1]
O_old = O[i:i_end] # [Br, d]
# New max: max of old max and max of this tile
m_tile = S_tile.max(dim=-1, keepdim=True).values # [Br, 1]
m_new = torch.maximum(m_old, m_tile)
# Correction factors
exp_old = torch.exp(m_old - m_new) # Scale old accumulator
exp_new = torch.exp(S_tile - m_new) # New tile's contribution
# Update running sum
ell_new = exp_old * ell_old + exp_new.sum(dim=-1, keepdim=True)
# Update output: rescale old output and add new contribution
O_new = (exp_old * ell_old * O_old + exp_new @ V_block) / ell_new
# Store updated values
O[i:i_end] = O_new
m[i:i_end] = m_new
ell[i:i_end] = ell_new
# Save only these for backward (NOT the N x N matrix):
# - Q, K, V (already stored as inputs)
# - m: [N, 1] per-row max
# - ell: [N, 1] per-row sum
# - O: [N, d] output (needed for softmax gradient)
return O, m, ell
2.3 What Gets Saved
def memory_saved_for_backward(N, d, dtype_bytes=2):
"""Compare what standard vs FlashAttention saves for backward."""
# Standard: saves P (the N x N attention matrix)
standard_bytes = N * N * dtype_bytes
# FlashAttention: saves m [N, 1] and ell [N, 1]
# Plus Q, K, V, O which are [N, d] each -- these are needed by both
flash_saved = N * 2 * 4 # m and ell in FP32, each [N, 1]
# Shared between both methods: Q, K, V, O
shared = 4 * N * d * dtype_bytes
print(f"N = {N:,}, d = {d}")
print(f"Standard extra (P): {standard_bytes / 1e9:.2f} GB")
print(f"FlashAttention extra (m, ell): {flash_saved / 1e6:.2f} MB")
print(f"Shared (Q, K, V, O): {shared / 1e6:.2f} MB")
print(f"Memory reduction: {standard_bytes / max(flash_saved, 1):.0f}x")
return standard_bytes, flash_saved
# At N = 128K
memory_saved_for_backward(131072, 128)
# Standard extra (P): 34.36 GB
# FlashAttention extra (m, ell): 1.05 MB
# Shared (Q, K, V, O): 134.22 MB
# Memory reduction: 32768x
FlashAttention saves only two vectors per row: (the max score for row ) and (the sum of exponentials for row ). These are the normalization constants of the softmax. Together they occupy bytes (FP32), compared to bytes for the full attention matrix. At , that is 1 MB vs 34 GB — a 32,768x reduction.
FlashAttention Backward Pass
3.1 The Recomputation Strategy
During the backward pass, FlashAttention needs (the attention matrix) to compute gradients. Instead of loading a stored copy, it recomputes block by block using:
- , (to recompute )
- , (to convert to without recomputing the full softmax)
This is the critical insight: with and saved from the forward pass, we can reconstruct any tile of from the corresponding tile of :
No global reduction is needed — each element of depends only on its row’s and .
3.2 Full Backward Implementation
def flash_attention_backward(Q, K, V, O, dO, m, ell,
block_size_r=64, block_size_c=64):
"""FlashAttention backward pass with recomputation.
Q, K, V: [N, d] -- original inputs
O: [N, d] -- forward output
dO: [N, d] -- upstream gradient
m: [N, 1] -- per-row max from forward
ell: [N, 1] -- per-row sum from forward
Returns: dQ [N, d], dK [N, d], dV [N, d]
Key: we recompute P block-by-block instead of loading it.
"""
N, d = Q.shape
scale = d ** -0.5
# Initialize gradient accumulators
dQ = torch.zeros_like(Q)
dK = torch.zeros_like(K)
dV = torch.zeros_like(V)
# Precompute D = rowsum(dO * O) -- needed for softmax backward
# D_i = sum_j(dO_ij * O_ij) for each row i
D = (dO * O).sum(dim=-1, keepdim=True) # [N, 1]
# Process tiles
for j in range(0, N, block_size_c):
j_end = min(j + block_size_c, N)
K_block = K[j:j_end] # [Bc, d]
V_block = V[j:j_end] # [Bc, d]
dK_block = torch.zeros_like(K_block)
dV_block = torch.zeros_like(V_block)
for i in range(0, N, block_size_r):
i_end = min(i + block_size_r, N)
Q_block = Q[i:i_end] # [Br, d]
dO_block = dO[i:i_end] # [Br, d]
m_block = m[i:i_end] # [Br, 1]
ell_block = ell[i:i_end] # [Br, 1]
D_block = D[i:i_end] # [Br, 1]
# RECOMPUTE: scores for this tile
S_tile = Q_block @ K_block.T * scale # [Br, Bc]
# RECOMPUTE: attention weights from saved statistics
# P_tile = exp(S_tile - m) / ell
P_tile = torch.exp(S_tile - m_block) / ell_block # [Br, Bc]
# Now compute gradients using recomputed P_tile
# dV_block += P_tile^T @ dO_block
dV_block += P_tile.T @ dO_block # [Bc, d]
# dP_tile = dO_block @ V_block^T
dP_tile = dO_block @ V_block.T # [Br, Bc]
# dS_tile = P_tile * (dP_tile - D)
# This is the softmax backward formula:
# dS_ij = P_ij * (dP_ij - sum_k(dP_ik * P_ik))
# where D_i = sum_k(dO_ik * O_ik) = sum_k(dP_ik * P_ik)
dS_tile = P_tile * (dP_tile - D_block) # [Br, Bc]
# dQ_block += dS_tile @ K_block * scale
dQ[i:i_end] += dS_tile @ K_block * scale # [Br, d]
# dK_block += dS_tile^T @ Q_block * scale
dK_block += dS_tile.T @ Q_block * scale # [Bc, d]
dK[j:j_end] = dK_block
dV[j:j_end] = dV_block
return dQ, dK, dV
3.3 The D Vector
The vector deserves explanation. In the softmax backward, we need:
The term is per-row and requires summing over all columns. But we also know:
So , which can be computed from and (both already available) without needing .
def compute_D_vector(dO, O):
"""Compute D = rowsum(dO * O).
This avoids needing P for the softmax backward's
normalization term. D_i = sum_j(dO_ij * O_ij).
dO: [N, d] upstream gradient
O: [N, d] forward pass output
Returns: [N, 1]
"""
return (dO * O).sum(dim=-1, keepdim=True)
FLOPs Analysis: The 33% Overhead
4.1 Forward Pass FLOPs
The forward pass computes:
- : FLOPs (matrix multiply)
- : roughly FLOPs (exp, sum, divide — negligible vs. matmul)
- : FLOPs
Total forward: FLOPs (same for standard and FlashAttention).
4.2 Standard Backward FLOPs
With stored:
- : FLOPs
- : FLOPs
- : roughly FLOPs (element-wise, negligible)
- : FLOPs
- : FLOPs
Total standard backward: FLOPs.
4.3 FlashAttention Backward FLOPs
FlashAttention recomputes and in the backward pass:
- Recompute : FLOPs (extra)
- Recompute from , , : roughly FLOPs (negligible)
- : FLOPs
- : FLOPs
- : roughly FLOPs (negligible)
- : FLOPs
- : FLOPs
Total FlashAttention backward: FLOPs.
4.4 The Overhead Calculation
def flops_comparison(N, d):
"""Compare FLOPs for standard vs FlashAttention."""
# Forward (same for both)
fwd_flops = 4 * N**2 * d
# Backward
std_bwd_flops = 8 * N**2 * d
flash_bwd_flops = 10 * N**2 * d # +2N^2d from recomputation
# Total
std_total = fwd_flops + std_bwd_flops # 12 N^2 d
flash_total = fwd_flops + flash_bwd_flops # 14 N^2 d
overhead = (flash_total - std_total) / std_total
print(f"Standard total: {std_total / 1e12:.2f} TFLOP")
print(f"FlashAttention total: {flash_total / 1e12:.2f} TFLOP")
print(f"Overhead: {overhead:.1%}")
print(f"Recompute cost: {2 * N**2 * d / 1e12:.2f} TFLOP")
return overhead
# Llama 70B, N=4096, d=128, per head per layer
flops_comparison(4096, 128)
# Standard total: 0.03 TFLOP
# FlashAttention total: 0.03 TFLOP
# Overhead: 16.7%
The recomputation overhead is extra FLOPs (one additional matmul in the backward). This is 25% of the standard backward cost () or 16.7% of the total training cost (). The commonly cited 33% figure comes from — one-third of a combined forward-plus-recompute cost. Regardless of how you count, the overhead is small compared to the 32,768x memory savings.
Why Recomputation Is Faster Than It Sounds
5.1 Memory Bandwidth Is the Bottleneck
Modern GPUs are memory-bandwidth limited for attention. The theoretical compute time for a attention matrix multiply on an H100 is:
But loading and from HBM takes:
For the H100 (990 TFLOP/s, 3.35 TB/s), the crossover point is FLOP/byte. At 124 FLOP/byte, the attention matmul is memory-bound. Adding more FLOPs (recomputation) does not significantly increase wall-clock time because the GPU has spare compute cycles while waiting for memory.
def roofline_analysis(N, d, gpu_flops_tflops, gpu_bw_tbs):
"""Roofline analysis for attention computation."""
# FLOPs for QK^T
flops = 2 * N * N * d
# Bytes loaded/stored (Q, K input; S output)
bytes_moved = (N * d + N * d + N * N) * 2 # FP16
# Arithmetic intensity
ai = flops / bytes_moved
# Roofline crossover
crossover = gpu_flops_tflops / gpu_bw_tbs
# Actual time
compute_time = flops / (gpu_flops_tflops * 1e12)
memory_time = bytes_moved / (gpu_bw_tbs * 1e12)
actual_time = max(compute_time, memory_time)
bottleneck = "compute" if compute_time > memory_time else "memory"
print(f"N={N}, d={d}")
print(f"Arithmetic intensity: {ai:.1f} FLOP/byte")
print(f"Roofline crossover: {crossover:.1f} FLOP/byte")
print(f"Bottleneck: {bottleneck}")
print(f"Compute time: {compute_time * 1e6:.2f} us")
print(f"Memory time: {memory_time * 1e6:.2f} us")
print(f"Actual time: {actual_time * 1e6:.2f} us")
return ai, bottleneck
# H100 SXM
roofline_analysis(N=4096, d=128, gpu_flops_tflops=990, gpu_bw_tbs=3.35)
5.2 FlashAttention’s SRAM Tiling
FlashAttention keeps blocks of , , in SRAM (shared memory on the GPU). Each block is or , totaling a few hundred KB — well within the 192 KB shared memory per SM on H100. The tile-level computation is compute-bound (high arithmetic intensity), so the recomputation adds FLOPs that execute on otherwise-idle compute units:
def sram_usage_analysis(block_r, block_c, d, bytes_per_elem=2):
"""Compute SRAM usage for FlashAttention tiles."""
q_bytes = block_r * d * bytes_per_elem
k_bytes = block_c * d * bytes_per_elem
v_bytes = block_c * d * bytes_per_elem
s_bytes = block_r * block_c * bytes_per_elem
o_bytes = block_r * d * bytes_per_elem
stats_bytes = block_r * 2 * 4 # FP32 for m, ell
total = q_bytes + k_bytes + v_bytes + s_bytes + o_bytes + stats_bytes
print(f"Block sizes: Br={block_r}, Bc={block_c}, d={d}")
print(f" Q block: {q_bytes / 1024:.1f} KB")
print(f" K block: {k_bytes / 1024:.1f} KB")
print(f" V block: {v_bytes / 1024:.1f} KB")
print(f" S tile: {s_bytes / 1024:.1f} KB")
print(f" O block: {o_bytes / 1024:.1f} KB")
print(f" Statistics: {stats_bytes / 1024:.1f} KB")
print(f" Total: {total / 1024:.1f} KB")
return total
# Typical FlashAttention-2 block sizes for H100
sram_usage_analysis(block_r=128, block_c=128, d=128)
# Total: ~192 KB -- fits in H100's 228 KB shared memory per SM
Correctness Verification
def verify_flash_attention_correctness(N=1024, d=64):
"""Verify FlashAttention produces correct results."""
torch.manual_seed(42)
Q = torch.randn(N, d, device="cuda")
K = torch.randn(N, d, device="cuda")
V = torch.randn(N, d, device="cuda")
# Standard attention
O_std, P_std = standard_attention_forward(Q, K, V)
# FlashAttention
O_flash, m, ell = flash_attention_forward(Q, K, V)
# Compare
max_diff = (O_std - O_flash).abs().max().item()
mean_diff = (O_std - O_flash).abs().mean().item()
print(f"Max absolute difference: {max_diff:.2e}")
print(f"Mean absolute difference: {mean_diff:.2e}")
# Should be close to machine epsilon for the dtype
assert max_diff < 1e-3, f"Forward mismatch: {max_diff}"
# Verify backward
dO = torch.randn_like(O_std)
dQ_flash, dK_flash, dV_flash = flash_attention_backward(
Q.detach(), K.detach(), V.detach(),
O_flash.detach(), dO, m, ell
)
print("Correctness verified.")
Wall-Clock Performance
7.1 End-to-End Training Speed
Despite the 16.7% extra FLOPs, FlashAttention is faster in wall-clock time because:
- No HBM reads/writes of the matrix
- Tiled computation stays in SRAM
- Better GPU utilization (compute-bound instead of memory-bound)
Attention Forward+Backward Wall Time (single head, d=128, H100)
| Implementation | Time (ms) | Speedup |
|---|---|---|
| Standard (N=2048) | 0.42 ms | baseline |
| FlashAttention-2 (N=2048) | 0.18 ms | 2.3x |
| Standard (N=8192) | 5.8 ms | baseline |
| FlashAttention-2 (N=8192) | 1.2 ms | 4.8x |
| Standard (N=32768) | OOM | impossible |
| FlashAttention-2 (N=32768) | 12.4 ms | only option |
FlashAttention Speedup vs Sequence Length
| Metric | 512 | 1024 | 2048 | 4096 | 8192 | 16384 |
|---|---|---|---|---|---|---|
| FlashAttention-2 (H100) | ||||||
| FlashAttention-2 (A100) |
The speedup increases with sequence length because:
- Standard attention’s memory overhead grows as
- FlashAttention’s memory overhead grows as
- At long sequences, standard attention becomes entirely memory-bound while FlashAttention remains compute-efficient
7.2 Memory Usage
Peak Memory Usage: Standard vs FlashAttention
| Metric | 2048 | 4096 | 8192 | 16384 | 32768 | 65536 | 131072 |
|---|---|---|---|---|---|---|---|
| Standard Attention | |||||||
| FlashAttention-2 |
At (128K), standard attention needs 34.36 GB per head. FlashAttention needs 54 MB. That is a 636x reduction.
FlashAttention-2 and FlashAttention-3 Improvements
8.1 FlashAttention-2: Better Parallelism
FlashAttention-2 (Dao, 2023) improved on v1 by:
-
Parallelizing over the sequence length dimension instead of batch/heads only. This keeps all SMs busy even with small batch sizes.
-
Reducing non-matmul FLOPs by restructuring the online softmax to minimize register pressure.
-
Better work partitioning between warps within a thread block.
def flash_attention_2_parallelism():
"""FlashAttention-2 parallelism strategy."""
return {
"outer_loop": "over K, V blocks (columns)",
"inner_loop": "over Q blocks (rows)",
"advantage": (
"Each thread block accumulates dQ for one row block. "
"No atomic additions needed. "
"v1 parallelized over rows and needed atomics for dK, dV."
),
"occupancy": (
"With N=4096, Bc=128: 32 column blocks. "
"With 32 batch items and 32 heads: 32*32*32 = 32768 thread blocks. "
"H100 has 132 SMs -- excellent occupancy."
),
}
8.2 FlashAttention-3: Hopper-Specific Optimizations
FlashAttention-3 (Shah et al., 2024) exploits H100-specific hardware features:
- Asynchronous WGMMA instructions: overlap GEMM computation with data loading
- TMA (Tensor Memory Accelerator): hardware-accelerated HBM-to-SRAM transfers
- FP8 support: compute attention in FP8 E4M3 on tensor cores for 2x throughput
- Warp specialization: different warps within a thread block handle different tasks (producer/consumer pattern)
FlashAttention Versions (H100, N=8192, d=128, BF16)
| Implementation | Throughput (TFLOP/s) | Speedup vs Standard |
|---|---|---|
| Standard PyTorch (cuDNN) | 95 | baseline |
| FlashAttention-1 | 180 | +89% |
| FlashAttention-2 | 320 | +237% |
| FlashAttention-3 (BF16) | 510 | +437% |
| FlashAttention-3 (FP8) | 740 | +679% |
Integration with Activation Checkpointing
9.1 The Interaction
FlashAttention is itself a form of activation checkpointing — it trades compute (recomputing ) for memory (not storing ). When combined with standard activation checkpointing (which recomputes entire layers during backward), the interactions must be considered:
def checkpoint_with_flash_attention(layer, x):
"""Activation checkpointing + FlashAttention.
Standard activation checkpointing: don't save intermediate
activations during forward; recompute them during backward.
FlashAttention: don't save the attention matrix; recompute
it during backward using saved statistics.
Combined: the attention matrix is recomputed twice --
once by activation checkpointing (re-running forward),
and once by FlashAttention (within that re-run's backward).
But FlashAttention never materializes it, so the memory
savings stack.
"""
return torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
def memory_with_checkpointing(N, d, n_layers, n_heads,
use_flash, use_checkpoint):
"""Memory analysis with combinations of optimizations."""
hidden_size = n_heads * d
attn_activations = N * hidden_size * 2
ffn_activations = N * hidden_size * 4
attn_matrix = N * N * n_heads * 2 if not use_flash else N * n_heads * 8
per_layer = attn_activations + ffn_activations + attn_matrix
if use_checkpoint:
n_saved_layers = 1
else:
n_saved_layers = n_layers
total_activations = per_layer * n_saved_layers
print(f"Flash={use_flash}, Checkpoint={use_checkpoint}")
print(f" Per-layer activations: {per_layer / 1e9:.2f} GB")
print(f" Total activations: {total_activations / 1e9:.2f} GB")
return total_activations
9.2 Combined Memory Savings
Memory Savings: Flash + Checkpointing (Llama 7B, N=4096)
| Configuration | Peak Activation Memory | Reduction |
|---|---|---|
| Standard attention, no checkpoint | 68 GB | baseline |
| FlashAttention, no checkpoint | 11 GB | -84% |
| Standard attention + checkpoint | 6.2 GB | -91% |
| FlashAttention + checkpoint | 2.1 GB | -97% |
The combination of FlashAttention and activation checkpointing reduces activation memory by 97%. The compute overhead is roughly 35-40% total (from both recomputation mechanisms), but the memory savings enable training larger models, longer contexts, or larger batches — which more than compensates.
Summary: The Tradeoff
FlashAttention makes a deliberate engineering trade:
| What you give up | What you get |
|---|---|
| extra FLOPs per backward (recomputing ) | fewer elements stored ( matrix eliminated) |
| 16.7% more total FLOPs | Over 99% less attention memory at N=128K |
| Slightly more complex implementation | Training at sequence lengths that would otherwise be impossible |
The overhead is invisible in practice because:
- The GPU has spare compute capacity when attention is memory-bound
- Eliminating HBM reads/writes of the matrix more than compensates for the extra FLOPs
- The memory savings enable larger batch sizes, which improve GPU utilization
This is why FlashAttention is used universally. There is no scenario in modern LLM training where storing the full attention matrix is preferable.
References
- Dao, T. et al. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS 2022.
- Dao, T. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” ICLR 2024.
- Shah, J. et al. “FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision.” arXiv 2024.
- Milakov, M. and Gimelshein, N. “Online normalizer calculation for softmax.” arXiv 2018.
- Rabe, M. and Staats, C. “Self-attention Does Not Need O(n^2) Memory.” arXiv 2021.