Standard attention requires O(NΒ²) HBM reads/writes for sequence length N. FlashAttention restructures the algorithm to achieve O(N) HBM access while performing the same computation. This isnβt magicβitβs careful exploitation of the GPU memory hierarchy.
The Memory Bandwidth Problem
Consider attention for a single head with sequence length N=4096 and head dimension d=128:
Standard Attention Memory Traffic:
Q, K, V load: 3 Γ N Γ d Γ 2 bytes = 3 Γ 4096 Γ 128 Γ 2 = 3.14 MB
S = QK^T store: N Γ N Γ 2 bytes = 4096 Γ 4096 Γ 2 = 33.5 MB
S load for softmax: N Γ N Γ 2 bytes = 33.5 MB
P store: N Γ N Γ 2 bytes = 33.5 MB
P load, V load: N Γ N Γ 2 + N Γ d Γ 2 = 34.0 MB
O store: N Γ d Γ 2 = 1.05 MB
βββββββββββββββββββββββββββββββββββββββββββββββββββββ
Total: ~139 MB per head
For 32 heads: 4.4 GB of HBM traffic per layer.
FlashAttention Memory Traffic:
Q, K, V load: 3 Γ N Γ d Γ 2 bytes = 3.14 MB
O store: N Γ d Γ 2 = 1.05 MB
βββββββββββββββββββββββββββββββββββββββββββββββββββββ
Total: ~4.2 MB per head
For 32 heads: 134 MB of HBM traffic per layer.
HBM Traffic Comparison (32 heads, seq_len=4096)
(MB)Thatβs a 33x reduction in memory traffic.
GPU Memory Hierarchy Review
A100 GPU Memory Hierarchy
L0 L1 L2 HBM 256KB/SM 192KB/SM 40MB 80GB The key insight: SRAM (registers + shared memory) bandwidth is 10x HBM bandwidth. FlashAttention restructures attention to maximize SRAM reuse.
The Tiling Strategy
FlashAttention divides Q, K, V into blocks that fit in SRAM:
// Conceptual tiling (simplified)
// Block sizes chosen to fit in shared memory
constexpr int Br = 128; // Q block rows
constexpr int Bc = 128; // K, V block columns
constexpr int d = 128; // Head dimension
// SRAM usage per thread block:
// Q block: Br Γ d Γ 2 bytes = 32 KB
// K block: Bc Γ d Γ 2 bytes = 32 KB
// V block: Bc Γ d Γ 2 bytes = 32 KB
// O block: Br Γ d Γ 2 bytes = 32 KB
// S block: Br Γ Bc Γ 2 bytes = 32 KB
// Total: ~160 KB (fits in 192KB shared memory)
Optimal block sizes depend on head dimension and shared memory capacity. For A100 with d=128, Br=Bc=128 achieves ~85% shared memory utilization.
The Online Softmax Trick
Standard softmax requires two passes over S:
- Find max:
m = max(S) - Compute:
softmax(S) = exp(S - m) / sum(exp(S - m))
FlashAttention uses online softmax to compute in a single pass:
# Online softmax accumulation
def online_softmax_attention_block(Q_block, K_block, V_block,
O_prev, l_prev, m_prev):
"""
Process one K,V block while maintaining running softmax statistics.
Args:
Q_block: [Br, d] query block
K_block: [Bc, d] key block
V_block: [Bc, d] value block
O_prev: [Br, d] running output accumulator
l_prev: [Br] running sum of exponentials
m_prev: [Br] running max
Returns:
O_new, l_new, m_new: Updated accumulators
"""
# Compute attention scores for this block
S = Q_block @ K_block.T # [Br, Bc]
# Block-wise max and new global max
m_block = S.max(dim=-1) # [Br]
m_new = torch.maximum(m_prev, m_block)
# Rescale previous accumulator for new max
scale_prev = torch.exp(m_prev - m_new)
l_prev_scaled = l_prev * scale_prev
O_prev_scaled = O_prev * scale_prev.unsqueeze(-1)
# Compute new block contribution
P_block = torch.exp(S - m_new.unsqueeze(-1)) # [Br, Bc]
l_block = P_block.sum(dim=-1) # [Br]
# Accumulate
l_new = l_prev_scaled + l_block
O_new = O_prev_scaled + P_block @ V_block
return O_new, l_new, m_new
# Final normalization
O_final = O_new / l_new.unsqueeze(-1)
CUDA Implementation Considerations
The actual CUDA kernel involves careful register and shared memory management:
template<int Br, int Bc, int d, int WARPS_PER_BLOCK>
__global__ void flash_attention_forward(
const half* __restrict__ Q,
const half* __restrict__ K,
const half* __restrict__ V,
half* __restrict__ O,
int N
) {
// Shared memory allocation
extern __shared__ char smem[];
half* sQ = reinterpret_cast<half*>(smem);
half* sK = sQ + Br * d;
half* sV = sK + Bc * d;
// Per-thread accumulators (in registers)
float O_acc[Br / WARPS_PER_BLOCK][d / 32]; // Each thread handles a tile
float l_acc[Br / WARPS_PER_BLOCK]; // Running sum
float m_acc[Br / WARPS_PER_BLOCK]; // Running max
// Initialize accumulators
#pragma unroll
for (int i = 0; i < Br / WARPS_PER_BLOCK; i++) {
m_acc[i] = -INFINITY;
l_acc[i] = 0.0f;
#pragma unroll
for (int j = 0; j < d / 32; j++) {
O_acc[i][j] = 0.0f;
}
}
// Load Q block once (reused across all K,V blocks)
load_block_async<Br, d>(Q + blockIdx.x * Br * d, sQ, N, d);
__syncthreads();
// Iterate over K,V blocks
for (int kv_block = 0; kv_block < (N + Bc - 1) / Bc; kv_block++) {
// Load K, V blocks
load_block_async<Bc, d>(K + kv_block * Bc * d, sK, N, d);
load_block_async<Bc, d>(V + kv_block * Bc * d, sV, N, d);
__syncthreads();
// Compute S = Q @ K^T using tensor cores
half S_frag[Br / WARPS_PER_BLOCK][Bc / 32];
mma_sync(S_frag, sQ, sK); // Simplified - actual uses wmma/mma
// Online softmax update (in registers)
#pragma unroll
for (int i = 0; i < Br / WARPS_PER_BLOCK; i++) {
float row_max = -INFINITY;
#pragma unroll
for (int j = 0; j < Bc / 32; j++) {
row_max = fmaxf(row_max, __half2float(S_frag[i][j]));
}
row_max = warp_reduce_max(row_max);
float new_max = fmaxf(m_acc[i], row_max);
float scale = expf(m_acc[i] - new_max);
l_acc[i] *= scale;
O_acc[i][:] *= scale; // Conceptual - vectorized in practice
// Accumulate this block
float row_sum = 0.0f;
#pragma unroll
for (int j = 0; j < Bc / 32; j++) {
float p = expf(__half2float(S_frag[i][j]) - new_max);
row_sum += p;
// O_acc += p * V - done via mma
}
l_acc[i] += row_sum;
m_acc[i] = new_max;
}
__syncthreads();
}
// Final normalization and store
#pragma unroll
for (int i = 0; i < Br / WARPS_PER_BLOCK; i++) {
float inv_l = 1.0f / l_acc[i];
#pragma unroll
for (int j = 0; j < d / 32; j++) {
O_acc[i][j] *= inv_l;
}
}
store_block(O + blockIdx.x * Br * d, O_acc);
}
Roofline Analysis
FlashAttention Roofline Position (A100)
| Operation | Arithmetic Intensity | Achieved FLOPS | Bound |
|---|---|---|---|
| Standard Attention | 2.8 FLOP/byte | 1.2 TFLOPS | Memory |
| FlashAttention | 89 FLOP/byte | 124 TFLOPS | Compute |
| FlashAttention-2 | 102 FLOP/byte | 156 TFLOPS | Compute |
| A100 Peak | - | 312 TFLOPS | - |
FlashAttention moves attention from memory-bound to compute-bound by increasing arithmetic intensity 30x.
FlashAttention-2 Improvements
FlashAttention-2 achieves additional speedup through:
- Reduced non-matmul FLOPs: Moved rescaling outside inner loop
- Better parallelism: Parallelize over sequence length, not just batch
- Improved warp scheduling: Better occupancy on Ampere/Hopper
# FlashAttention-2 key optimization: delayed rescaling
# Instead of rescaling O_prev each block, accumulate then rescale once
def flash_attention_2_block(Q_block, K_block, V_block, acc):
S = Q_block @ K_block.T
m_block = S.max(dim=-1)
# Don't rescale yet - just track the scaling factors
m_new = torch.maximum(acc.m, m_block)
P = torch.exp(S - m_block.unsqueeze(-1)) # Local softmax
PV = P @ V_block
# Accumulate with deferred scaling
acc.O_unscaled += PV * torch.exp(m_block - m_new).unsqueeze(-1)
acc.l *= torch.exp(acc.m - m_new)
acc.l += P.sum(dim=-1) * torch.exp(m_block - m_new)
acc.m = m_new
return acc
# Final rescaling (once, after all blocks)
O_final = acc.O_unscaled / acc.l.unsqueeze(-1)
Profiling FlashAttention
# Measure achieved memory bandwidth
ncu --set full \
--metrics dram__bytes_read.sum,dram__bytes_write.sum,\
sm__sass_thread_inst_executed_op_fadd_pred_on.sum,\
sm__sass_thread_inst_executed_op_fmul_pred_on.sum \
python -c "
import torch
from flash_attn import flash_attn_func
q = torch.randn(1, 4096, 32, 128, device='cuda', dtype=torch.float16)
k = torch.randn(1, 4096, 32, 128, device='cuda', dtype=torch.float16)
v = torch.randn(1, 4096, 32, 128, device='cuda', dtype=torch.float16)
for _ in range(100):
o = flash_attn_func(q, k, v)
torch.cuda.synchronize()
"
Expected results on A100:
- HBM Read: ~134 MB (vs 4.4 GB for standard)
- HBM Write: ~33 MB
- Achieved TFLOPS: 150+ (48% of peak)
- SM Occupancy: 75-85%
When FlashAttention Isnβt Optimal
FlashAttention has overhead for:
- Very short sequences (N < 256): Tiling overhead dominates
- Very small batch sizes: Canβt saturate SMs
- Non-standard attention patterns: Requires custom kernels
For N < 512, cuBLAS GEMM-based attention often outperforms FlashAttention due to lower kernel launch overhead and better SM utilization at small problem sizes.
Conclusion
FlashAttentionβs 33x reduction in HBM traffic comes from a fundamental restructuring of the attention algorithm, not approximation. Understanding this restructuringβand the memory hierarchy constraints that motivate itβis essential for anyone optimizing transformer inference.
The key insight generalizes: Any O(NΒ²) intermediate tensor that can be computed on-the-fly should be. This principle applies beyond attention to any algorithm with large intermediate materialization.