Standard attention requires O(NΒ²) HBM reads/writes for sequence length N. FlashAttention restructures the algorithm to achieve O(N) HBM access while performing the same computation. This isn’t magicβ€”it’s careful exploitation of the GPU memory hierarchy.

The Memory Bandwidth Problem

Consider attention for a single head with sequence length N=4096 and head dimension d=128:

Standard Attention Memory Traffic:

Q, K, V load:       3 Γ— N Γ— d Γ— 2 bytes = 3 Γ— 4096 Γ— 128 Γ— 2 = 3.14 MB
S = QK^T store:     N Γ— N Γ— 2 bytes = 4096 Γ— 4096 Γ— 2 = 33.5 MB
S load for softmax: N Γ— N Γ— 2 bytes = 33.5 MB  
P store:            N Γ— N Γ— 2 bytes = 33.5 MB
P load, V load:     N Γ— N Γ— 2 + N Γ— d Γ— 2 = 34.0 MB
O store:            N Γ— d Γ— 2 = 1.05 MB
─────────────────────────────────────────────────────
Total:              ~139 MB per head

For 32 heads: 4.4 GB of HBM traffic per layer.

FlashAttention Memory Traffic:

Q, K, V load:       3 Γ— N Γ— d Γ— 2 bytes = 3.14 MB
O store:            N Γ— d Γ— 2 = 1.05 MB
─────────────────────────────────────────────────────
Total:              ~4.2 MB per head

For 32 heads: 134 MB of HBM traffic per layer.

HBM Traffic Comparison (32 heads, seq_len=4096)

(MB)
Standard Attention
4,450 MB
FlashAttention
134 MB

That’s a 33x reduction in memory traffic.

GPU Memory Hierarchy Review

A100 GPU Memory Hierarchy

L0
L1
L2
HBM
Registers 256KB/SM
Shared Memory / L1 192KB/SM
L2 Cache 40MB
HBM2e 80GB
~19 TB/s bandwidth
~19 TB/s aggregate
~5 TB/s bandwidth
2.0 TB/s bandwidth
Registers 256KB/SM
Shared Memory / L1 192KB/SM
L2 Cache 40MB
HBM2e 80GB

The key insight: SRAM (registers + shared memory) bandwidth is 10x HBM bandwidth. FlashAttention restructures attention to maximize SRAM reuse.

The Tiling Strategy

FlashAttention divides Q, K, V into blocks that fit in SRAM:

// Conceptual tiling (simplified)
// Block sizes chosen to fit in shared memory
constexpr int Br = 128;  // Q block rows
constexpr int Bc = 128;  // K, V block columns
constexpr int d = 128;   // Head dimension

// SRAM usage per thread block:
// Q block:  Br Γ— d Γ— 2 bytes = 32 KB
// K block:  Bc Γ— d Γ— 2 bytes = 32 KB  
// V block:  Bc Γ— d Γ— 2 bytes = 32 KB
// O block:  Br Γ— d Γ— 2 bytes = 32 KB
// S block:  Br Γ— Bc Γ— 2 bytes = 32 KB
// Total: ~160 KB (fits in 192KB shared memory)
⚑ Block Size Selection

Optimal block sizes depend on head dimension and shared memory capacity. For A100 with d=128, Br=Bc=128 achieves ~85% shared memory utilization.

The Online Softmax Trick

Standard softmax requires two passes over S:

  1. Find max: m = max(S)
  2. Compute: softmax(S) = exp(S - m) / sum(exp(S - m))

FlashAttention uses online softmax to compute in a single pass:

