A 1000-token response from Llama 70B requires 1000 sequential decode steps. Each step loads 140 GB of weights from HBM, performs a tiny GEMV, and produces a single token. The actual computation per decode step takes approximately 40ms on an H100, of which 33ms is memory bandwidth and 7ms is overhead: kernel launch, CPU-GPU synchronization, Python interpreter, and tensor allocation. Over 1000 tokens, that 7ms of overhead per step accumulates to 7 seconds of pure waste.
The Decode Overhead Budget
Let us measure where time goes in a single decode step:
import torch
import time
class DecodeProfiler:
"""Measure overhead breakdown for a single decode step."""
def __init__(self, model, device="cuda:0"):
self.model = model
self.device = device
def profile_step(self, input_ids, kv_cache, num_trials=100):
"""Measure each overhead component independently."""
results = {}
# 1. CPU-side Python overhead (module dispatch)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(num_trials):
# Just call forward without GPU sync
with torch.no_grad():
outputs = self.model(
input_ids=input_ids,
past_key_values=kv_cache,
use_cache=True,
)
torch.cuda.synchronize()
total = (time.perf_counter() - t0) / num_trials
results["total_ms"] = total * 1000
# 2. GPU compute time (using CUDA events)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
gpu_times = []
for _ in range(num_trials):
start_event.record()
with torch.no_grad():
outputs = self.model(
input_ids=input_ids,
past_key_values=kv_cache,
use_cache=True,
)
end_event.record()
torch.cuda.synchronize()
gpu_times.append(start_event.elapsed_time(end_event))
results["gpu_ms"] = sum(gpu_times) / len(gpu_times)
results["overhead_ms"] = results["total_ms"] - results["gpu_ms"]
results["overhead_pct"] = results["overhead_ms"] / results["total_ms"] * 100
return results
Decode Step Overhead Breakdown (Llama 70B, H100, Batch=1)
| Component | Time (ms) | Percentage | Source |
|---|---|---|---|
| HBM weight loading | 33.2 | 78.9% | Memory bandwidth |
| Tensor core compute | 0.5 | 1.2% | GPU ALU |
| Kernel launch overhead | 3.8 | 9.0% | CPU-GPU launch queue |
| Python module dispatch | 2.1 | 5.0% | PyTorch nn.Module |
| Tensor allocation/free | 1.4 | 3.3% | CUDA allocator |
| CPU-GPU sync points | 1.1 | 2.6% | cudaStreamSynchronize |
| Total | 42.1 | 100% |
The 8.4ms of non-compute, non-bandwidth overhead (kernel launch + Python dispatch + allocation + sync) is 20% of the total decode step time. Over 1000 tokens, that is 8.4 seconds wasted. CUDA graphs eliminate most of this.
CUDA Graphs: Recording and Replaying Kernel Sequences
A CUDA graph is a recorded sequence of GPU operations (kernel launches, memory copies, memory sets) that can be replayed with a single API call. Instead of the CPU submitting each of the ~1600 kernels per decode step individually, the entire sequence is submitted as one unit.
How Graph Capture Works
import torch
class CUDAGraphManager:
"""Manage CUDA graph capture and replay for decode steps."""
def __init__(self, model, device="cuda:0"):
self.model = model
self.device = device
self.captured_graphs = {} # batch_size -> (graph, static_io)
def _allocate_static_buffers(self, batch_size, max_seq_len):
"""Allocate fixed-address buffers for graph I/O.
CUDA graphs require that tensor addresses do not change
between capture and replay."""
static_io = {
# Input buffers (filled before each replay)
"input_ids": torch.zeros(
(batch_size, 1), dtype=torch.long, device=self.device
),
"position_ids": torch.zeros(
(batch_size, 1), dtype=torch.long, device=self.device
),
"slot_mapping": torch.zeros(
(batch_size,), dtype=torch.long, device=self.device
),
# Output buffer (read after each replay)
"logits": torch.zeros(
(batch_size, 1, self.model.config.vocab_size),
dtype=torch.float16, device=self.device,
),
# KV cache is pre-allocated and persistent
# (not part of graph capture, just indexed by slot_mapping)
}
return static_io
def capture(self, batch_size, max_seq_len=8192):
"""Capture the decode forward pass as a CUDA graph."""
static_io = self._allocate_static_buffers(batch_size, max_seq_len)
# Step 1: Warmup runs (populate CUDA caches, JIT compile)
# The warmup must use the EXACT same tensor addresses
for _ in range(3):
with torch.no_grad():
logits = self.model.decode_forward(
input_ids=static_io["input_ids"],
position_ids=static_io["position_ids"],
slot_mapping=static_io["slot_mapping"],
)
static_io["logits"].copy_(logits)
# Step 2: Capture
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=None):
with torch.no_grad():
logits = self.model.decode_forward(
input_ids=static_io["input_ids"],
position_ids=static_io["position_ids"],
slot_mapping=static_io["slot_mapping"],
)
static_io["logits"].copy_(logits)
self.captured_graphs[batch_size] = (graph, static_io)
return graph, static_io
def replay(self, batch_size, input_ids, position_ids, slot_mapping):
"""Execute decode step by replaying the captured graph."""
if batch_size not in self.captured_graphs:
raise RuntimeError(f"No graph captured for batch_size={batch_size}")
graph, static_io = self.captured_graphs[batch_size]
# Copy dynamic inputs into static buffers
# These copies are tiny (batch_size * 8 bytes each)
static_io["input_ids"].copy_(input_ids)
static_io["position_ids"].copy_(position_ids)
static_io["slot_mapping"].copy_(slot_mapping)
# Replay: one CUDA API call launches all ~1600 kernels
graph.replay()
# Read output from static buffer
return static_io["logits"]
Graph Capture Constraints
CUDA graphs impose strict constraints:
-
Fixed tensor addresses: every tensor used during capture must remain at the same GPU memory address during replay. This means no dynamic allocation inside the captured region.
-
Fixed control flow: no Python
if/elsebased on tensor values. The kernel sequence is fixed at capture time. -
Fixed tensor shapes: all tensors must have the same shape during replay as during capture. Batch size cannot change.
-
No CPU-GPU synchronization:
torch.cuda.synchronize()inside the captured region is not allowed.
# What CANNOT be inside a CUDA graph:
def bad_graph_example():
"""These operations break CUDA graph capture."""
# BAD: dynamic allocation
temp = torch.empty(dynamic_size, device="cuda") # Address changes each call
# BAD: CPU-dependent control flow
if tensor.item() > 0.5: # Requires CPU-GPU sync to read tensor value
do_something()
# BAD: dynamic shapes
output = tensor[:variable_length] # Shape depends on runtime value
# BAD: Python print/logging
print(f"Value: {tensor}") # Forces sync to read tensor
# What CAN be inside a CUDA graph:
def good_graph_example(static_input, static_output, static_temp):
"""These operations work with CUDA graph capture."""
# GOOD: operations on pre-allocated static tensors
static_temp.copy_(static_input)
# GOOD: fixed-shape operations
result = torch.nn.functional.linear(static_temp, weight)
# GOOD: in-place operations on static buffers
static_output.copy_(result)
The most common CUDA graph failure in LLM serving: the scheduler changes the batch size between iterations (requests arrive and finish). Since graphs are captured per-batch-size, you need either (a) one graph per possible batch size, (b) padding to a fixed batch size, or (c) graph pool management with multiple pre-captured sizes.
Handling Variable Batch Sizes
In production, the batch size changes every iteration as requests arrive and complete. There are three strategies:
Strategy 1: Pad to Power-of-Two
class PaddedGraphManager:
"""Pad batch to next power-of-two and use pre-captured graphs."""
BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
def __init__(self, model):
self.model = model
self.graph_manager = CUDAGraphManager(model)
# Pre-capture graphs for all power-of-two batch sizes
for bs in self.BATCH_SIZES:
self.graph_manager.capture(bs)
def _next_power_of_two(self, n):
"""Find smallest captured batch size >= n."""
for bs in self.BATCH_SIZES:
if bs >= n:
return bs
return self.BATCH_SIZES[-1]
def decode_step(self, input_ids, position_ids, slot_mapping):
actual_batch = input_ids.shape[0]
padded_batch = self._next_power_of_two(actual_batch)
# Pad inputs to the graph's expected batch size
pad_size = padded_batch - actual_batch
if pad_size > 0:
input_ids = torch.nn.functional.pad(input_ids, (0, 0, 0, pad_size))
position_ids = torch.nn.functional.pad(position_ids, (0, 0, 0, pad_size))
slot_mapping = torch.nn.functional.pad(slot_mapping, (0, pad_size))
logits = self.graph_manager.replay(
padded_batch, input_ids, position_ids, slot_mapping
)
# Return only the non-padded outputs
return logits[:actual_batch]
Padding waste analysis: for batch sizes uniformly distributed in [1, 512], the average padding overhead with power-of-two buckets is:
Strategy 2: Fine-Grained Buckets
class BucketedGraphManager:
"""Use fine-grained batch size buckets to reduce padding waste."""
def __init__(self, model, bucket_step=8, max_batch=512):
self.model = model
self.graph_manager = CUDAGraphManager(model)
# Capture every multiple of 8 from 8 to max_batch
# Plus individual sizes 1-7 for small batches
self.buckets = list(range(1, 8)) + list(range(8, max_batch + 1, bucket_step))
for bs in self.buckets:
self.graph_manager.capture(bs)
def _next_bucket(self, n):
for bs in self.buckets:
if bs >= n:
return bs
return self.buckets[-1]
With step=8 buckets, the average waste drops to approximately 4%. The cost is more captured graphs: 64 + 7 = 71 graphs instead of 10, each consuming GPU memory for the graph’s kernel argument buffers (typically 10-50 MB per graph).
Strategy 3: Graph Pool with LRU Eviction
from collections import OrderedDict
class LRUGraphPool:
"""Capture graphs on-demand with LRU eviction."""
def __init__(self, model, max_cached_graphs=32):
self.model = model
self.max_cached = max_cached_graphs
self.graph_manager = CUDAGraphManager(model)
self.cache = OrderedDict() # batch_size -> graph, LRU order
def get_graph(self, batch_size):
if batch_size in self.cache:
# Move to end (most recently used)
self.cache.move_to_end(batch_size)
return self.cache[batch_size]
# Capture new graph
if len(self.cache) >= self.max_cached:
# Evict least recently used
evicted_bs, _ = self.cache.popitem(last=False)
del self.graph_manager.captured_graphs[evicted_bs]
torch.cuda.empty_cache()
graph, static_io = self.graph_manager.capture(batch_size)
self.cache[batch_size] = (graph, static_io)
return graph, static_io
Graph Management Strategy Comparison
| Strategy | Num Graphs | GPU Memory (MB) | Avg Padding Waste | First-Use Latency |
|---|---|---|---|---|
| Power-of-two | 10 | ~500 | 25% | ~100ms (capture) |
| Step-8 buckets | 71 | ~3500 | 4% | ~100ms (capture) |
| LRU pool (32) | Up to 32 | ~1600 | 0% | ~100ms (miss) |
| No graphs | 0 | 0 | 0% | 0ms |
Persistent Batches: Eliminating Tensor Re-allocation
Beyond CUDA graphs, another source of overhead is tensor allocation. Every decode step allocates intermediate tensors (attention scores, FFN activations, layer outputs) and frees them afterward. The CUDA memory allocator, even with caching (PyTorch’s CUDACachingAllocator), incurs overhead per allocation.
Persistent batches pre-allocate all intermediate tensors and reuse them across decode iterations:
class PersistentBatchDecoder:
"""Pre-allocate all intermediate tensors for decode.
Eliminates per-step allocation overhead."""
def __init__(self, model_config, max_batch_size, device="cuda:0"):
self.config = model_config
self.max_batch = max_batch_size
self.device = device
d = model_config.hidden_size
num_heads = model_config.num_attention_heads
head_dim = d // num_heads
num_kv_heads = model_config.num_key_value_heads
ffn_dim = model_config.intermediate_size
num_layers = model_config.num_hidden_layers
# Pre-allocate ALL intermediate tensors
self.buffers = {
# Per-layer intermediates (reused across layers)
"hidden_states": torch.empty(max_batch_size, 1, d,
dtype=torch.float16, device=device),
"residual": torch.empty(max_batch_size, 1, d,
dtype=torch.float16, device=device),
"normed": torch.empty(max_batch_size, 1, d,
dtype=torch.float16, device=device),
# Attention intermediates
"q": torch.empty(max_batch_size, num_heads, 1, head_dim,
dtype=torch.float16, device=device),
"k": torch.empty(max_batch_size, num_kv_heads, 1, head_dim,
dtype=torch.float16, device=device),
"v": torch.empty(max_batch_size, num_kv_heads, 1, head_dim,
dtype=torch.float16, device=device),
"attn_output": torch.empty(max_batch_size, num_heads, 1, head_dim,
dtype=torch.float16, device=device),
# FFN intermediates
"gate": torch.empty(max_batch_size, 1, ffn_dim,
dtype=torch.float16, device=device),
"up": torch.empty(max_batch_size, 1, ffn_dim,
dtype=torch.float16, device=device),
"ffn_out": torch.empty(max_batch_size, 1, d,
dtype=torch.float16, device=device),
}
def decode_layer(self, layer_idx, batch_size):
"""Execute one decoder layer using persistent buffers."""
# All operations write to pre-allocated buffers
# No torch.empty() or torch.zeros() calls during decode
b = self.buffers
# RMSNorm (in-place to normed buffer)
rms_norm_inplace(
b["hidden_states"][:batch_size],
b["normed"][:batch_size],
self.weights[layer_idx]["input_layernorm"],
)
# Save residual
b["residual"][:batch_size].copy_(b["hidden_states"][:batch_size])
# Q, K, V projections (write to persistent buffers)
torch.mm(
b["normed"][:batch_size].view(-1, self.config.hidden_size),
self.weights[layer_idx]["q_proj"],
out=b["q"][:batch_size].view(-1, self.config.hidden_size),
)
# ... K and V projections similarly ...
# Attention (writes to attn_output buffer)
flash_attn_decode(
b["q"][:batch_size],
b["k"][:batch_size],
b["v"][:batch_size],
out=b["attn_output"][:batch_size],
)
# O projection + residual (in-place)
torch.addmm(
b["residual"][:batch_size].view(-1, self.config.hidden_size),
b["attn_output"][:batch_size].view(-1, self.config.hidden_size),
self.weights[layer_idx]["o_proj"],
out=b["hidden_states"][:batch_size].view(-1, self.config.hidden_size),
)
# FFN (similar pattern with gate, up, down projections)
# All using persistent buffers with batch_size slicing
Persistent buffers combined with CUDA graphs eliminate both allocation overhead and launch overhead. The buffers satisfy the graph constraint of fixed tensor addresses, and the pre-allocation satisfies the constraint of no dynamic memory operations during the captured region. Together, they reduce per-step overhead from 8.4ms to less than 0.5ms.
Speculative Verification: Amortizing Decode Cost
Each decode step produces one token but pays the full cost of loading all model weights. Speculative decoding amortizes this cost by verifying multiple candidate tokens in a single forward pass.
The key insight: during decode, the forward pass cost is dominated by weight loading, which is the same whether processing 1 token or K tokens. Verifying K speculative tokens costs approximately the same as generating 1 token.
class SpeculativeVerifier:
"""Verify speculative draft tokens in a single decode-like forward pass."""
def __init__(self, target_model, draft_model, num_speculative=5):
self.target = target_model
self.draft = draft_model
self.K = num_speculative
def speculative_decode_step(self, input_token, kv_cache_target, kv_cache_draft):
"""One speculative decode step:
1. Draft K tokens with small model
2. Verify all K+1 positions with target model in one pass
3. Accept prefix of correct tokens
"""
# Step 1: Draft K tokens autoregressively with draft model
draft_tokens = []
draft_probs = []
current = input_token
for i in range(self.K):
with torch.no_grad():
logits = self.draft(current, past_key_values=kv_cache_draft)
prob = torch.softmax(logits[:, -1, :], dim=-1)
token = torch.multinomial(prob, 1)
draft_tokens.append(token)
draft_probs.append(prob)
current = token
# Step 2: Verify all K+1 positions in ONE target model forward pass
# Input: [input_token, draft_0, draft_1, ..., draft_{K-1}]
verify_input = torch.cat([input_token] + draft_tokens, dim=1)
with torch.no_grad():
# This forward pass processes K+1 tokens but costs ~same as 1 token
# because weight loading (the bottleneck) is identical
target_logits = self.target(
verify_input,
past_key_values=kv_cache_target,
use_cache=True,
)
target_probs = torch.softmax(target_logits, dim=-1)
# Step 3: Accept/reject using modified rejection sampling
accepted = 0
for i in range(self.K):
# Target probability of the draft token at position i
p_target = target_probs[0, i, draft_tokens[i].item()]
# Draft probability of the draft token at position i
p_draft = draft_probs[i][0, draft_tokens[i].item()]
# Accept if target prob >= draft prob
# Otherwise accept with probability p_target / p_draft
if torch.rand(1).item() < min(1.0, p_target / p_draft):
accepted += 1
else:
break
# Sample one more token from adjusted distribution at rejection point
if accepted < self.K:
# Resample from max(0, p_target - p_draft) normalized
adjusted = torch.clamp(
target_probs[0, accepted] - draft_probs[accepted][0], min=0
)
adjusted = adjusted / adjusted.sum()
bonus_token = torch.multinomial(adjusted, 1)
else:
bonus_token = torch.multinomial(target_probs[0, self.K], 1)
accepted += 1
# Return accepted + 1 tokens from one verification pass
output_tokens = draft_tokens[:accepted] + [bonus_token]
return output_tokens, accepted + 1
Speedup Analysis
Expected tokens per verification step:
where is the acceptance rate (probability that the draft model matches the target model).
def speculative_speedup(alpha, K, draft_cost_ratio=0.1):
"""Calculate expected speedup from speculative decoding.
Args:
alpha: acceptance rate (0-1)
K: number of speculative tokens
draft_cost_ratio: draft model cost / target model cost per token
"""
# Expected accepted tokens per step
expected_tokens = (1 - alpha**(K+1)) / (1 - alpha)
# Cost per step: K draft steps + 1 verification step
# Draft steps are cheap (small model, ~10% of target cost)
cost_per_step = K * draft_cost_ratio + 1.0
# Speedup = expected_tokens / cost_per_step
speedup = expected_tokens / cost_per_step
return {
"expected_tokens": expected_tokens,
"cost_per_step": cost_per_step,
"speedup": speedup,
}
# Typical values:
# alpha=0.8, K=5: expected 4.0 tokens, cost 1.5, speedup 2.67x
# alpha=0.9, K=5: expected 4.7 tokens, cost 1.5, speedup 3.13x
# alpha=0.7, K=5: expected 3.2 tokens, cost 1.5, speedup 2.14x
Speculative Decode Speedup vs Acceptance Rate (K=5)
line| Metric | 0.5 | 0.6 | 0.7 | 0.75 | 0.8 | 0.85 | 0.9 | 0.95 |
|---|---|---|---|---|---|---|---|---|
| Expected tokens per step | ||||||||
| Speedup (with draft overhead) | ||||||||
| Theoretical max (no draft cost) |
CUDA Graph + Speculative Verification Combined
The verification forward pass processes K+1 tokens, which changes the tensor shapes from standard decode. This requires a separate CUDA graph:
class GraphedSpeculativeDecoder:
"""CUDA graph-accelerated speculative decode."""
def __init__(self, target_model, draft_model, K=5, device="cuda:0"):
self.target = target_model
self.draft = draft_model
self.K = K
self.device = device
self.graph_manager = CUDAGraphManager(target_model)
# Capture graphs for:
# 1. Standard decode (batch_size, seq_len=1) for draft model
# 2. Verification (batch_size, seq_len=K+1) for target model
self.draft_graphs = {}
self.verify_graphs = {}
def capture_verification_graph(self, batch_size):
"""Capture a graph for the verification pass.
Shape: (batch_size, K+1) input tokens."""
verify_len = self.K + 1
static_io = {
"input_ids": torch.zeros(
(batch_size, verify_len), dtype=torch.long, device=self.device
),
"position_ids": torch.zeros(
(batch_size, verify_len), dtype=torch.long, device=self.device
),
"logits": torch.zeros(
(batch_size, verify_len, self.target.config.vocab_size),
dtype=torch.float16, device=self.device,
),
}
# Warmup
for _ in range(3):
with torch.no_grad():
logits = self.target.forward(
input_ids=static_io["input_ids"],
position_ids=static_io["position_ids"],
)
static_io["logits"].copy_(logits)
# Capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
with torch.no_grad():
logits = self.target.forward(
input_ids=static_io["input_ids"],
position_ids=static_io["position_ids"],
)
static_io["logits"].copy_(logits)
self.verify_graphs[batch_size] = (graph, static_io)
def verify_step(self, batch_size, candidate_tokens, position_ids):
"""Run verification using captured graph."""
if batch_size not in self.verify_graphs:
self.capture_verification_graph(batch_size)
graph, static_io = self.verify_graphs[batch_size]
static_io["input_ids"].copy_(candidate_tokens)
static_io["position_ids"].copy_(position_ids)
graph.replay()
return static_io["logits"]
Decode Optimization Stack (Llama 70B, H100, Batch=64)
| Optimization | Time per Output Token (ms) | Cumulative Speedup | Overhead Eliminated |
|---|---|---|---|
| Baseline (eager PyTorch) | 48.5 | 1.00x | N/A |
| + CUDA graphs | 40.2 | 1.21x | Kernel launch |
| + Persistent buffers | 38.8 | 1.25x | Tensor allocation |
| + Speculative (K=5, alpha=0.8) | 15.1 | 3.21x | Per-token weight load |
| + FP8 quantization | 8.2 | 5.91x | 50% bandwidth |
vLLM’s CUDA Graph Implementation
vLLM captures CUDA graphs at startup for a set of padded batch sizes:
# Simplified from vllm/worker/model_runner.py
class ModelRunner:
GRAPH_BATCH_SIZES = [1, 2, 4] + list(range(8, 513, 8)) # 1,2,4,8,16,...,512
def capture_graphs(self):
"""Capture decode graphs for all batch sizes at startup."""
# This runs once during server initialization
# Takes ~30-60 seconds for all batch sizes
for batch_size in self.GRAPH_BATCH_SIZES:
# Create dummy inputs with correct shapes
input_ids = torch.zeros(batch_size, dtype=torch.long, device="cuda")
positions = torch.zeros(batch_size, dtype=torch.long, device="cuda")
slot_mapping = torch.arange(batch_size, dtype=torch.long, device="cuda")
# Warmup
for _ in range(2):
self.model.forward(
input_ids=input_ids,
positions=positions,
kv_caches=self.kv_caches,
attn_metadata=self._build_decode_metadata(batch_size),
)
# Capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=self.graph_pool):
output = self.model.forward(
input_ids=input_ids,
positions=positions,
kv_caches=self.kv_caches,
attn_metadata=self._build_decode_metadata(batch_size),
)
self.graph_runners[batch_size] = CUDAGraphRunner(
graph, input_ids, positions, slot_mapping, output
)
def execute_model(self, scheduled_batch):
"""Execute one iteration, using graph when possible."""
if scheduled_batch.is_decode_only:
# Find the smallest graph that fits
padded_bs = self._get_padded_batch_size(
scheduled_batch.batch_size
)
return self.graph_runners[padded_bs].replay(
scheduled_batch.input_ids,
scheduled_batch.positions,
scheduled_batch.slot_mapping,
)
else:
# Prefill or mixed batch: cannot use graph (variable shapes)
return self.model.forward(...)
The graph_pool parameter is important: it allows multiple graphs to share a memory pool, so CUDA does not allocate separate memory for each graph’s internal buffers. Without pooling, 64 captured graphs could consume 3+ GB of GPU memory just for graph bookkeeping.
CUDA graphs in vLLM are only used for pure decode batches. Mixed prefill+decode batches (from chunked prefill) use eager execution because the prefill token count varies per iteration, violating the fixed-shape constraint.
Measuring the Full Stack
import torch
import time
def benchmark_decode_optimizations(model, batch_size=64, num_steps=200):
"""Benchmark each decode optimization independently."""
device = "cuda:0"
dummy_ids = torch.randint(0, 32000, (batch_size, 1), device=device)
dummy_pos = torch.zeros(batch_size, 1, dtype=torch.long, device=device)
results = {}
# Baseline: eager, no graphs, no persistent buffers
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(num_steps):
with torch.no_grad():
_ = model(dummy_ids, dummy_pos)
torch.cuda.synchronize()
results["eager_ms"] = (time.perf_counter() - t0) * 1000 / num_steps
# With CUDA graph
gm = CUDAGraphManager(model)
gm.capture(batch_size)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(num_steps):
_ = gm.replay(batch_size, dummy_ids, dummy_pos,
torch.arange(batch_size, device=device))
torch.cuda.synchronize()
results["graph_ms"] = (time.perf_counter() - t0) * 1000 / num_steps
results["graph_speedup"] = results["eager_ms"] / results["graph_ms"]
return results
CUDA Graph Limitations and Workarounds
Several common inference features conflict with CUDA graphs:
CUDA_GRAPH_COMPATIBILITY = {
"standard_decode": {
"compatible": True,
"notes": "Core use case, works well",
},
"chunked_prefill": {
"compatible": False,
"reason": "Variable prefill chunk size changes input tensor shape",
"workaround": "Use eager mode for mixed batches, graph for pure decode",
},
"speculative_decoding": {
"compatible": "Partial",
"reason": "Verification pass has variable accepted token count",
"workaround": "Capture separate graphs for draft (fixed K) and verify (fixed K+1)",
},
"structured_output": {
"compatible": "Partial",
"reason": "Grammar-guided sampling changes logit masks per step",
"workaround": "Apply grammar mask outside the graph, to the static output buffer",
},
"lora_adapters": {
"compatible": False,
"reason": "Different LoRA weights per request change the GEMM operands",
"workaround": "Capture per-adapter graphs or use batched LoRA with fixed adapter set",
},
"dynamic_rope_scaling": {
"compatible": True,
"notes": "Position IDs are inputs, not control flow. Graph works fine.",
},
"prefix_caching": {
"compatible": True,
"notes": "Cache hit/miss changes slot_mapping but not tensor shapes",
},
}
Graph Capture Memory Overhead
Each CUDA graph stores a copy of the kernel arguments and launch parameters. For a 70B model with approximately 1600 kernels per decode step, each graph consumes 20-50 MB of GPU memory:
def estimate_graph_memory(num_kernels, avg_args_per_kernel=8,
avg_arg_bytes=8, num_graphs=64):
"""Estimate memory overhead of CUDA graph pool."""
# Per-graph: kernel node storage + argument buffers
per_graph_bytes = num_kernels * (
64 + # Node metadata
avg_args_per_kernel * avg_arg_bytes # Kernel arguments
)
# Plus internal CUDA driver allocations (empirically ~2x)
per_graph_total = per_graph_bytes * 2
total_bytes = num_graphs * per_graph_total
return {
"per_graph_mb": per_graph_total / 1e6,
"total_mb": total_bytes / 1e6,
"total_gb": total_bytes / 1e9,
}
# Llama 70B: ~1600 kernels, 64 graphs (batch sizes 1-512 in steps of 8)
# Per graph: ~30 MB, Total: ~1.9 GB
# This is 2.4% of an H100's HBM -- a worthwhile tradeoff for 20% decode speedup
CUDA Graph Memory Overhead vs Decode Speedup
| Num Graphs | Batch Sizes Covered | Memory Overhead (MB) | Avg Padding Waste | Decode Speedup |
|---|---|---|---|---|
| 10 | Powers of 2 (1-512) | 300 | 25% | 1.15x (avg with padding) |
| 32 | Every 16 (1-512) | 960 | 8% | 1.19x |
| 64 | Every 8 (1-512) | 1920 | 4% | 1.20x |
| 128 | Every 4 (1-512) | 3840 | 2% | 1.21x |
The sweet spot is 64 graphs with step-8 buckets: 1.9 GB of memory overhead (2.4% of H100 HBM) for 1.20x decode speedup with only 4% padding waste. Going finer-grained to 128 graphs doubles the memory cost but provides diminishing returns on padding efficiency.
The decode optimization stack is cumulative: CUDA graphs remove launch overhead, persistent buffers remove allocation overhead, speculative decoding removes the per-token weight loading cost, and quantization reduces the weight bytes themselves. Each targets a different component of the total decode latency, and together they can reduce per-token time by 5-6x on production workloads.