Each CUDA kernel launch costs 5-15µs of CPU overhead. In LLM decode, we launch hundreds of kernels per token. At 100 tokens/second, kernel launch overhead becomes 50-150ms/second—a significant fraction of compute time. CUDA Graphs eliminate this by capturing and replaying entire kernel sequences.

The Launch Overhead Problem

Profiling a typical decode iteration:

$ nsys profile --stats=true python decode_one_token.py

CUDA API Statistics:
 Time(%)  Total Time (ns)  Num Calls   Avg (ns)   Name
 -------  ---------------  ---------   --------   ----
    42.3       12,690,000       847     14,988   cudaLaunchKernel
    28.1        8,430,000       212     39,764   cudaMemcpyAsync
    18.4        5,520,000       424     13,019   cudaStreamSynchronize
    ...

847 kernel launches at ~15µs each = 12.7ms of CPU overhead per token. With a 30ms/token target latency, this is 42% overhead.

CUDA Graph Basics

A CUDA graph captures a sequence of operations for replay:

// Basic graph capture and execution
cudaGraph_t graph;
cudaGraphExec_t graphExec;

// Capture phase
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);

// All operations here are captured, not executed
kernel_a<<<blocks, threads, 0, stream>>>(args...);
kernel_b<<<blocks, threads, 0, stream>>>(args...);
cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToDevice, stream);
kernel_c<<<blocks, threads, 0, stream>>>(args...);

cudaStreamEndCapture(stream, &graph);

// Instantiate for execution
cudaGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0);

// Replay phase - all captured ops execute with single CPU call
for (int i = 0; i < num_iterations; i++) {
    cudaGraphLaunch(graphExec, stream);  // ~5µs total, not 847 × 15µs
}
Launch Time Comparison

847 individual launches: ~12.7ms CPU time. One graph launch: ~5µs CPU time. That’s a 2,500x reduction in launch overhead.

The Dynamic Shape Challenge

LLM inference has dynamic shapes:

  • Batch size changes as requests arrive/complete
  • Sequence lengths vary per request
  • KV cache positions change each iteration

Standard CUDA graphs require fixed shapes. Solutions:

1. Graph Pools by Shape

class CUDAGraphPool:
    """Maintain separate graphs for different batch sizes."""
    
    def __init__(self, model, max_batch_size: int):
        self.model = model
        self.graphs: Dict[int, torch.cuda.CUDAGraph] = {}
        self.static_inputs: Dict[int, Dict[str, torch.Tensor]] = {}
        
        # Pre-capture graphs for common batch sizes
        for batch_size in [1, 2, 4, 8, 16, 32]:
            if batch_size <= max_batch_size:
                self._capture_graph(batch_size)
    
    def _capture_graph(self, batch_size: int):
        """Capture graph for specific batch size."""
        # Create static input tensors (will be reused)
        static_inputs = {
            'input_ids': torch.zeros((batch_size, 1), dtype=torch.long, device='cuda'),
            'position_ids': torch.zeros((batch_size, 1), dtype=torch.long, device='cuda'),
            'kv_cache_positions': torch.zeros(batch_size, dtype=torch.long, device='cuda'),
        }
        self.static_inputs[batch_size] = static_inputs
        
        # Warm up
        with torch.no_grad():
            self.model(**static_inputs)
        torch.cuda.synchronize()
        
        # Capture
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph):
            with torch.no_grad():
                output = self.model(**static_inputs)
        
        self.graphs[batch_size] = graph
        self.static_outputs[batch_size] = output
    
    def execute(self, input_ids, position_ids, kv_cache_positions):
        """Execute appropriate graph, padding if necessary."""
        batch_size = input_ids.shape[0]
        
        # Find smallest graph that fits
        graph_batch_size = None
        for size in sorted(self.graphs.keys()):
            if size >= batch_size:
                graph_batch_size = size
                break
        
        if graph_batch_size is None:
            # Fall back to eager execution for large batches
            return self.model(input_ids, position_ids, kv_cache_positions)
        
        # Copy inputs to static tensors
        static = self.static_inputs[graph_batch_size]
        static['input_ids'][:batch_size].copy_(input_ids)
        static['position_ids'][:batch_size].copy_(position_ids)
        static['kv_cache_positions'][:batch_size].copy_(kv_cache_positions)
        
        # Replay graph
        self.graphs[graph_batch_size].replay()
        
        # Return only valid outputs
        return self.static_outputs[graph_batch_size][:batch_size]