# Online softmax accumulation
def online_softmax_attention_block(Q_block, K_block, V_block, 
                                    O_prev, l_prev, m_prev):
    """
    Process one K,V block while maintaining running softmax statistics.
    
    Args:
        Q_block: [Br, d] query block
        K_block: [Bc, d] key block  
        V_block: [Bc, d] value block
        O_prev: [Br, d] running output accumulator
        l_prev: [Br] running sum of exponentials
        m_prev: [Br] running max
    
    Returns:
        O_new, l_new, m_new: Updated accumulators
    """
    # Compute attention scores for this block
    S = Q_block @ K_block.T  # [Br, Bc]
    
    # Block-wise max and new global max
    m_block = S.max(dim=-1)  # [Br]
    m_new = torch.maximum(m_prev, m_block)
    
    # Rescale previous accumulator for new max
    scale_prev = torch.exp(m_prev - m_new)
    l_prev_scaled = l_prev * scale_prev
    O_prev_scaled = O_prev * scale_prev.unsqueeze(-1)
    
    # Compute new block contribution
    P_block = torch.exp(S - m_new.unsqueeze(-1))  # [Br, Bc]
    l_block = P_block.sum(dim=-1)  # [Br]
    
    # Accumulate
    l_new = l_prev_scaled + l_block
    O_new = O_prev_scaled + P_block @ V_block
    
    return O_new, l_new, m_new

# Final normalization
O_final = O_new / l_new.unsqueeze(-1)

CUDA Implementation Considerations

The actual CUDA kernel involves careful register and shared memory management:

template<int Br, int Bc, int d, int WARPS_PER_BLOCK>
__global__ void flash_attention_forward(
    const half* __restrict__ Q,
    const half* __restrict__ K,
    const half* __restrict__ V,
    half* __restrict__ O,
    int N
) {
    // Shared memory allocation
    extern __shared__ char smem[];
    half* sQ = reinterpret_cast<half*>(smem);
    half* sK = sQ + Br * d;
    half* sV = sK + Bc * d;
    
    // Per-thread accumulators (in registers)
    float O_acc[Br / WARPS_PER_BLOCK][d / 32];  // Each thread handles a tile
    float l_acc[Br / WARPS_PER_BLOCK];           // Running sum
    float m_acc[Br / WARPS_PER_BLOCK];           // Running max
    
    // Initialize accumulators
    #pragma unroll
    for (int i = 0; i < Br / WARPS_PER_BLOCK; i++) {
        m_acc[i] = -INFINITY;
        l_acc[i] = 0.0f;
        #pragma unroll
        for (int j = 0; j < d / 32; j++) {
            O_acc[i][j] = 0.0f;
        }
    }
    
    // Load Q block once (reused across all K,V blocks)
    load_block_async<Br, d>(Q + blockIdx.x * Br * d, sQ, N, d);
    __syncthreads();
    
    // Iterate over K,V blocks
    for (int kv_block = 0; kv_block < (N + Bc - 1) / Bc; kv_block++) {
        // Load K, V blocks
        load_block_async<Bc, d>(K + kv_block * Bc * d, sK, N, d);
        load_block_async<Bc, d>(V + kv_block * Bc * d, sV, N, d);
        __syncthreads();
        
        // Compute S = Q @ K^T using tensor cores
        half S_frag[Br / WARPS_PER_BLOCK][Bc / 32];
        mma_sync(S_frag, sQ, sK);  // Simplified - actual uses wmma/mma
        
        // Online softmax update (in registers)
        #pragma unroll
        for (int i = 0; i < Br / WARPS_PER_BLOCK; i++) {
            float row_max = -INFINITY;
            #pragma unroll
            for (int j = 0; j < Bc / 32; j++) {
                row_max = fmaxf(row_max, __half2float(S_frag[i][j]));
            }
            row_max = warp_reduce_max(row_max);
            
            float new_max = fmaxf(m_acc[i], row_max);
            float scale = expf(m_acc[i] - new_max);
            
            l_acc[i] *= scale;
            O_acc[i][:] *= scale;  // Conceptual - vectorized in practice
            
            // Accumulate this block
            float row_sum = 0.0f;
            #pragma unroll
            for (int j = 0; j < Bc / 32; j++) {
                float p = expf(__half2float(S_frag[i][j]) - new_max);
                row_sum += p;
                // O_acc += p * V - done via mma
            }
            l_acc[i] += row_sum;
            m_acc[i] = new_max;
        }
        __syncthreads();
    }
    
    // Final normalization and store
    #pragma unroll
    for (int i = 0; i < Br / WARPS_PER_BLOCK; i++) {
        float inv_l = 1.0f / l_acc[i];
        #pragma unroll
        for (int j = 0; j < d / 32; j++) {
            O_acc[i][j] *= inv_l;
        }
    }
    store_block(O + blockIdx.x * Br * d, O_acc);
}

