A transformer decode step launches 80 kernels. At 10 microseconds per launch, that is 800 microseconds of CPU overhead. The actual GPU compute takes 2 milliseconds. You are spending 29% of total latency on launch overhead. CUDA Graphs collapse those 80 launches into a single graph replay operation that costs under 10 microseconds โ a 40x reduction in launch overhead. For vLLM serving Llama-7B, graph capture reduces decode latency from 6.2 ms to 4.8 ms, a 23% speedup from eliminating launch tax alone. The catch: graphs require fixed memory addresses and fixed kernel parameters. Variable batch sizes require either multiple graphs or clever memory pre-allocation schemes.
CUDA Graphs solve the kernel launch problem by recording a sequence of kernel launches into a graph, then replaying the entire graph with a single CPU-side launch. The overhead drops from 1 millisecond to under 10 microseconds. The catch: CUDA Graphs require fixed memory addresses, fixed kernel parameters, and no CPU-side control flow during replay. This post covers the mechanics, the constraints, and the production strategies used by systems like vLLM.
Why Kernel Launch Overhead Matters
Quantifying the Overhead
import torch
import time
def measure_launch_overhead(num_kernels=80, tensor_size=4096):
"""Measure CPU overhead of launching many small kernels."""
device = 'cuda'
x = torch.randn(1, tensor_size, device=device, dtype=torch.float16)
bias = torch.randn(tensor_size, device=device, dtype=torch.float16)
weight = torch.randn(tensor_size, tensor_size, device=device,
dtype=torch.float16)
# Warmup
for _ in range(10):
y = x
for _ in range(num_kernels):
y = y + bias # Tiny kernel
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(100):
y = x
for _ in range(num_kernels):
y = y + bias
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / 100
# Estimate per-kernel overhead
# The actual add computation is negligible for 1x4096
per_kernel_us = elapsed * 1e6 / num_kernels
print(f"Total time for {num_kernels} kernels: "
f"{elapsed*1e6:.0f} us")
print(f"Per-kernel overhead: {per_kernel_us:.1f} us")
print(f"This is {elapsed*1e6:.0f} us of CPU overhead per "
f"decode step")
measure_launch_overhead()
The Decode Bottleneck
During autoregressive decoding, each generated token requires one full forward pass through the model. For a 7B parameter model with 32 layers, each forward pass launches approximately:
def count_kernels_per_decode(num_layers=32):
"""Count kernel launches in one decode step."""
per_layer = {
'rmsnorm_1': 1,
'qkv_proj_gemm': 1,
'rope': 1,
'attention_score': 1, # Or FlashDecoding
'attention_softmax': 1,
'attention_value': 1,
'o_proj_gemm': 1,
'residual_add_1': 1,
'rmsnorm_2': 1,
'gate_proj_gemm': 1,
'up_proj_gemm': 1,
'silu_mul': 1,
'down_proj_gemm': 1,
'residual_add_2': 1,
}
total = num_layers * sum(per_layer.values())
print(f"Kernels per layer: {sum(per_layer.values())}")
print(f"Total kernels per decode step: {total}")
print(f"At 8 us/launch: {total * 8 / 1000:.1f} ms overhead")
print(f"At 12 us/launch: {total * 12 / 1000:.1f} ms overhead")
return total
count_kernels_per_decode(32) # 7B model
print()
count_kernels_per_decode(80) # 70B model
Kernel Launch Overhead vs GPU Compute Time (Llama-2 7B Decode, batch=1)
(ms)CUDA Graph Capture Mechanics
Basic Capture and Replay
def basic_cuda_graph_example():
"""Demonstrate CUDA graph capture and replay."""
device = 'cuda'
# Allocate tensors BEFORE capture
# These tensors must persist -- the graph stores their addresses
x = torch.randn(1, 4096, device=device, dtype=torch.float16)
weight = torch.randn(4096, 4096, device=device, dtype=torch.float16)
bias = torch.randn(4096, device=device, dtype=torch.float16)
output = torch.empty(1, 4096, device=device, dtype=torch.float16)
# Warmup (required before capture)
# CUDA lazily initializes kernel code on first launch
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
y = torch.matmul(x, weight.T) + bias
_ = torch.nn.functional.gelu(y)
torch.cuda.current_stream().wait_stream(s)
# Capture the graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
# Everything inside this context is recorded, not executed
y = torch.matmul(x, weight.T)
y = y + bias
output = torch.nn.functional.gelu(y)
# The graph is now captured. output tensor has a fixed address.
# To run inference:
# 1. Update input (write new data to the same tensor)
x.copy_(torch.randn(1, 4096, device=device, dtype=torch.float16))
# 2. Replay the graph (single CPU launch)
graph.replay()
# 3. Read output (from the same tensor used during capture)
result = output.clone() # Clone to get a snapshot
print(f"Output shape: {result.shape}, "
f"Output sum: {result.sum().item():.4f}")
return graph, x, output
basic_cuda_graph_example()
CUDA Graphs record the GPU memory addresses of all tensors accessed during capture. After capture, you must not reallocate, resize, or move these tensors. Writing new values to the same tensor (via .copy_()) is fine โ the address stays the same. Creating new tensors or resizing existing ones breaks the graph.
What Gets Captured
def what_gets_captured():
"""Explain what a CUDA graph records."""
captured = {
"Kernel launches": "Function pointer, grid/block dims, "
"shared memory size, arguments (pointers, scalars)",
"Memory copies": "cudaMemcpy (device-to-device only; "
"host-device copies cannot be captured)",
"Stream operations": "Events, stream synchronization within "
"the graph",
"Memory allocs (CUDA 11.4+)": "cudaMalloc within graph "
"(allocated once, reused on replay)",
}
not_captured = {
"CPU operations": "Python code, control flow (if/else), "
"printing, logging",
"Host-device memcpy": "cudaMemcpy from/to host memory",
"cuBLAS handle creation": "Must be created before capture",
"NCCL operations": "Collective communication (allreduce, etc.) "
"-- use separate mechanism",
"Dynamic allocations": "torch.empty() with new sizes "
"(creates new addresses)",
}
print("=== Captured by CUDA Graph ===")
for item, desc in captured.items():
print(f" {item}: {desc}")
print("\n=== NOT Captured ===")
for item, desc in not_captured.items():
print(f" {item}: {desc}")
Memory Management for CUDA Graphs
Pre-Allocation Strategy
class CUDAGraphMemoryPool:
"""Pre-allocate memory for CUDA graph execution.
All tensors used during graph capture must be pre-allocated.
This class manages the memory pool.
"""
def __init__(self, max_batch_size, max_seq_len, hidden_dim,
num_layers, num_heads, head_dim, dtype=torch.float16):
self.device = 'cuda'
self.dtype = dtype
# Pre-allocate all intermediate tensors
self.hidden_states = torch.empty(
max_batch_size, 1, hidden_dim, # decode: seq_len=1
device=self.device, dtype=dtype
)
self.residual = torch.empty_like(self.hidden_states)
# QKV projection outputs
self.qkv = torch.empty(
max_batch_size, 1, 3 * hidden_dim,
device=self.device, dtype=dtype
)
# Attention output
self.attn_output = torch.empty(
max_batch_size, 1, hidden_dim,
device=self.device, dtype=dtype
)
# MLP intermediates (SwiGLU: gate and up)
mlp_intermediate_dim = int(hidden_dim * 2.6875) # Llama-style
self.gate_output = torch.empty(
max_batch_size, 1, mlp_intermediate_dim,
device=self.device, dtype=dtype
)
self.up_output = torch.empty_like(self.gate_output)
self.mlp_output = torch.empty_like(self.hidden_states)
# Logits output
vocab_size = 128256 # Llama-3 vocab
self.logits = torch.empty(
max_batch_size, 1, vocab_size,
device=self.device, dtype=dtype
)
total_bytes = sum(
t.nelement() * t.element_size()
for t in [self.hidden_states, self.residual, self.qkv,
self.attn_output, self.gate_output,
self.up_output, self.mlp_output, self.logits]
)
print(f"Pre-allocated {total_bytes / 1e6:.1f} MB for "
f"CUDA graph workspace")
def get_input_buffer(self):
"""Return the input buffer for writing new data."""
return self.hidden_states
CUDA Memory Pool for Graph Allocations
def setup_cuda_graph_memory_pool():
"""Configure PyTorch memory allocator for CUDA graph compatibility.
PyTorch's caching allocator can interfere with CUDA graphs
because it may return different addresses for same-size
allocations. The graph pool ID forces consistent allocation.
"""
# Create a private memory pool for graph capture
# All allocations during capture will use this pool
pool = torch.cuda.graph_pool_handle()
# During capture, set the pool
graph = torch.cuda.CUDAGraph()
# Warmup with pool
s = torch.cuda.Stream()
with torch.cuda.stream(s):
# Run the operations once to populate the allocator cache
pass
# Capture with pool
with torch.cuda.graph(graph, pool=pool):
# Allocations here are pinned to the pool
pass
return graph, pool
Dynamic Shape Handling
The Fundamental Problem
CUDA Graphs require fixed tensor shapes (and therefore fixed memory addresses). But LLM inference has dynamic shapes:
- Batch size varies as requests arrive and complete
- Sequence length grows during decode (but only by 1 each step)
- KV cache size depends on context length
class DynamicShapeStrategy:
"""Strategies for handling dynamic shapes with CUDA graphs."""
def __init__(self):
pass
def strategy_1_pad_to_max(self, actual_batch, max_batch=256):
"""Pad batch to maximum size. Wasteful but simple.
- Capture one graph at max_batch_size
- Pad actual input to max_batch_size
- Mask output for actual batch entries
Pro: Only one graph to maintain
Con: Wastes compute on padding (e.g., batch=1 pays for batch=256)
"""
pad_size = max_batch - actual_batch
print(f"Padding: batch {actual_batch} -> {max_batch} "
f"(wasting {pad_size/max_batch*100:.0f}% compute)")
def strategy_2_bucketed_graphs(self, actual_batch,
buckets=None):
"""Capture multiple graphs at specific bucket sizes.
Bucket sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256]
Round up actual batch to next bucket, pad the difference.
Pro: Much less waste than max padding
Con: N graphs to capture and store
"""
if buckets is None:
buckets = [1, 2, 4, 8, 16, 32, 64, 128, 256]
# Find smallest bucket >= actual_batch
bucket = next(b for b in buckets if b >= actual_batch)
waste_pct = (bucket - actual_batch) / bucket * 100
print(f"Bucket: batch {actual_batch} -> {bucket} "
f"(wasting {waste_pct:.0f}% compute)")
return bucket
def strategy_3_graph_per_shape(self, shapes_seen):
"""Cache a graph for each unique shape seen.
Pro: No wasted compute
Con: Potentially many graphs, high memory usage
Used by: TensorRT-LLM (with a cap on cache size)
"""
print(f"Cached graphs: {len(shapes_seen)}")
for shape, count in shapes_seen.items():
print(f" batch={shape}: used {count} times")
vLLMโs Graph Caching Strategy
class VLLMCUDAGraphRunner:
"""Simplified version of vLLM's CUDA graph strategy.
vLLM captures graphs for a set of padded batch sizes and
reuses them during decode. The key insight: during decode,
sequence length is always 1 (one new token per sequence),
so only batch size varies.
Graph capture happens at server startup for each batch size
bucket. During serving, the actual batch is padded up to
the next bucket and the corresponding graph is replayed.
"""
# Batch size buckets used by vLLM
BATCH_SIZE_BUCKETS = [1, 2, 4, 8, 16, 24, 32, 48, 64,
96, 128, 192, 256, 384, 512]
def __init__(self, model, max_batch_size=512):
self.model = model
self.graphs = {} # bucket_size -> CUDAGraph
self.input_buffers = {} # bucket_size -> input tensor
self.output_buffers = {} # bucket_size -> output tensor
# Only capture graphs for buckets up to max_batch_size
self.buckets = [b for b in self.BATCH_SIZE_BUCKETS
if b <= max_batch_size]
def capture_all(self):
"""Capture graphs for all batch size buckets.
This runs at server startup. Each capture:
1. Creates input/output buffers at the bucket size
2. Runs a warmup forward pass
3. Captures the forward pass as a CUDA graph
"""
for bucket in self.buckets:
print(f"Capturing graph for batch_size={bucket}...")
self._capture_one(bucket)
print(f"Captured {len(self.buckets)} graphs")
def _capture_one(self, batch_size):
"""Capture a single graph for a given batch size."""
device = 'cuda'
# Allocate input/output buffers
input_buf = torch.empty(
batch_size, 1, self.model.hidden_size,
device=device, dtype=torch.float16
)
output_buf = torch.empty(
batch_size, 1, self.model.vocab_size,
device=device, dtype=torch.float16
)
# Warmup (required before capture)
with torch.no_grad():
self.model.decode_forward(input_buf)
# Capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
with torch.no_grad():
output_buf = self.model.decode_forward(input_buf)
self.graphs[batch_size] = graph
self.input_buffers[batch_size] = input_buf
self.output_buffers[batch_size] = output_buf
def run(self, hidden_states):
"""Run inference using the appropriate cached graph.
Args:
hidden_states: [actual_batch, 1, hidden] tensor
Returns:
logits: [actual_batch, 1, vocab] tensor
"""
actual_batch = hidden_states.shape[0]
# Find the smallest bucket that fits
bucket = self._find_bucket(actual_batch)
if bucket is None:
# Batch too large for any graph -- fall back to eager
return self.model.decode_forward(hidden_states)
# Pad input to bucket size
input_buf = self.input_buffers[bucket]
input_buf[:actual_batch].copy_(hidden_states)
if actual_batch < bucket:
input_buf[actual_batch:].zero_()
# Replay graph
self.graphs[bucket].replay()
# Extract actual outputs (discard padding)
return self.output_buffers[bucket][:actual_batch].clone()
def _find_bucket(self, batch_size):
for b in self.buckets:
if b >= batch_size:
return b
return None
def memory_usage(self):
"""Report memory used by all cached graphs."""
total = 0
for bucket in self.buckets:
input_bytes = self.input_buffers[bucket].nelement() * 2
output_bytes = self.output_buffers[bucket].nelement() * 2
total += input_bytes + output_bytes
print(f"Graph buffer memory: {total / 1e9:.2f} GB")
return total
CUDA Graph Memory Overhead (Llama-2 7B, All Bucket Sizes)
| Component | Per Bucket | 15 Buckets Total |
|---|---|---|
| Input buffers (hidden states) | Varies (bucket * hidden * 2B) | ~4 MB |
| Output buffers (logits) | Varies (bucket * vocab * 2B) | ~60 MB |
| Graph metadata (kernel params) | ~0.5 MB per graph | ~7.5 MB |
| Workspace tensors (intermediates) | ~200 MB per graph (largest bucket) | ~200 MB (shared pool) |
| Total overhead | --- | ~270 MB |
CUDA Graph Capture for a Transformer Forward Pass
Complete Implementation
class CUDAGraphTransformerDecoder:
"""Transformer decoder with CUDA graph capture for decode."""
def __init__(self, config):
self.config = config
self.hidden_size = config.hidden_size
self.num_layers = config.num_layers
self.vocab_size = config.vocab_size
self.graph = None
self.captured = False
def eager_decode(self, hidden_states, positions, kv_caches):
"""Standard eager-mode decode (no graph)."""
for i in range(self.num_layers):
hidden_states = self.layers[i].forward(
hidden_states, positions, kv_caches[i]
)
logits = self.lm_head(self.norm(hidden_states))
return logits
def capture_graph(self, batch_size, kv_caches):
"""Capture the decode forward pass as a CUDA graph."""
device = 'cuda'
# Pre-allocate persistent buffers
self._graph_hidden = torch.zeros(
batch_size, 1, self.hidden_size,
device=device, dtype=torch.float16
)
self._graph_positions = torch.zeros(
batch_size, dtype=torch.long, device=device
)
self._graph_logits = torch.empty(
batch_size, 1, self.vocab_size,
device=device, dtype=torch.float16
)
# Warmup runs (populate CUDA caches)
for _ in range(3):
self.eager_decode(
self._graph_hidden,
self._graph_positions,
kv_caches
)
torch.cuda.synchronize()
# Capture
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
self._graph_logits = self.eager_decode(
self._graph_hidden,
self._graph_positions,
kv_caches
)
self.captured = True
print(f"Graph captured for batch_size={batch_size}")
def decode(self, hidden_states, positions, kv_caches):
"""Decode using CUDA graph if available."""
if not self.captured:
return self.eager_decode(hidden_states, positions, kv_caches)
actual_batch = hidden_states.shape[0]
graph_batch = self._graph_hidden.shape[0]
if actual_batch > graph_batch:
# Cannot use graph -- too large
return self.eager_decode(hidden_states, positions, kv_caches)
# Copy input to graph buffer
self._graph_hidden[:actual_batch].copy_(hidden_states)
self._graph_positions[:actual_batch].copy_(positions)
# Replay graph
self.graph.replay()
# Return output
return self._graph_logits[:actual_batch]
Handling KV Cache Updates Within Graphs
class GraphCompatibleKVCache:
"""KV cache that works with CUDA graphs.
The KV cache grows during generation (each decode adds one
new K/V entry). The graph cannot handle dynamic allocation,
so we pre-allocate to max sequence length and track the
current length externally.
"""
def __init__(self, num_layers, num_heads, head_dim,
max_seq_len, dtype=torch.float16):
self.device = 'cuda'
self.max_seq_len = max_seq_len
self.current_len = 0
# Pre-allocate to max sequence length
# Shape: [num_layers, 2, batch, heads, max_seq, head_dim]
self.cache = torch.zeros(
num_layers, 2, 1, num_heads, max_seq_len, head_dim,
device=self.device, dtype=dtype
)
# Slot index tensor (for scatter update)
self.slot_idx = torch.zeros(1, dtype=torch.long,
device=self.device)
def append(self, layer_idx, new_k, new_v):
"""Append new K,V to the cache at current position.
This operation is captured in the CUDA graph.
The slot_idx is updated externally (CPU-side) before replay.
"""
# In-place scatter: write to the pre-allocated slot
self.cache[layer_idx, 0, :, :, self.slot_idx] = new_k
self.cache[layer_idx, 1, :, :, self.slot_idx] = new_v
def advance(self):
"""Advance the write position (called before graph replay)."""
self.current_len += 1
self.slot_idx.fill_(self.current_len - 1)
def get_kv(self, layer_idx):
"""Get the current K,V tensors for attention."""
k = self.cache[layer_idx, 0, :, :, :self.current_len]
v = self.cache[layer_idx, 1, :, :, :self.current_len]
return k, v
The KV cache length changes every decode step (grows by 1). If attention accesses cache[:current_len], the slice changes each step and the graph would need recapture. The solution: use a fixed-size cache and a position index, where attention always reads the full buffer but masks out unused positions. Alternatively, use a separate mechanism (like PagedAttention) that updates page tables outside the graph.
Performance Measurements
Benchmarking Graph vs Eager
def benchmark_graph_vs_eager(model, batch_sizes, num_iters=1000):
"""Compare graph vs eager decode latency."""
results = []
for bs in batch_sizes:
hidden = torch.randn(bs, 1, model.hidden_size,
device='cuda', dtype=torch.float16)
positions = torch.zeros(bs, dtype=torch.long, device='cuda')
# Eager timing
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(num_iters):
model.eager_decode(hidden, positions, model.kv_caches)
torch.cuda.synchronize()
eager_ms = (time.perf_counter() - start) / num_iters * 1000
# Graph timing
model.capture_graph(bs, model.kv_caches)
model._graph_hidden[:bs].copy_(hidden)
model._graph_positions[:bs].copy_(positions)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(num_iters):
model.graph.replay()
torch.cuda.synchronize()
graph_ms = (time.perf_counter() - start) / num_iters * 1000
speedup = eager_ms / graph_ms
results.append((bs, eager_ms, graph_ms, speedup))
print(f"batch={bs:3d}: eager={eager_ms:.3f} ms, "
f"graph={graph_ms:.3f} ms, "
f"speedup={speedup:.2f}x")
return results
CUDA Graph vs Eager Decode Latency (Llama-2 7B, H100)
| Batch Size | Eager (ms) | Graph (ms) | Speedup | Launch Overhead Saved |
|---|---|---|---|---|
| 1 | 5.8 | 2.3 | 2.52x | 3.5 ms (60%) |
| 4 | 6.1 | 2.6 | 2.35x | 3.5 ms (57%) |
| 16 | 7.2 | 4.0 | 1.80x | 3.2 ms (44%) |
| 64 | 12.5 | 9.8 | 1.28x | 2.7 ms (22%) |
| 256 | 38.2 | 36.1 | 1.06x | 2.1 ms (5%) |
CUDA Graph Speedup by Batch Size (Llama-2 7B Decode)
(speedup (x))Limitations and Workarounds
Operations That Cannot Be Captured
def graph_incompatible_operations():
"""Operations that prevent CUDA graph capture."""
incompatible = {
"CPU-GPU synchronization": {
"example": "torch.cuda.synchronize(), .item(), print(tensor)",
"workaround": "Remove all sync points from the captured region",
},
"Dynamic shapes": {
"example": "torch.cat with variable-length tensors",
"workaround": "Pre-allocate to max size, use masking",
},
"Host-device transfers": {
"example": "tensor.cpu(), tensor.to('cuda')",
"workaround": "Keep all data on GPU during capture",
},
"cuDNN with non-deterministic algorithms": {
"example": "torch.backends.cudnn.benchmark = True",
"workaround": "Set cudnn.benchmark = False during capture",
},
"NCCL collectives": {
"example": "dist.all_reduce(tensor)",
"workaround": "Split graph at NCCL boundaries, or use "
"NCCL graph support (CUDA 12.x+)",
},
"Dynamic control flow": {
"example": "if tensor.sum() > threshold: ...",
"workaround": "Move control flow outside the graph, or "
"capture separate graphs for each branch",
},
}
for op, info in incompatible.items():
print(f"\n{op}:")
print(f" Example: {info['example']}")
print(f" Workaround: {info['workaround']}")
If graph capture fails silently (the graph captures but produces wrong results), the most common cause is a tensor allocation inside the captured region that returns a different address on replay. Use CUDA_LAUNCH_BLOCKING=1 and torch.cuda.set_sync_debug_mode(1) during development to catch synchronization issues.
Summary
CUDA Graphs eliminate kernel launch overhead by recording a sequence of operations and replaying them with a single CPU-side launch. For small-batch LLM decode (where launch overhead exceeds GPU compute time), CUDA Graphs provide 2-2.5x speedup. The implementation requires pre-allocating all tensors, handling dynamic batch sizes through bucketed graph capture, and managing KV cache updates outside the graph.
The production pattern (used by vLLM, TensorRT-LLM): capture graphs at startup for 10-15 batch size buckets, pad actual batches to the nearest bucket during serving, and fall back to eager execution for batches exceeding the maximum captured size. The memory overhead is 200-300 MB per model โ negligible relative to model weight memory โ and the latency reduction is material for all batch sizes below 64.