A naive attention kernel for sequence length 8192 allocates a 256 MB attention matrix, writes it to HBM, applies softmax, reads it back, then multiplies by V. For 32 attention heads, that is 8 GB of temporary storage — more memory than the model weights of a 7B parameter model. FlashAttention eliminates this materialization by computing attention in tiles: load a block of Q and K, compute QK^T scores, apply softmax incrementally using an online algorithm, multiply by V, and discard the scores. Memory usage drops from 256 MB per head to 64 KB per thread block. Throughput improves by 3-7x because the optimized version reads Q, K, V once instead of three times.
All measurements target NVIDIA Ampere (A100-80GB SXM, SM 8.0) unless stated otherwise.
The Math: Standard Attention
Given:
- — queries
- — keys
- — values
Standard attention:
Where softmax is applied row-wise:
For numerical stability, subtract the row maximum:
Version 0: Naive Attention (Full Materialization)
#include <cuda_runtime.h>
#include <cmath>
#include <cfloat>
// Step 1: Compute S = Q @ K^T
__global__ void compute_qk(const float* __restrict__ Q,
const float* __restrict__ K,
float* __restrict__ S,
int N, int d, float scale) {
int row = blockIdx.y * blockDim.y + threadIdx.y; // Query index
int col = blockIdx.x * blockDim.x + threadIdx.x; // Key index
if (row < N && col < N) {
float sum = 0.0f;
for (int k = 0; k < d; k++) {
sum += Q[row * d + k] * K[col * d + k];
}
S[row * N + col] = sum * scale;
}
}
// Step 2: Row-wise softmax of S
__global__ void softmax_rows(float* S, int N) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= N) return;
// Find max for numerical stability
float max_val = -FLT_MAX;
for (int j = 0; j < N; j++) {
max_val = fmaxf(max_val, S[row * N + j]);
}
// Compute exp and sum
float sum = 0.0f;
for (int j = 0; j < N; j++) {
S[row * N + j] = expf(S[row * N + j] - max_val);
sum += S[row * N + j];
}
// Normalize
float inv_sum = 1.0f / sum;
for (int j = 0; j < N; j++) {
S[row * N + j] *= inv_sum;
}
}
// Step 3: O = P @ V
__global__ void compute_pv(const float* __restrict__ P,
const float* __restrict__ V,
float* __restrict__ O,
int N, int d) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;
if (row < N && col < d) {
float sum = 0.0f;
for (int k = 0; k < N; k++) {
sum += P[row * N + k] * V[k * d + col];
}
O[row * d + col] = sum;
}
}
void naive_attention(const float* d_Q, const float* d_K, const float* d_V,
float* d_O, float* d_S, int N, int d) {
float scale = 1.0f / sqrtf((float)d);
dim3 block(32, 32);
dim3 grid_qk((N + 31) / 32, (N + 31) / 32);
compute_qk<<<grid_qk, block>>>(d_Q, d_K, d_S, N, d, scale);
softmax_rows<<<(N + 255) / 256, 256>>>(d_S, N);
dim3 grid_pv((d + 31) / 32, (N + 31) / 32);
compute_pv<<<grid_pv, block>>>(d_S, d_V, d_O, N, d);
}
Naive Attention: Memory and Compute (A100)
| Seq Length | Head Dim | Attention Matrix | Total Memory | Time (ms) |
|---|---|---|---|---|
| 1024 | 128 | 4 MB | ~5 MB | 0.8 |
| 4096 | 128 | 64 MB | ~66 MB | 12.4 |
| 8192 | 128 | 256 MB | ~260 MB | 49.2 |
| 16384 | 128 | 1 GB | OOM | N/A |
Version 1: Fused QK + Softmax + PV (Still Materializing S)
Fuse the three kernels to reduce global memory traffic:
// Fused: compute one row of attention output at a time
// Still materializes S, but row-by-row in registers
__global__ void attention_fused_rowwise(const float* __restrict__ Q,
const float* __restrict__ K,
const float* __restrict__ V,
float* __restrict__ O,
int N, int d, float scale) {
int query_idx = blockIdx.x; // One block per query
int tid = threadIdx.x; // Thread within block
extern __shared__ float smem[];
float* s_row = smem; // N floats for attention scores
float* s_q = smem + N; // d floats for query vector
// Load query vector to shared memory
for (int i = tid; i < d; i += blockDim.x) {
s_q[i] = Q[query_idx * d + i] * scale;
}
__syncthreads();
// Compute attention scores: s_row[j] = dot(Q[query_idx], K[j])
for (int j = tid; j < N; j += blockDim.x) {
float dot = 0.0f;
for (int k = 0; k < d; k++) {
dot += s_q[k] * K[j * d + k];
}
s_row[j] = dot;
}
__syncthreads();
// Softmax: find max
float local_max = -FLT_MAX;
for (int j = tid; j < N; j += blockDim.x) {
local_max = fmaxf(local_max, s_row[j]);
}
// Warp reduction for max
for (int offset = 16; offset > 0; offset >>= 1) {
float other = __shfl_down_sync(0xffffffff, local_max, offset);
local_max = fmaxf(local_max, other);
}
__shared__ float warp_max[32];
int warp_id = tid / 32;
int lane = tid & 31;
if (lane == 0) warp_max[warp_id] = local_max;
__syncthreads();
if (warp_id == 0) {
float val = (lane < blockDim.x / 32) ? warp_max[lane] : -FLT_MAX;
for (int offset = 16; offset > 0; offset >>= 1) {
float other = __shfl_down_sync(0xffffffff, val, offset);
val = fmaxf(val, other);
}
if (lane == 0) warp_max[0] = val;
}
__syncthreads();
float row_max = warp_max[0];
// Softmax: exp and sum
float local_sum = 0.0f;
for (int j = tid; j < N; j += blockDim.x) {
s_row[j] = expf(s_row[j] - row_max);
local_sum += s_row[j];
}
// Warp reduction for sum
for (int offset = 16; offset > 0; offset >>= 1) {
local_sum += __shfl_down_sync(0xffffffff, local_sum, offset);
}
__shared__ float warp_sum[32];
if (lane == 0) warp_sum[warp_id] = local_sum;
__syncthreads();
if (warp_id == 0) {
float val = (lane < blockDim.x / 32) ? warp_sum[lane] : 0.0f;
for (int offset = 16; offset > 0; offset >>= 1) {
val += __shfl_down_sync(0xffffffff, val, offset);
}
if (lane == 0) warp_sum[0] = val;
}
__syncthreads();
float inv_sum = 1.0f / warp_sum[0];
// Normalize
for (int j = tid; j < N; j += blockDim.x) {
s_row[j] *= inv_sum;
}
__syncthreads();
// Compute output: O[query_idx] = P[query_idx,:] @ V
for (int col = tid; col < d; col += blockDim.x) {
float sum = 0.0f;
for (int j = 0; j < N; j++) {
sum += s_row[j] * V[j * d + col];
}
O[query_idx * d + col] = sum;
}
}
This eliminates the global attention matrix but still stores the full attention row in shared memory ( floats). For , that is 32 KB per block — feasible on Ampere but limits occupancy.
The Online Softmax Algorithm
The key insight of FlashAttention: you do not need the entire attention row to compute softmax. You can compute softmax incrementally using the online softmax algorithm:
Maintain running statistics as you process blocks of keys:
At the end:
This processes the attention computation in tiles of keys at a time, never materializing the full matrix.
Version 2: FlashAttention-Style Tiled Attention
#include <cuda_runtime.h>
#include <cfloat>
#include <cmath>
// Tile sizes
#define Br 64 // Query tile size (rows of Q processed per block)
#define Bc 64 // Key tile size (columns of K processed per inner loop iteration)
#define d_head 128 // Head dimension (compile-time for this example)
__global__ void flash_attention_forward(
const float* __restrict__ Q, // [N, d]
const float* __restrict__ K, // [N, d]
const float* __restrict__ V, // [N, d]
float* __restrict__ O, // [N, d]
float* __restrict__ L, // [N] — log-sum-exp for backward pass
int N, float scale)
{
int batch_head = blockIdx.y; // Combined batch and head index
int tile_q = blockIdx.x; // Which query tile
// Offset to this batch/head
const float* q = Q + batch_head * N * d_head;
const float* k = K + batch_head * N * d_head;
const float* v = V + batch_head * N * d_head;
float* o = O + batch_head * N * d_head;
float* l = L + batch_head * N;
int tid = threadIdx.x;
int q_start = tile_q * Br;
// Shared memory
__shared__ float s_Q[Br][d_head]; // Query tile
__shared__ float s_K[Bc][d_head]; // Key tile
__shared__ float s_V[Bc][d_head]; // Value tile
__shared__ float s_S[Br][Bc]; // Attention scores for current tile
// Load query tile to shared memory
for (int i = tid; i < Br * d_head; i += blockDim.x) {
int r = i / d_head;
int c = i % d_head;
int global_r = q_start + r;
s_Q[r][c] = (global_r < N) ? q[global_r * d_head + c] * scale : 0.0f;
}
__syncthreads();
// Per-thread accumulators for output and softmax stats
// Each thread owns a subset of the Br query rows
float m_i[Br]; // Running max per query row
float ell_i[Br]; // Running sum of exp per query row
float o_i[Br][d_head]; // Running output accumulator — too large for registers at Br=64
// For practical implementation, each thread handles a few rows
// Simplified: assume blockDim.x >= Br, thread tid handles row tid
// (In production FlashAttention, the mapping is more complex)
int my_row = tid; // Thread tid handles query row tid within tile
if (my_row >= Br) return;
float my_m = -FLT_MAX;
float my_ell = 0.0f;
float my_o[d_head];
for (int i = 0; i < d_head; i++) my_o[i] = 0.0f;
// Iterate over key/value tiles
int num_kv_tiles = (N + Bc - 1) / Bc;
for (int tile_kv = 0; tile_kv < num_kv_tiles; tile_kv++) {
int kv_start = tile_kv * Bc;
// Load K tile to shared memory
__syncthreads();
for (int i = tid; i < Bc * d_head; i += blockDim.x) {
int r = i / d_head;
int c = i % d_head;
int global_r = kv_start + r;
s_K[r][c] = (global_r < N) ? k[global_r * d_head + c] : 0.0f;
}
// Load V tile to shared memory
for (int i = tid; i < Bc * d_head; i += blockDim.x) {
int r = i / d_head;
int c = i % d_head;
int global_r = kv_start + r;
s_V[r][c] = (global_r < N) ? v[global_r * d_head + c] : 0.0f;
}
__syncthreads();
// Compute S[my_row][j] = Q[my_row] @ K[j]^T for j in current tile
float s_local[Bc];
float tile_max = -FLT_MAX;
for (int j = 0; j < Bc; j++) {
if (kv_start + j >= N) {
s_local[j] = -FLT_MAX;
continue;
}
float dot = 0.0f;
for (int dd = 0; dd < d_head; dd++) {
dot += s_Q[my_row][dd] * s_K[j][dd];
}
s_local[j] = dot;
tile_max = fmaxf(tile_max, dot);
}
// Online softmax update
float new_m = fmaxf(my_m, tile_max);
float correction = expf(my_m - new_m);
// Rescale previous accumulator
my_ell *= correction;
for (int dd = 0; dd < d_head; dd++) {
my_o[dd] *= correction;
}
// Add current tile's contribution
float tile_sum = 0.0f;
for (int j = 0; j < Bc; j++) {
if (kv_start + j >= N) continue;
float p_ij = expf(s_local[j] - new_m);
tile_sum += p_ij;
for (int dd = 0; dd < d_head; dd++) {
my_o[dd] += p_ij * s_V[j][dd];
}
}
my_m = new_m;
my_ell += tile_sum;
}
// Final normalization
int global_row = q_start + my_row;
if (global_row < N) {
float inv_ell = 1.0f / my_ell;
for (int dd = 0; dd < d_head; dd++) {
o[global_row * d_head + dd] = my_o[dd] * inv_ell;
}
l[global_row] = my_m + logf(my_ell); // Log-sum-exp for backward
}
}
The FlashAttention-style kernel never materializes the attention matrix. Memory usage is for plus per block in shared memory. For , : naive requires 1 GB for the attention matrix alone; FlashAttention uses only KB of shared memory per block.
Optimizing the Inner Loop
The inner loop (QK dot product + softmax + PV accumulation) dominates runtime. Key optimizations:
Register Tiling
// Instead of d_head = 128 registers per output element,
// tile the d dimension and accumulate in registers
// Thread block: 128 threads
// Each thread handles 1 query row, accumulates d_head output values
// For d_head = 128, this is 128 floats in registers per thread = 512 bytes
// With 32 regs per float, that is 128 registers just for output
// Plus intermediates: ~160 registers per thread
// At 160 regs: max 65536/160 = 409 threads per SM = 12 warps = 18.75% occupancy
// This is acceptable because the kernel is compute-bound in the inner loop
Vectorized K/V Loads to Shared Memory
// Load K tile using float4 for 128-bit transactions
for (int i = tid; i < Bc * (d_head / 4); i += blockDim.x) {
int r = i / (d_head / 4);
int c4 = i % (d_head / 4);
int global_r = kv_start + r;
if (global_r < N) {
float4 val = reinterpret_cast<const float4*>(
&k[global_r * d_head])[c4];
s_K[r][c4 * 4 + 0] = val.x;
s_K[r][c4 * 4 + 1] = val.y;
s_K[r][c4 * 4 + 2] = val.z;
s_K[r][c4 * 4 + 3] = val.w;
}
}
FMA-Heavy Dot Product
// Use fmaf for fused multiply-add (single instruction, higher throughput)
float dot = 0.0f;
for (int dd = 0; dd < d_head; dd += 4) {
dot = fmaf(s_Q[my_row][dd], s_K[j][dd], dot);
dot = fmaf(s_Q[my_row][dd+1], s_K[j][dd+1], dot);
dot = fmaf(s_Q[my_row][dd+2], s_K[j][dd+2], dot);
dot = fmaf(s_Q[my_row][dd+3], s_K[j][dd+3], dot);
}
Causal (Autoregressive) Masking
For decoder-style attention, add a causal mask that prevents attending to future tokens:
// In the inner loop, after computing dot products:
for (int j = 0; j < Bc; j++) {
int key_pos = kv_start + j;
int query_pos = q_start + my_row;
if (key_pos > query_pos) {
// Causal mask: future tokens get -inf
s_local[j] = -FLT_MAX;
} else if (key_pos >= N) {
s_local[j] = -FLT_MAX;
}
tile_max = fmaxf(tile_max, s_local[j]);
}
// Optimization: skip entire KV tiles where kv_start > q_start + Br - 1
// (all keys in this tile are in the future for all queries in this tile)
if (kv_start > q_start + Br - 1) {
continue; // Skip this tile entirely
}
This early-exit optimization cuts the number of tiles processed roughly in half for causal attention, since approximately half the tiles are fully masked:
Multi-Head Attention Integration
// Full multi-head attention launcher
void multi_head_attention(
const float* d_Q, // [batch, num_heads, seq_len, d_head]
const float* d_K, // [batch, num_heads, seq_len, d_head]
const float* d_V, // [batch, num_heads, seq_len, d_head]
float* d_O, // [batch, num_heads, seq_len, d_head]
float* d_L, // [batch, num_heads, seq_len]
int batch, int num_heads, int seq_len, int d_head_dim,
bool causal)
{
float scale = 1.0f / sqrtf((float)d_head_dim);
int num_q_tiles = (seq_len + Br - 1) / Br;
int total_batch_heads = batch * num_heads;
dim3 grid(num_q_tiles, total_batch_heads);
dim3 block(Br); // One thread per query row in tile
// Shared memory: Q tile + K tile + V tile
size_t smem = (Br * d_head_dim + 2 * Bc * d_head_dim) * sizeof(float);
flash_attention_forward<<<grid, block, smem>>>(
d_Q, d_K, d_V, d_O, d_L, seq_len, scale);
}
FP16 Attention with Tensor Cores
For production performance, use FP16 accumulation with tensor cores via wmma or mma:
#include <cuda_fp16.h>
#include <mma.h>
using namespace nvcuda::wmma;
// Simplified FP16 attention tile using WMMA
// Each warp computes a 16x16 output tile of QK^T
__global__ void attention_fp16_wmma(
const half* __restrict__ Q, // [N, d]
const half* __restrict__ K, // [N, d]
const half* __restrict__ V, // [N, d]
half* __restrict__ O, // [N, d]
int N, int d, float scale)
{
// WMMA fragment declarations for 16x16x16 matrix multiply
fragment<matrix_a, 16, 16, 16, half, row_major> frag_Q;
fragment<matrix_b, 16, 16, 16, half, col_major> frag_K;
fragment<accumulator, 16, 16, 16, float> frag_S;
// Initialize accumulator
fill_fragment(frag_S, 0.0f);
// Compute S[16x16] = Q[16xd] @ K[16xd]^T using WMMA tiles of 16x16x16
int q_row = blockIdx.y * 16;
int k_row = blockIdx.x * 16;
for (int kk = 0; kk < d; kk += 16) {
load_matrix_sync(frag_Q, Q + q_row * d + kk, d);
load_matrix_sync(frag_K, K + k_row * d + kk, d);
mma_sync(frag_S, frag_Q, frag_K, frag_S);
}
// Apply scale
for (int i = 0; i < frag_S.num_elements; i++) {
frag_S.x[i] *= scale;
}
// ... softmax and PV multiply follow the same online pattern
}
Real FlashAttention implementations (FlashAttention-2, FlashAttention-3) use tensor core mma instructions for the QK and PV matrix multiplies, achieving 50-70% of peak TFLOPS. The FP32 version in this post is for pedagogical clarity — production code should use FP16/BF16 with wmma or inline PTX mma instructions.
Performance Comparison
Attention Kernel Performance (A100, batch=1, heads=32, d=128)
| Implementation | Seq=1024 (ms) | Seq=4096 (ms) | Seq=8192 (ms) | Memory |
|---|---|---|---|---|
| Naive (3 kernels) | 0.8 | 12.4 | 49.2 | O(N^2) |
| Fused row-wise | 0.6 | 8.8 | 34.1 | O(N) smem |
| Tiled (this post, FP32) | 0.4 | 4.2 | 15.8 | O(Br*Bc) smem |
| FlashAttention-2 (FP16) | 0.12 | 0.9 | 3.2 | O(Br*Bc) smem |
| cuDNN attention (FP16) | 0.10 | 0.8 | 2.9 | O(Br*Bc) |
Attention Latency: Seq Length 4096, d=128
(ms (32 heads))Backward Pass Sketch
The backward pass recomputes the attention matrix from (using the saved log-sum-exp ) rather than storing it:
// Forward saves: O (output) and L (log-sum-exp per row)
// Backward receives: dO (gradient of output)
// Backward computes: dQ, dK, dV
// Key insight: recompute S = QK^T/sqrt(d) and P = softmax(S) from Q,K,L
// This trades compute for memory — no N x N storage needed
// dV = P^T @ dO
// dP = dO @ V^T
// dS = P * (dP - rowsum(dP * P)) [softmax backward]
// dQ = dS @ K
// dK = dS^T @ Q
// Same tiling strategy as forward, with online recomputation of P
Design Decisions for Custom Attention Kernels
When should you write a custom attention kernel versus using FlashAttention or cuDNN?
When to Write Custom vs Use Library
| Scenario | Recommendation | Reason |
|---|---|---|
| Standard MHA/GQA | Use FlashAttention-2/3 | Highly optimized, battle-tested |
| Custom masking pattern | Consider custom | Sliding window, block-sparse, etc. |
| Fused attention + bias | Consider custom | ALiBi, relative position encoding |
| Quantized KV cache | Custom required | INT4/INT8 K/V not in standard libraries |
| Non-standard head dim | May need custom | d != 64, 128, 256 |
| Research prototyping | Use Triton | Faster iteration than CUDA C++ |
FlashAttention-2 and cuDNN 9.x attention cover the vast majority of use cases with near-optimal performance. Write a custom CUDA kernel only when you have a non-standard attention pattern (e.g., sparse attention, fused ALiBi, quantized KV cache) that the libraries do not support. For prototyping, Triton is 10x faster to iterate on than raw CUDA.
Summary
Custom attention kernel development progresses from naive (three separate kernels, memory) through fused row-wise computation (single kernel, shared memory) to the FlashAttention-style tiled approach with online softmax ( shared memory). The online softmax algorithm maintains running max and sum-of-exponentials statistics, allowing the attention matrix to be processed in tiles without materialization. The key optimization dimensions are: tile size selection (balancing shared memory capacity against parallelism), vectorized loads for K/V tiles, FMA-dense dot products, causal mask tile skipping, and FP16 tensor core utilization. For production, FlashAttention-2/3 or cuDNN attention should be the default; write custom kernels for non-standard attention patterns, fused bias computation, or quantized KV caches.