When you deploy vLLM, it doesn’t ask which attention backend you want. It looks at your GPU architecture, checks your model’s data type and head dimensions, and silently picks one of four kernels — FlashAttention-2, FlashAttention-3, PagedAttention v2, or FlashInfer. Get this choice wrong and your H100 GPU performs like an A100. Get it right and you extract every last TFLOP the hardware can deliver. The selection happens once at startup, buried in selector.py, and most users never see the decision tree. This post makes it visible: given your hardware and workload, you’ll know exactly which kernel vLLM uses and why — and whether changing it would help.
vLLM v1 does not use a single attention kernel. It dispatches to one of four backends depending on the GPU architecture, the data type, the sequence length, the attention pattern, and the phase (prefill vs. decode). The selection happens at model initialization time, not per-request, because switching kernels mid-flight would require reallocating workspace buffers and rebuilding CUDA graphs.
This post traces the selection logic in vllm/attention/selector.py, explains what each backend does differently at the CUDA kernel level, and benchmarks all four across representative workloads. The goal is to make the backend choice legible: given your hardware and workload, you should know exactly which kernel vLLM will select and why.
The Four Backends
vLLM v1 supports four attention computation backends. Each implements the same mathematical operation — scaled dot-product attention — but with radically different memory access patterns, tiling strategies, and hardware utilization.
FlashAttention-2
FlashAttention-2 (Dao, 2023) rewrites the attention computation to avoid materializing the full attention matrix in HBM. Instead, it tiles the computation across SRAM (shared memory) and accumulates partial softmax results using the online softmax algorithm. The kernel processes Q, K, V in blocks of size (typically 64x64 or 128x128 depending on head dimension), keeping only one tile in SRAM at a time.
The key data flow:
- Load a block of Q rows from HBM to SRAM
- Iterate over K, V blocks: load K block, compute tile, apply causal mask, compute local softmax, multiply by V block, accumulate
- Write the final output block back to HBM
Total HBM reads: for Q, K, V. Total HBM writes: for the output. No intermediate. This is where the memory savings come from.
// FlashAttention-2 kernel pseudocode (simplified from flash_fwd_kernel.h)
// Each thread block processes B_r rows of Q against all K,V
template <int kBlockM, int kBlockN, int kHeadDim>
__global__ void flash_fwd_kernel(
const half* __restrict__ Q, // [batch, seqlen_q, num_heads, head_dim]
const half* __restrict__ K, // [batch, seqlen_k, num_kv_heads, head_dim]
const half* __restrict__ V, // [batch, seqlen_k, num_kv_heads, head_dim]
half* __restrict__ O, // [batch, seqlen_q, num_heads, head_dim]
float* __restrict__ L, // [batch, num_heads, seqlen_q] - logsumexp for online softmax
const int seqlen_q,
const int seqlen_k,
const float softmax_scale
) {
// Each thread block handles kBlockM rows of Q
const int q_start = blockIdx.x * kBlockM;
// Shared memory for Q, K, V tiles
__shared__ half sQ[kBlockM][kHeadDim];
__shared__ half sK[kBlockN][kHeadDim];
__shared__ half sV[kBlockN][kHeadDim];
// Load Q tile once (stays in SRAM for entire K,V iteration)
load_tile(Q, sQ, q_start, kBlockM, kHeadDim);
// Accumulators in registers
float acc[kBlockM][kHeadDim] = {0};
float row_max[kBlockM] = {-INFINITY};
float row_sum[kBlockM] = {0};
// Iterate over K,V blocks
for (int kv_start = 0; kv_start < seqlen_k; kv_start += kBlockN) {
// Load K tile
load_tile(K, sK, kv_start, kBlockN, kHeadDim);
__syncthreads();
// Compute S = Q @ K^T (in SRAM)
float S[kBlockM][kBlockN];
gemm_in_sram(sQ, sK, S, softmax_scale); // S[i][j] = sum_d(sQ[i][d] * sK[j][d]) * scale
// Apply causal mask
apply_causal_mask(S, q_start, kv_start, kBlockM, kBlockN);
// Online softmax: update running max and sum
float prev_max[kBlockM];
copy(row_max, prev_max, kBlockM);
update_row_max(S, row_max, kBlockM, kBlockN);
// Rescale previous accumulator
float scale_factor[kBlockM];
for (int i = 0; i < kBlockM; i++) {
scale_factor[i] = expf(prev_max[i] - row_max[i]);
row_sum[i] *= scale_factor[i];
for (int d = 0; d < kHeadDim; d++) {
acc[i][d] *= scale_factor[i];
}
}
// Compute P = softmax(S) with new max
float P[kBlockM][kBlockN];
for (int i = 0; i < kBlockM; i++) {
for (int j = 0; j < kBlockN; j++) {
P[i][j] = expf(S[i][j] - row_max[i]);
row_sum[i] += P[i][j];
}
}
// Load V tile and accumulate: acc += P @ V
load_tile(V, sV, kv_start, kBlockN, kHeadDim);
__syncthreads();
gemm_accumulate(P, sV, acc, kBlockM, kBlockN, kHeadDim);
}
// Final normalization: O = acc / row_sum
for (int i = 0; i < kBlockM; i++) {
for (int d = 0; d < kHeadDim; d++) {
O[(q_start + i) * kHeadDim + d] = __float2half(acc[i][d] / row_sum[i]);
}
}
}
FlashAttention-2 runs on Ampere (SM80+) and Hopper (SM90+) GPUs with FP16 or BF16. It supports standard multi-head attention (MHA), multi-query attention (MQA), and grouped-query attention (GQA) natively. It does not support FP8.
FlashAttention-3
FlashAttention-3 (Shah et al., 2024) targets Hopper GPUs exclusively. It exploits three Hopper-specific hardware features that do not exist on Ampere:
WGMMA (Warpgroup Matrix Multiply-Accumulate): Hopper introduces warpgroup-level matrix instructions that operate on 128x128 or larger tiles in a single instruction, compared to Ampere’s mma.sync which operates on 16x16 tiles. A single WGMMA instruction achieves higher throughput per clock cycle.
TMA (Tensor Memory Accelerator): A hardware unit that handles async memory copies between global memory and shared memory without consuming SM compute resources. On Ampere, loading a tile from HBM to SRAM requires explicit cp.async instructions that still consume warp scheduler slots. On Hopper, TMA offloads the copy entirely to a dedicated engine.
Warp Specialization: FlashAttention-3 divides warps within a thread block into producer warps (that load data via TMA) and consumer warps (that compute matmuls via WGMMA). Producers and consumers operate concurrently — while consumers compute on the current tile, producers prefetch the next tile. This hides memory latency behind compute.
// FlashAttention-3: Hopper warp specialization pattern
// Producer warps handle TMA loads, consumer warps handle WGMMA compute
__global__ void flash3_fwd_kernel_hopper(/* ... */) {
const int warp_id = threadIdx.x / 32;
const int num_warps = blockDim.x / 32;
// Warp role assignment
const bool is_producer = (warp_id < num_warps / 4); // 25% producers
const bool is_consumer = !is_producer; // 75% consumers
if (is_producer) {
// TMA-based async loads -- does NOT stall compute warps
for (int kv_block = 0; kv_block < num_kv_blocks; kv_block++) {
// Issue TMA descriptor-based copy
// Hardware DMA engine moves data from HBM to shared memory
tma_load_async(&smem_K[kv_block % 2], &gmem_K[kv_block]);
tma_load_async(&smem_V[kv_block % 2], &gmem_V[kv_block]);
// Signal consumers that data is ready
arrive_barrier(load_barrier[kv_block % 2]);
// Wait for consumers to finish with previous buffer
wait_barrier(compute_barrier[(kv_block - 1) % 2]);
}
} else {
// Consumer warps: WGMMA compute on tiles in shared memory
for (int kv_block = 0; kv_block < num_kv_blocks; kv_block++) {
// Wait for producer to finish loading
wait_barrier(load_barrier[kv_block % 2]);
// WGMMA: 128x128 matmul in one instruction
wgmma_mma_async(acc_S, smem_Q, smem_K[kv_block % 2]);
// Online softmax + V accumulation via WGMMA
online_softmax_and_accumulate(acc_S, smem_V[kv_block % 2], acc_O);
// Signal producer that buffer can be reused
arrive_barrier(compute_barrier[kv_block % 2]);
}
}
}
FlashAttention-3 also adds native FP8 support on Hopper. FP8 ( for forward, for backward) halves the memory bandwidth compared to FP16, giving nearly 2x throughput for bandwidth-bound decode operations. The kernel handles FP8 quantization and dequantization internally, with per-block scaling factors to maintain accuracy.
PagedAttention v2
PagedAttention is vLLM’s original attention kernel, designed specifically for the decode phase where KV cache is stored in non-contiguous physical blocks. Unlike FlashAttention (which assumes contiguous K, V tensors), PagedAttention reads K and V through a block table indirection.
// PagedAttention v2 decode kernel (simplified)
// Each thread block handles one query head for one sequence
template <int HEAD_DIM, int BLOCK_SIZE, int NUM_WARPS>
__global__ void paged_attention_v2_kernel(
const half* __restrict__ Q, // [num_seqs, num_heads, head_dim]
const half* __restrict__ K_cache, // [num_blocks, block_size, num_kv_heads, head_dim]
const half* __restrict__ V_cache, // [num_blocks, block_size, num_kv_heads, head_dim]
half* __restrict__ O, // [num_seqs, num_heads, head_dim]
const int* __restrict__ block_tables, // [num_seqs, max_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
float softmax_scale
) {
const int seq_idx = blockIdx.x;
const int head_idx = blockIdx.y;
const int context_len = context_lens[seq_idx];
const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE;
// Each warp processes a subset of KV blocks
// Phase 1: Compute QK^T scores for all tokens
float qk_max = -INFINITY;
// Load query vector (one row, shared across all warps)
half q_vec[HEAD_DIM];
load_query(Q, q_vec, seq_idx, head_idx);
// Iterate through paged KV blocks
for (int block_idx = threadIdx.x / 32; block_idx < num_blocks; block_idx += NUM_WARPS) {
// Look up physical block ID from block table
int physical_block = block_tables[seq_idx * max_blocks + block_idx];
// Load K from non-contiguous physical block
for (int tok = 0; tok < BLOCK_SIZE; tok++) {
int global_tok_idx = block_idx * BLOCK_SIZE + tok;
if (global_tok_idx >= context_len) break;
// Compute dot product: q . k[tok]
float score = dot_product(
q_vec,
&K_cache[physical_block * BLOCK_SIZE * HEAD_DIM + tok * HEAD_DIM],
HEAD_DIM
) * softmax_scale;
qk_max = fmaxf(qk_max, score);
scores[global_tok_idx] = score;
}
}
// Phase 2: Cross-warp reduction of max, then softmax, then weighted V sum
// (V2 splits this into a two-pass approach for better parallelism)
// ...
}
The “v2” in PagedAttention v2 refers to the two-pass approach: the first pass computes partial softmax results per warp, the second pass reduces across warps. v1 used a single-pass approach that required atomic operations and had lower throughput at long context lengths.
PagedAttention does not require contiguous K, V. This makes it the only backend that can serve decode requests without copying KV cache into contiguous buffers.
FlashInfer
FlashInfer (Ye et al., 2024) is an external library that provides both prefill and decode attention kernels with PagedAttention support. Its key differentiator: it generates specialized CUDA kernels at runtime using a JIT compilation approach, tuning tile sizes and memory access patterns for the specific GPU architecture and problem size.
# FlashInfer backend integration in vLLM
import flashinfer
class FlashInferBackend:
def __init__(self, num_heads, num_kv_heads, head_dim, dtype, page_size):
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device="cuda")
# Prefill: uses ragged tensor format (variable-length sequences packed together)
self.prefill_wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer, kv_layout="NHD"
)
# Decode: one query token per sequence, paged KV
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, kv_layout="NHD"
)
def forward_prefill(self, q, kv_cache, block_tables, seq_lens):
# Plan: compute memory access patterns for this batch
self.prefill_wrapper.plan(
qo_indptr=compute_indptr(seq_lens),
paged_kv_indptr=compute_kv_indptr(block_tables),
paged_kv_indices=flatten_block_tables(block_tables),
paged_kv_last_page_len=compute_last_page_lens(seq_lens),
num_qo_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
page_size=self.page_size
)
return self.prefill_wrapper.run(q, kv_cache)
def forward_decode(self, q, kv_cache, block_tables, seq_lens):
self.decode_wrapper.plan(
indptr=compute_indptr_ones(len(seq_lens)),
paged_kv_indptr=compute_kv_indptr(block_tables),
paged_kv_indices=flatten_block_tables(block_tables),
paged_kv_last_page_len=compute_last_page_lens(seq_lens),
num_qo_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_dim=self.head_dim,
page_size=self.page_size
)
return self.decode_wrapper.run(q, kv_cache)
FlashInfer’s plan/run separation is important: the plan() call pre-computes all memory access metadata (indirection pointers, page offsets) on CPU, and run() executes the GPU kernel with zero CPU-side overhead. When the batch composition does not change between iterations (common in decode), the plan can be cached.
Backend Feature Comparison
| Feature | FlashAttention-2 | FlashAttention-3 | PagedAttention v2 | FlashInfer |
|---|---|---|---|---|
| Min GPU Arch | Ampere (SM80) | Hopper (SM90) | Any CUDA | Ampere (SM80) |
| FP16/BF16 | Yes | Yes | Yes | Yes |
| FP8 | No | Yes (E4M3) | No | Yes (E4M3/E5M2) |
| Paged KV | No (contiguous) | No (contiguous) | Yes (native) | Yes (native) |
| GQA/MQA | Yes | Yes | Yes | Yes |
| MLA (DeepSeek) | No | No | Custom kernel | Yes |
| Prefill | Optimal | Optimal (Hopper) | Suboptimal | Good |
| Decode | Requires copy | Requires copy | Optimal | Good |
| WGMMA/TMA | No | Yes | No | Partial |
The Selection Logic
The backend selection runs once during model initialization in vllm/attention/selector.py. The function get_attn_backend() walks a decision tree that considers the environment, the hardware, and the model configuration.
The Decision Tree
def get_attn_backend(
num_heads: int,
num_kv_heads: int,
head_dim: int,
dtype: torch.dtype,
block_size: int,
is_attention_free: bool,
use_mla: bool,
) -> AttentionBackend:
"""
Selection priority:
1. Environment override (VLLM_ATTENTION_BACKEND env var)
2. Attention-free models (Mamba, RWKV) -> no backend needed
3. MLA models (DeepSeek-V2/V3) -> special path
4. Hopper + FP8 -> FlashAttention-3
5. Hopper + FP16/BF16 -> FlashAttention-3 (if available) else FlashAttention-2
6. Ampere + FP16/BF16 -> FlashAttention-2 (prefill) + PagedAttention v2 (decode)
7. Fallback -> PagedAttention v2 (both phases)
"""
# Step 1: Environment override
env_backend = os.environ.get("VLLM_ATTENTION_BACKEND")
if env_backend is not None:
return _resolve_backend(env_backend)
# Step 2: Attention-free models need no attention backend
if is_attention_free:
return PlaceholderAttentionBackend
# Step 3: MLA (Multi-head Latent Attention) special path
if use_mla:
return _select_mla_backend(num_heads, num_kv_heads, head_dim, dtype)
# Step 4-7: Standard selection based on GPU arch and dtype
gpu_arch = torch.cuda.get_device_capability()
if gpu_arch >= (9, 0): # Hopper (SM90+)
return _select_hopper_backend(dtype, head_dim, block_size)
elif gpu_arch >= (8, 0): # Ampere (SM80+)
return _select_ampere_backend(dtype, head_dim, block_size)
else:
# Pre-Ampere: only PagedAttention
return PagedAttentionBackend
def _select_hopper_backend(dtype, head_dim, block_size):
"""Hopper GPUs: prefer FlashAttention-3 when available."""
if dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
# FP8: only FlashAttention-3 supports this
if _is_flash3_available():
return FlashAttention3Backend
else:
raise ValueError("FP8 attention requires FlashAttention-3 on Hopper")
if dtype in (torch.float16, torch.bfloat16):
# FP16/BF16: prefer FA3 for Hopper-specific optimizations
if _is_flash3_available() and head_dim in (64, 128, 256):
return FlashAttention3Backend
elif _is_flash2_available():
return FlashAttention2Backend
else:
return PagedAttentionBackend
# FP32: no flash support
return PagedAttentionBackend
def _select_ampere_backend(dtype, head_dim, block_size):
"""Ampere GPUs: FlashAttention-2 for prefill, PagedAttention for decode."""
if dtype not in (torch.float16, torch.bfloat16):
return PagedAttentionBackend
if head_dim not in (64, 128, 256):
# Flash kernels only support specific head dims
return PagedAttentionBackend
if _is_flash2_available():
# Hybrid: FA2 for prefill (contiguous), Paged for decode (non-contiguous)
return FlashAttention2Backend
return PagedAttentionBackend
Summary
The backend selection in vLLM v1 is a function of hardware capability and workload characteristics. FlashAttention-3 on Hopper represents the current throughput ceiling for prefill, exploiting WGMMA, TMA, and warp specialization for up to 1.7x gains over FlashAttention-2. PagedAttention v2 remains the workhorse for decode, handling non-contiguous KV cache without copy overhead. FlashInfer fills the gaps: MLA support, non-standard head dimensions, and superior small-batch long-context decode.
The decision tree is deterministic and runs once at initialization. Understanding it means understanding why your deployment gets the throughput it does — and whether changing the backend could improve it.