2. Padded Fixed-Size Batches

def pad_to_graph_batch(inputs: Dict[str, torch.Tensor], target_size: int):
    """Pad batch to fixed size for graph execution."""
    batch_size = inputs['input_ids'].shape[0]
    
    if batch_size == target_size:
        return inputs, slice(None)
    
    padded = {}
    for key, tensor in inputs.items():
        pad_shape = list(tensor.shape)
        pad_shape[0] = target_size
        
        padded_tensor = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
        padded_tensor[:batch_size] = tensor
        padded[key] = padded_tensor
    
    return padded, slice(0, batch_size)

vLLM’s CUDA Graph Implementation

vLLM captures graphs for decode-only (single token generation):

# Simplified from vLLM's cuda_graph_runner.py
class CUDAGraphRunner:
    def __init__(self, model, max_batch_size: int, max_context_len: int):
        self.model = model
        self.graphs: Dict[int, CUDAGraph] = {}
        
        # Capture graphs for batch sizes that are powers of 2
        self.batch_size_pool = [1, 2, 4, 8, 16, 32, 64, 128, 256]
        self.batch_size_pool = [b for b in self.batch_size_pool if b <= max_batch_size]
        
    def capture(self):
        """Capture all graphs. Call once after model warmup."""
        # Use the actual model memory for static tensors
        # This avoids extra memory allocation
        
        for batch_size in self.batch_size_pool:
            graph = torch.cuda.CUDAGraph()
            
            # Prepare static inputs matching model's expected format
            static_input_tokens = torch.zeros(
                batch_size, dtype=torch.long, device='cuda'
            )
            static_positions = torch.zeros(
                batch_size, dtype=torch.long, device='cuda'
            )
            
            # Warm-up run (required before capture)
            with torch.inference_mode():
                self.model.forward(
                    static_input_tokens,
                    static_positions,
                    is_cuda_graph_capture=True,
                )
            torch.cuda.synchronize()
            
            # Actual capture
            with torch.cuda.graph(graph, stream=torch.cuda.current_stream()):
                with torch.inference_mode():
                    hidden_states = self.model.forward(
                        static_input_tokens,
                        static_positions,
                        is_cuda_graph_capture=True,
                    )
            
            self.graphs[batch_size] = CapturedGraph(
                graph=graph,
                input_tokens=static_input_tokens,
                positions=static_positions,
                output=hidden_states,
            )
    
    def execute(self, input_tokens, positions) -> torch.Tensor:
        batch_size = input_tokens.shape[0]
        
        # Find appropriate graph
        graph_batch_size = self._get_graph_batch_size(batch_size)
        captured = self.graphs[graph_batch_size]
        
        # Update static inputs
        captured.input_tokens[:batch_size].copy_(input_tokens)
        captured.positions[:batch_size].copy_(positions)
        
        # Replay
        captured.graph.replay()
        
        return captured.output[:batch_size]
⚠️ Capture Constraints

During graph capture, you cannot: allocate memory (cudaMalloc), use CPU-dependent control flow, access CPU tensors, or synchronize with CPU. All shapes must be static.

Memory Considerations

CUDA graphs consume GPU memory for storing captured operations:

