Each CUDA kernel launch costs 5-15µs of CPU overhead. In LLM decode, we launch hundreds of kernels per token. At 100 tokens/second, kernel launch overhead becomes 50-150ms/second—a significant fraction of compute time. CUDA Graphs eliminate this by capturing and replaying entire kernel sequences.
The Launch Overhead Problem
Profiling a typical decode iteration:
$ nsys profile --stats=true python decode_one_token.py
CUDA API Statistics:
Time(%) Total Time (ns) Num Calls Avg (ns) Name
------- --------------- --------- -------- ----
42.3 12,690,000 847 14,988 cudaLaunchKernel
28.1 8,430,000 212 39,764 cudaMemcpyAsync
18.4 5,520,000 424 13,019 cudaStreamSynchronize
...
847 kernel launches at ~15µs each = 12.7ms of CPU overhead per token. With a 30ms/token target latency, this is 42% overhead.
CUDA Graph Basics
A CUDA graph captures a sequence of operations for replay:
// Basic graph capture and execution
cudaGraph_t graph;
cudaGraphExec_t graphExec;
// Capture phase
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
// All operations here are captured, not executed
kernel_a<<<blocks, threads, 0, stream>>>(args...);
kernel_b<<<blocks, threads, 0, stream>>>(args...);
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream);
kernel_c<<<blocks, threads, 0, stream>>>(args...);
cudaStreamEndCapture(stream, &graph);
// Instantiate for execution
cudaGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0);
// Replay phase - all captured ops execute with single CPU call
for (int i = 0; i < num_iterations; i++) {
cudaGraphLaunch(graphExec, stream); // ~5µs total, not 847 × 15µs
}
847 individual launches: ~12.7ms CPU time. One graph launch: ~5µs CPU time. That’s a 2,500x reduction in launch overhead.
The Dynamic Shape Challenge
LLM inference has dynamic shapes:
- Batch size changes as requests arrive/complete
- Sequence lengths vary per request
- KV cache positions change each iteration
Standard CUDA graphs require fixed shapes. Solutions:
1. Graph Pools by Shape
class CUDAGraphPool:
"""Maintain separate graphs for different batch sizes."""
def __init__(self, model, max_batch_size: int):
self.model = model
self.graphs: Dict[int, torch.cuda.CUDAGraph] = {}
self.static_inputs: Dict[int, Dict[str, torch.Tensor]] = {}
# Pre-capture graphs for common batch sizes
for batch_size in [1, 2, 4, 8, 16, 32]:
if batch_size <= max_batch_size:
self._capture_graph(batch_size)
def _capture_graph(self, batch_size: int):
"""Capture graph for specific batch size."""
# Create static input tensors (will be reused)
static_inputs = {
'input_ids': torch.zeros((batch_size, 1), dtype=torch.long, device='cuda'),
'position_ids': torch.zeros((batch_size, 1), dtype=torch.long, device='cuda'),
'kv_cache_positions': torch.zeros(batch_size, dtype=torch.long, device='cuda'),
}
self.static_inputs[batch_size] = static_inputs
# Warm up
with torch.no_grad():
self.model(**static_inputs)
torch.cuda.synchronize()
# Capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
with torch.no_grad():
output = self.model(**static_inputs)
self.graphs[batch_size] = graph
self.static_outputs[batch_size] = output
def execute(self, input_ids, position_ids, kv_cache_positions):
"""Execute appropriate graph, padding if necessary."""
batch_size = input_ids.shape[0]
# Find smallest graph that fits
graph_batch_size = None
for size in sorted(self.graphs.keys()):
if size >= batch_size:
graph_batch_size = size
break
if graph_batch_size is None:
# Fall back to eager execution for large batches
return self.model(input_ids, position_ids, kv_cache_positions)
# Copy inputs to static tensors
static = self.static_inputs[graph_batch_size]
static['input_ids'][:batch_size].copy_(input_ids)
static['position_ids'][:batch_size].copy_(position_ids)
static['kv_cache_positions'][:batch_size].copy_(kv_cache_positions)
# Replay graph
self.graphs[graph_batch_size].replay()
# Return only valid outputs
return self.static_outputs[graph_batch_size][:batch_size]
2. Padded Fixed-Size Batches
def pad_to_graph_batch(inputs: Dict[str, torch.Tensor], target_size: int):
"""Pad batch to fixed size for graph execution."""
batch_size = inputs['input_ids'].shape[0]
if batch_size == target_size:
return inputs, slice(None)
padded = {}
for key, tensor in inputs.items():
pad_shape = list(tensor.shape)
pad_shape[0] = target_size
padded_tensor = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
padded_tensor[:batch_size] = tensor
padded[key] = padded_tensor
return padded, slice(0, batch_size)
vLLM’s CUDA Graph Implementation
vLLM captures graphs for decode-only (single token generation):
# Simplified from vLLM's cuda_graph_runner.py
class CUDAGraphRunner:
def __init__(self, model, max_batch_size: int, max_context_len: int):
self.model = model
self.graphs: Dict[int, CUDAGraph] = {}
# Capture graphs for batch sizes that are powers of 2
self.batch_size_pool = [1, 2, 4, 8, 16, 32, 64, 128, 256]
self.batch_size_pool = [b for b in self.batch_size_pool if b <= max_batch_size]
def capture(self):
"""Capture all graphs. Call once after model warmup."""
# Use the actual model memory for static tensors
# This avoids extra memory allocation
for batch_size in self.batch_size_pool:
graph = torch.cuda.CUDAGraph()
# Prepare static inputs matching model's expected format
static_input_tokens = torch.zeros(
batch_size, dtype=torch.long, device='cuda'
)
static_positions = torch.zeros(
batch_size, dtype=torch.long, device='cuda'
)
# Warm-up run (required before capture)
with torch.inference_mode():
self.model.forward(
static_input_tokens,
static_positions,
is_cuda_graph_capture=True,
)
torch.cuda.synchronize()
# Actual capture
with torch.cuda.graph(graph, stream=torch.cuda.current_stream()):
with torch.inference_mode():
hidden_states = self.model.forward(
static_input_tokens,
static_positions,
is_cuda_graph_capture=True,
)
self.graphs[batch_size] = CapturedGraph(
graph=graph,
input_tokens=static_input_tokens,
positions=static_positions,
output=hidden_states,
)
def execute(self, input_tokens, positions) -> torch.Tensor:
batch_size = input_tokens.shape[0]
# Find appropriate graph
graph_batch_size = self._get_graph_batch_size(batch_size)
captured = self.graphs[graph_batch_size]
# Update static inputs
captured.input_tokens[:batch_size].copy_(input_tokens)
captured.positions[:batch_size].copy_(positions)
# Replay
captured.graph.replay()
return captured.output[:batch_size]
During graph capture, you cannot: allocate memory (cudaMalloc), use CPU-dependent control flow, access CPU tensors, or synchronize with CPU. All shapes must be static.
Memory Considerations
CUDA graphs consume GPU memory for storing captured operations:
def measure_graph_memory_overhead(model, batch_sizes):
"""Measure memory overhead of graph capture."""
torch.cuda.reset_peak_memory_stats()
base_memory = torch.cuda.memory_allocated()
graphs = {}
for bs in batch_sizes:
# Capture graph
graph = torch.cuda.CUDAGraph()
static_input = torch.zeros(bs, 1, dtype=torch.long, device='cuda')
with torch.cuda.graph(graph):
model(static_input)
graphs[bs] = graph
graph_memory = torch.cuda.memory_allocated() - base_memory
peak_memory = torch.cuda.max_memory_allocated() - base_memory
return {
'graph_memory_mb': graph_memory / 1024**2,
'peak_memory_mb': peak_memory / 1024**2,
'per_graph_mb': graph_memory / len(batch_sizes) / 1024**2,
}
# Typical results for Llama-70B:
# graph_memory_mb: 480 (for 8 batch sizes)
# per_graph_mb: 60 (per graph)
CUDA Graph Memory Overhead (Llama-70B)
| Batch Sizes | Graph Memory | Capture Time | Note |
|---|---|---|---|
| [1] | 45 MB | 2.1s | Minimal |
| [1, 4, 16] | 142 MB | 6.3s | Common |
| [1, 2, 4, 8, 16, 32, 64] | 485 MB | 14.8s | Full pool |
| [1..256, step=1] | 1.8 GB | 52s | Excessive |
Performance Impact
Decode Latency Reduction with CUDA Graphs
(ms)End-to-End Performance Comparison (A100-80GB)
| Configuration | Throughput | P50 Latency | P99 Latency |
|---|---|---|---|
| Eager Execution | 3,234 tok/s | 28ms | 45ms |
| CUDA Graphs (8 sizes) | 4,891 tok/s | 18ms | 32ms |
| Improvement | +51% | -36% | -29% |
Debugging Graph Issues
Common problems and solutions:
# Problem: Graph capture fails silently
# Solution: Enable capture error checking
try:
with torch.cuda.graph(graph, capture_error_mode="strict"):
model(input)
except RuntimeError as e:
print(f"Graph capture failed: {e}")
# Common causes:
# - cudaMalloc during capture (dynamic allocation)
# - CPU tensor access
# - Data-dependent control flow
# Problem: Graph replay produces wrong results
# Solution: Verify input tensors are updated correctly
def debug_graph_inputs(captured_graph, expected_inputs):
"""Verify static inputs match expected values."""
for name, expected in expected_inputs.items():
actual = getattr(captured_graph, name)
if not torch.equal(actual, expected):
print(f"Mismatch in {name}:")
print(f" Expected: {expected}")
print(f" Actual: {actual}")
# Problem: Graphs don't cover all code paths
# Solution: Profile to find uncaptured operations
import torch.profiler as profiler
with profiler.profile(
activities=[profiler.ProfilerActivity.CUDA],
with_stack=True,
) as prof:
for _ in range(100):
model.decode_step(inputs)
# Look for cudaLaunchKernel calls - these bypass graphs
for event in prof.events():
if 'cudaLaunchKernel' in event.name:
print(f"Uncaptured kernel: {event.stack}")
When Not to Use CUDA Graphs
Graphs have overhead; they’re not always beneficial:
- Prefill phase: Long sequence, many different shapes
- Very large batches: Launch overhead is amortized anyway
- Rapidly changing shapes: Graph lookup overhead dominates
- Memory-constrained: Graph storage may exceed budget
Use CUDA graphs for decode phase with batch sizes ≤256. For prefill or large batches, eager execution is often faster due to avoided padding overhead.
Conclusion
CUDA graphs transform LLM decode from CPU-bound to GPU-bound by eliminating kernel launch overhead. The key implementation challenges are handling dynamic batch sizes (via graph pools) and managing memory overhead. Properly implemented, graphs provide 30-50% latency reduction for decode-heavy workloads.