Roofline Analysis

πŸ“Š

FlashAttention Roofline Position (A100)

OperationArithmetic IntensityAchieved FLOPSBound
Standard Attention 2.8 FLOP/byte 1.2 TFLOPS Memory
FlashAttention 89 FLOP/byte 124 TFLOPS Compute
FlashAttention-2 102 FLOP/byte 156 TFLOPS Compute
A100 Peak - 312 TFLOPS -
Note: FP16, batch=1, heads=32, seq_len=4096, d=128

FlashAttention moves attention from memory-bound to compute-bound by increasing arithmetic intensity 30x.

FlashAttention-2 Improvements

FlashAttention-2 achieves additional speedup through:

  1. Reduced non-matmul FLOPs: Moved rescaling outside inner loop
  2. Better parallelism: Parallelize over sequence length, not just batch
  3. Improved warp scheduling: Better occupancy on Ampere/Hopper
# FlashAttention-2 key optimization: delayed rescaling
# Instead of rescaling O_prev each block, accumulate then rescale once

def flash_attention_2_block(Q_block, K_block, V_block, acc):
    S = Q_block @ K_block.T
    m_block = S.max(dim=-1)
    
    # Don't rescale yet - just track the scaling factors
    m_new = torch.maximum(acc.m, m_block)
    
    P = torch.exp(S - m_block.unsqueeze(-1))  # Local softmax
    PV = P @ V_block
    
    # Accumulate with deferred scaling
    acc.O_unscaled += PV * torch.exp(m_block - m_new).unsqueeze(-1)
    acc.l *= torch.exp(acc.m - m_new)
    acc.l += P.sum(dim=-1) * torch.exp(m_block - m_new)
    acc.m = m_new
    
    return acc

# Final rescaling (once, after all blocks)
O_final = acc.O_unscaled / acc.l.unsqueeze(-1)

Profiling FlashAttention

# Measure achieved memory bandwidth
ncu --set full \
    --metrics dram__bytes_read.sum,dram__bytes_write.sum,\
              sm__sass_thread_inst_executed_op_fadd_pred_on.sum,\
              sm__sass_thread_inst_executed_op_fmul_pred_on.sum \
    python -c "
import torch
from flash_attn import flash_attn_func
q = torch.randn(1, 4096, 32, 128, device='cuda', dtype=torch.float16)
k = torch.randn(1, 4096, 32, 128, device='cuda', dtype=torch.float16)
v = torch.randn(1, 4096, 32, 128, device='cuda', dtype=torch.float16)
for _ in range(100):
    o = flash_attn_func(q, k, v)
torch.cuda.synchronize()
"

Expected results on A100:

  • HBM Read: ~134 MB (vs 4.4 GB for standard)
  • HBM Write: ~33 MB
  • Achieved TFLOPS: 150+ (48% of peak)
  • SM Occupancy: 75-85%

When FlashAttention Isn’t Optimal

FlashAttention has overhead for:

  • Very short sequences (N < 256): Tiling overhead dominates
  • Very small batch sizes: Can’t saturate SMs
  • Non-standard attention patterns: Requires custom kernels
ℹ️ Sequence Length Crossover

For N < 512, cuBLAS GEMM-based attention often outperforms FlashAttention due to lower kernel launch overhead and better SM utilization at small problem sizes.

Conclusion

FlashAttention’s 33x reduction in HBM traffic comes from a fundamental restructuring of the attention algorithm, not approximation. Understanding this restructuringβ€”and the memory hierarchy constraints that motivate itβ€”is essential for anyone optimizing transformer inference.

The key insight generalizes: Any O(NΒ²) intermediate tensor that can be computed on-the-fly should be. This principle applies beyond attention to any algorithm with large intermediate materialization.