def measure_graph_memory_overhead(model, batch_sizes):
    """Measure memory overhead of graph capture."""
    torch.cuda.reset_peak_memory_stats()
    base_memory = torch.cuda.memory_allocated()
    
    graphs = {}
    for bs in batch_sizes:
        # Capture graph
        graph = torch.cuda.CUDAGraph()
        static_input = torch.zeros(bs, 1, dtype=torch.long, device='cuda')
        
        with torch.cuda.graph(graph):
            model(static_input)
        
        graphs[bs] = graph
    
    graph_memory = torch.cuda.memory_allocated() - base_memory
    peak_memory = torch.cuda.max_memory_allocated() - base_memory
    
    return {
        'graph_memory_mb': graph_memory / 1024**2,
        'peak_memory_mb': peak_memory / 1024**2,
        'per_graph_mb': graph_memory / len(batch_sizes) / 1024**2,
    }

# Typical results for Llama-70B:
# graph_memory_mb: 480 (for 8 batch sizes)
# per_graph_mb: 60 (per graph)
📊

CUDA Graph Memory Overhead (Llama-70B)

Batch SizesGraph MemoryCapture TimeNote
[1] 45 MB 2.1s Minimal
[1, 4, 16] 142 MB 6.3s Common
[1, 2, 4, 8, 16, 32, 64] 485 MB 14.8s Full pool
[1..256, step=1] 1.8 GB 52s Excessive
Note: Memory measured after capture, time includes warmup

Performance Impact

Decode Latency Reduction with CUDA Graphs

(ms)
Without Graphs
32.4 ms
With Graphs -39%
19.8 ms
📊

End-to-End Performance Comparison (A100-80GB)

ConfigurationThroughputP50 LatencyP99 Latency
Eager Execution 3,234 tok/s 28ms 45ms
CUDA Graphs (8 sizes) 4,891 tok/s 18ms 32ms
Improvement +51% -36% -29%
Note: Llama-70B, batch 1-32, decode phase only

Debugging Graph Issues

Common problems and solutions:

# Problem: Graph capture fails silently
# Solution: Enable capture error checking

try:
    with torch.cuda.graph(graph, capture_error_mode="strict"):
        model(input)
except RuntimeError as e:
    print(f"Graph capture failed: {e}")
    # Common causes:
    # - cudaMalloc during capture (dynamic allocation)
    # - CPU tensor access
    # - Data-dependent control flow

# Problem: Graph replay produces wrong results
# Solution: Verify input tensors are updated correctly

def debug_graph_inputs(captured_graph, expected_inputs):
    """Verify static inputs match expected values."""
    for name, expected in expected_inputs.items():
        actual = getattr(captured_graph, name)
        if not torch.equal(actual, expected):
            print(f"Mismatch in {name}:")
            print(f"  Expected: {expected}")
            print(f"  Actual: {actual}")

# Problem: Graphs don't cover all code paths
# Solution: Profile to find uncaptured operations

import torch.profiler as profiler

with profiler.profile(
    activities=[profiler.ProfilerActivity.CUDA],
    with_stack=True,
) as prof:
    for _ in range(100):
        model.decode_step(inputs)

# Look for cudaLaunchKernel calls - these bypass graphs
for event in prof.events():
    if 'cudaLaunchKernel' in event.name:
        print(f"Uncaptured kernel: {event.stack}")

When Not to Use CUDA Graphs

Graphs have overhead; they’re not always beneficial:

  1. Prefill phase: Long sequence, many different shapes
  2. Very large batches: Launch overhead is amortized anyway
  3. Rapidly changing shapes: Graph lookup overhead dominates
  4. Memory-constrained: Graph storage may exceed budget
💡 Rule of Thumb

Use CUDA graphs for decode phase with batch sizes ≤256. For prefill or large batches, eager execution is often faster due to avoided padding overhead.

Conclusion

CUDA graphs transform LLM decode from CPU-bound to GPU-bound by eliminating kernel launch overhead. The key implementation challenges are handling dynamic batch sizes (via graph pools) and managing memory overhead. Properly implemented, graphs provide 30-50% latency reduction for decode-heavy workloads.