Part of Series Inference Optimization Timeline 24 of 60
1 Transformer Fundamentals for Systems Engineers: The 10-Minute Bridge from Architecture to Inference 2 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 3 KV Cache: The Hidden Memory Giant in LLM Serving 4 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 5 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 6 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 7 Continuous Batching: The Complete Guide to LLM Inference Scheduling 8 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 9 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 10 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 11 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 12 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 13 Mamba and State Space Models: The O(n) Alternative to Attention 14 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 15 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 16 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 17 Model Loading and Cold Start: safetensors, mmap, and Startup Optimization 18 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 19 Kernel Autotuning: How TensorRT and torch.compile Find Optimal CUDA Kernels 20 Attention Kernel Comparison: FlashAttention vs FlashInfer vs xformers vs Triton 21 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 22 Dynamic Batching: Orca, Sarathi, and Iteration-Level Scheduling Algorithms 23 Memory Pool Management: Slab Allocators for GPU Inference 24 Prefill vs Decode Optimization: Different Bottlenecks, Different Solutions 25 Decode Optimization: CUDA Graphs, Persistent Batches, and Speculative Verification 26 Multi-Model Serving: GPU Sharing, Model Switching, and Adapter Pool Management 27 Structured Output Acceleration: Compressed FSMs, Speculative JSON, and Grammar Caching 28 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 29 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 30 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 31 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 32 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification 33 Disaggregated Serving v2: Mooncake KV-Centric Architecture and LoongServe Elastic SP 34 Request Preemption and Priority Scheduling in Production LLM Serving 35 Autoscaling LLM Inference: Signals, Lag, Warm Pools, and Cost-Optimal Scaling 36 The Inference Stack in 2026: From HTTP Request to GPU Kernel and Back 37 Video and Audio LLM Serving: Temporal Encoding, Chunked Streaming, and Latency Budgets 38 KV Cache Compression and Eviction: H2O, Attention Sinks, Sliding Window, and Quantized KV 39 Distributed Inference: Tensor Parallelism vs Pipeline Parallelism for Serving 40 Serving Benchmark Methodology: How to Properly Measure LLM Inference Performance 41 Compute-Communication Overlap: Hiding Distributed Training Latency 42 DeepSpeed ZeRO: Memory Optimization for Distributed Training at Scale 43 Pipeline Parallelism: From GPipe to DualPipe -- Eliminating the Bubble 44 Gradient Compression for Distributed Training: Promise, Reality, and Where It Still Wins 45 The Definitive Guide to Distributed Parallelism: Data, Tensor, Pipeline, Expert, and Sequence Parallelism for Large-Scale Training 46 Decoding Performance: Beam Search vs Sampling — Latency, Throughput, Memory, and the Full Design Space 47 LLM Prefill Phase Optimization: Why Prompt Processing Is Compute-Bound and How to Fix It 48 LLM Serving Engines: vLLM vs SGLang vs TensorRT-LLM — A Systems Comparison 49 Request Routing for LLM Inference: From Naive Load Balancing to KV Cache-Aware Scheduling 50 Why Adam Is Expensive and What To Do About It: 8-bit Adam, Adafactor, CAME, and the Memory Math of Optimizers 51 How Large Models Actually Get Loaded: Safetensors, mmap, Tensor Parallelism, and Progressive Loading 52 Mixed Precision Training: The Complete Precision Landscape from FP32 to FP4 53 Model Compression: Pruning, Distillation, and Why Quantization Won 54 From NAS to Scaling Laws: How We Design LLM Architectures Now 55 NVIDIA NCCL Performance Tuning for Multi-GPU Training 56 ONNX Runtime in Practice: Graph Optimization, Execution Providers, Quantization, and When ORT Is the Right Choice 57 Optimizing GEMM for Neural Networks: BLAS vs Custom Kernels (Nov 2019) 58 Long Context: From Sparse Attention to Ring Attention 59 TensorRT-LLM: Graph Optimization for Maximum Inference Performance 60 Long Context LLMs: From 2K to 1M Tokens

A 1000-token response from Llama 70B requires 1000 sequential decode steps. Each step loads 140 GB of weights from HBM, performs a tiny GEMV, and produces a single token. The actual computation per decode step takes approximately 40ms on an H100, of which 33ms is memory bandwidth and 7ms is overhead: kernel launch, CPU-GPU synchronization, Python interpreter, and tensor allocation. Over 1000 tokens, that 7ms of overhead per step accumulates to 7 seconds of pure waste.

The Decode Overhead Budget

Let us measure where time goes in a single decode step:

import torch
import time

class DecodeProfiler:
    """Measure overhead breakdown for a single decode step."""

    def __init__(self, model, device="cuda:0"):
        self.model = model
        self.device = device

    def profile_step(self, input_ids, kv_cache, num_trials=100):
        """Measure each overhead component independently."""
        results = {}

        # 1. CPU-side Python overhead (module dispatch)
        torch.cuda.synchronize()
        t0 = time.perf_counter()
        for _ in range(num_trials):
            # Just call forward without GPU sync
            with torch.no_grad():
                outputs = self.model(
                    input_ids=input_ids,
                    past_key_values=kv_cache,
                    use_cache=True,
                )
        torch.cuda.synchronize()
        total = (time.perf_counter() - t0) / num_trials
        results["total_ms"] = total * 1000

        # 2. GPU compute time (using CUDA events)
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        gpu_times = []
        for _ in range(num_trials):
            start_event.record()
            with torch.no_grad():
                outputs = self.model(
                    input_ids=input_ids,
                    past_key_values=kv_cache,
                    use_cache=True,
                )
            end_event.record()
            torch.cuda.synchronize()
            gpu_times.append(start_event.elapsed_time(end_event))

        results["gpu_ms"] = sum(gpu_times) / len(gpu_times)
        results["overhead_ms"] = results["total_ms"] - results["gpu_ms"]
        results["overhead_pct"] = results["overhead_ms"] / results["total_ms"] * 100

        return results
📊

Decode Step Overhead Breakdown (Llama 70B, H100, Batch=1)

ComponentTime (ms)PercentageSource
HBM weight loading 33.2 78.9% Memory bandwidth
Tensor core compute 0.5 1.2% GPU ALU
Kernel launch overhead 3.8 9.0% CPU-GPU launch queue
Python module dispatch 2.1 5.0% PyTorch nn.Module
Tensor allocation/free 1.4 3.3% CUDA allocator
CPU-GPU sync points 1.1 2.6% cudaStreamSynchronize
Total 42.1 100%

The 8.4ms of non-compute, non-bandwidth overhead (kernel launch + Python dispatch + allocation + sync) is 20% of the total decode step time. Over 1000 tokens, that is 8.4 seconds wasted. CUDA graphs eliminate most of this.

CUDA Graphs: Recording and Replaying Kernel Sequences

A CUDA graph is a recorded sequence of GPU operations (kernel launches, memory copies, memory sets) that can be replayed with a single API call. Instead of the CPU submitting each of the ~1600 kernels per decode step individually, the entire sequence is submitted as one unit.

How Graph Capture Works

import torch

class CUDAGraphManager:
    """Manage CUDA graph capture and replay for decode steps."""

    def __init__(self, model, device="cuda:0"):
        self.model = model
        self.device = device
        self.captured_graphs = {}  # batch_size -> (graph, static_io)

    def _allocate_static_buffers(self, batch_size, max_seq_len):
        """Allocate fixed-address buffers for graph I/O.
        CUDA graphs require that tensor addresses do not change
        between capture and replay."""

        static_io = {
            # Input buffers (filled before each replay)
            "input_ids": torch.zeros(
                (batch_size, 1), dtype=torch.long, device=self.device
            ),
            "position_ids": torch.zeros(
                (batch_size, 1), dtype=torch.long, device=self.device
            ),
            "slot_mapping": torch.zeros(
                (batch_size,), dtype=torch.long, device=self.device
            ),
            # Output buffer (read after each replay)
            "logits": torch.zeros(
                (batch_size, 1, self.model.config.vocab_size),
                dtype=torch.float16, device=self.device,
            ),
            # KV cache is pre-allocated and persistent
            # (not part of graph capture, just indexed by slot_mapping)
        }
        return static_io

    def capture(self, batch_size, max_seq_len=8192):
        """Capture the decode forward pass as a CUDA graph."""

        static_io = self._allocate_static_buffers(batch_size, max_seq_len)

        # Step 1: Warmup runs (populate CUDA caches, JIT compile)
        # The warmup must use the EXACT same tensor addresses
        for _ in range(3):
            with torch.no_grad():
                logits = self.model.decode_forward(
                    input_ids=static_io["input_ids"],
                    position_ids=static_io["position_ids"],
                    slot_mapping=static_io["slot_mapping"],
                )
                static_io["logits"].copy_(logits)

        # Step 2: Capture
        torch.cuda.synchronize()
        graph = torch.cuda.CUDAGraph()

        with torch.cuda.graph(graph, pool=None):
            with torch.no_grad():
                logits = self.model.decode_forward(
                    input_ids=static_io["input_ids"],
                    position_ids=static_io["position_ids"],
                    slot_mapping=static_io["slot_mapping"],
                )
                static_io["logits"].copy_(logits)

        self.captured_graphs[batch_size] = (graph, static_io)
        return graph, static_io

    def replay(self, batch_size, input_ids, position_ids, slot_mapping):
        """Execute decode step by replaying the captured graph."""

        if batch_size not in self.captured_graphs:
            raise RuntimeError(f"No graph captured for batch_size={batch_size}")

        graph, static_io = self.captured_graphs[batch_size]

        # Copy dynamic inputs into static buffers
        # These copies are tiny (batch_size * 8 bytes each)
        static_io["input_ids"].copy_(input_ids)
        static_io["position_ids"].copy_(position_ids)
        static_io["slot_mapping"].copy_(slot_mapping)

        # Replay: one CUDA API call launches all ~1600 kernels
        graph.replay()

        # Read output from static buffer
        return static_io["logits"]

Graph Capture Constraints

CUDA graphs impose strict constraints:

  1. Fixed tensor addresses: every tensor used during capture must remain at the same GPU memory address during replay. This means no dynamic allocation inside the captured region.

  2. Fixed control flow: no Python if/else based on tensor values. The kernel sequence is fixed at capture time.

  3. Fixed tensor shapes: all tensors must have the same shape during replay as during capture. Batch size cannot change.

  4. No CPU-GPU synchronization: torch.cuda.synchronize() inside the captured region is not allowed.

# What CANNOT be inside a CUDA graph:

def bad_graph_example():
    """These operations break CUDA graph capture."""

    # BAD: dynamic allocation
    temp = torch.empty(dynamic_size, device="cuda")  # Address changes each call

    # BAD: CPU-dependent control flow
    if tensor.item() > 0.5:  # Requires CPU-GPU sync to read tensor value
        do_something()

    # BAD: dynamic shapes
    output = tensor[:variable_length]  # Shape depends on runtime value

    # BAD: Python print/logging
    print(f"Value: {tensor}")  # Forces sync to read tensor

# What CAN be inside a CUDA graph:

def good_graph_example(static_input, static_output, static_temp):
    """These operations work with CUDA graph capture."""

    # GOOD: operations on pre-allocated static tensors
    static_temp.copy_(static_input)

    # GOOD: fixed-shape operations
    result = torch.nn.functional.linear(static_temp, weight)

    # GOOD: in-place operations on static buffers
    static_output.copy_(result)
⚠️ Warning

The most common CUDA graph failure in LLM serving: the scheduler changes the batch size between iterations (requests arrive and finish). Since graphs are captured per-batch-size, you need either (a) one graph per possible batch size, (b) padding to a fixed batch size, or (c) graph pool management with multiple pre-captured sizes.

Handling Variable Batch Sizes

In production, the batch size changes every iteration as requests arrive and complete. There are three strategies:

Strategy 1: Pad to Power-of-Two

class PaddedGraphManager:
    """Pad batch to next power-of-two and use pre-captured graphs."""

    BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]

    def __init__(self, model):
        self.model = model
        self.graph_manager = CUDAGraphManager(model)
        # Pre-capture graphs for all power-of-two batch sizes
        for bs in self.BATCH_SIZES:
            self.graph_manager.capture(bs)

    def _next_power_of_two(self, n):
        """Find smallest captured batch size >= n."""
        for bs in self.BATCH_SIZES:
            if bs >= n:
                return bs
        return self.BATCH_SIZES[-1]

    def decode_step(self, input_ids, position_ids, slot_mapping):
        actual_batch = input_ids.shape[0]
        padded_batch = self._next_power_of_two(actual_batch)

        # Pad inputs to the graph's expected batch size
        pad_size = padded_batch - actual_batch
        if pad_size > 0:
            input_ids = torch.nn.functional.pad(input_ids, (0, 0, 0, pad_size))
            position_ids = torch.nn.functional.pad(position_ids, (0, 0, 0, pad_size))
            slot_mapping = torch.nn.functional.pad(slot_mapping, (0, pad_size))

        logits = self.graph_manager.replay(
            padded_batch, input_ids, position_ids, slot_mapping
        )

        # Return only the non-padded outputs
        return logits[:actual_batch]

Padding waste analysis: for batch sizes uniformly distributed in [1, 512], the average padding overhead with power-of-two buckets is:

Average waste=1512n=1512nextpow2(n)nnextpow2(n)25%\text{Average waste} = \frac{1}{512} \sum_{n=1}^{512} \frac{\text{nextpow2}(n) - n}{\text{nextpow2}(n)} \approx 25\%

Strategy 2: Fine-Grained Buckets

class BucketedGraphManager:
    """Use fine-grained batch size buckets to reduce padding waste."""

    def __init__(self, model, bucket_step=8, max_batch=512):
        self.model = model
        self.graph_manager = CUDAGraphManager(model)
        # Capture every multiple of 8 from 8 to max_batch
        # Plus individual sizes 1-7 for small batches
        self.buckets = list(range(1, 8)) + list(range(8, max_batch + 1, bucket_step))

        for bs in self.buckets:
            self.graph_manager.capture(bs)

    def _next_bucket(self, n):
        for bs in self.buckets:
            if bs >= n:
                return bs
        return self.buckets[-1]

With step=8 buckets, the average waste drops to approximately 4%. The cost is more captured graphs: 64 + 7 = 71 graphs instead of 10, each consuming GPU memory for the graph’s kernel argument buffers (typically 10-50 MB per graph).

Strategy 3: Graph Pool with LRU Eviction

from collections import OrderedDict

class LRUGraphPool:
    """Capture graphs on-demand with LRU eviction."""

    def __init__(self, model, max_cached_graphs=32):
        self.model = model
        self.max_cached = max_cached_graphs
        self.graph_manager = CUDAGraphManager(model)
        self.cache = OrderedDict()  # batch_size -> graph, LRU order

    def get_graph(self, batch_size):
        if batch_size in self.cache:
            # Move to end (most recently used)
            self.cache.move_to_end(batch_size)
            return self.cache[batch_size]

        # Capture new graph
        if len(self.cache) >= self.max_cached:
            # Evict least recently used
            evicted_bs, _ = self.cache.popitem(last=False)
            del self.graph_manager.captured_graphs[evicted_bs]
            torch.cuda.empty_cache()

        graph, static_io = self.graph_manager.capture(batch_size)
        self.cache[batch_size] = (graph, static_io)
        return graph, static_io
📊

Graph Management Strategy Comparison

StrategyNum GraphsGPU Memory (MB)Avg Padding WasteFirst-Use Latency
Power-of-two 10 ~500 25% ~100ms (capture)
Step-8 buckets 71 ~3500 4% ~100ms (capture)
LRU pool (32) Up to 32 ~1600 0% ~100ms (miss)
No graphs 0 0 0% 0ms

Persistent Batches: Eliminating Tensor Re-allocation

Beyond CUDA graphs, another source of overhead is tensor allocation. Every decode step allocates intermediate tensors (attention scores, FFN activations, layer outputs) and frees them afterward. The CUDA memory allocator, even with caching (PyTorch’s CUDACachingAllocator), incurs overhead per allocation.

Persistent batches pre-allocate all intermediate tensors and reuse them across decode iterations:

class PersistentBatchDecoder:
    """Pre-allocate all intermediate tensors for decode.
    Eliminates per-step allocation overhead."""

    def __init__(self, model_config, max_batch_size, device="cuda:0"):
        self.config = model_config
        self.max_batch = max_batch_size
        self.device = device
        d = model_config.hidden_size
        num_heads = model_config.num_attention_heads
        head_dim = d // num_heads
        num_kv_heads = model_config.num_key_value_heads
        ffn_dim = model_config.intermediate_size
        num_layers = model_config.num_hidden_layers

        # Pre-allocate ALL intermediate tensors
        self.buffers = {
            # Per-layer intermediates (reused across layers)
            "hidden_states": torch.empty(max_batch_size, 1, d,
                                         dtype=torch.float16, device=device),
            "residual": torch.empty(max_batch_size, 1, d,
                                    dtype=torch.float16, device=device),
            "normed": torch.empty(max_batch_size, 1, d,
                                  dtype=torch.float16, device=device),
            # Attention intermediates
            "q": torch.empty(max_batch_size, num_heads, 1, head_dim,
                             dtype=torch.float16, device=device),
            "k": torch.empty(max_batch_size, num_kv_heads, 1, head_dim,
                             dtype=torch.float16, device=device),
            "v": torch.empty(max_batch_size, num_kv_heads, 1, head_dim,
                             dtype=torch.float16, device=device),
            "attn_output": torch.empty(max_batch_size, num_heads, 1, head_dim,
                                       dtype=torch.float16, device=device),
            # FFN intermediates
            "gate": torch.empty(max_batch_size, 1, ffn_dim,
                                dtype=torch.float16, device=device),
            "up": torch.empty(max_batch_size, 1, ffn_dim,
                              dtype=torch.float16, device=device),
            "ffn_out": torch.empty(max_batch_size, 1, d,
                                   dtype=torch.float16, device=device),
        }

    def decode_layer(self, layer_idx, batch_size):
        """Execute one decoder layer using persistent buffers."""
        # All operations write to pre-allocated buffers
        # No torch.empty() or torch.zeros() calls during decode
        b = self.buffers

        # RMSNorm (in-place to normed buffer)
        rms_norm_inplace(
            b["hidden_states"][:batch_size],
            b["normed"][:batch_size],
            self.weights[layer_idx]["input_layernorm"],
        )

        # Save residual
        b["residual"][:batch_size].copy_(b["hidden_states"][:batch_size])

        # Q, K, V projections (write to persistent buffers)
        torch.mm(
            b["normed"][:batch_size].view(-1, self.config.hidden_size),
            self.weights[layer_idx]["q_proj"],
            out=b["q"][:batch_size].view(-1, self.config.hidden_size),
        )
        # ... K and V projections similarly ...

        # Attention (writes to attn_output buffer)
        flash_attn_decode(
            b["q"][:batch_size],
            b["k"][:batch_size],
            b["v"][:batch_size],
            out=b["attn_output"][:batch_size],
        )

        # O projection + residual (in-place)
        torch.addmm(
            b["residual"][:batch_size].view(-1, self.config.hidden_size),
            b["attn_output"][:batch_size].view(-1, self.config.hidden_size),
            self.weights[layer_idx]["o_proj"],
            out=b["hidden_states"][:batch_size].view(-1, self.config.hidden_size),
        )

        # FFN (similar pattern with gate, up, down projections)
        # All using persistent buffers with batch_size slicing
Performance

Persistent buffers combined with CUDA graphs eliminate both allocation overhead and launch overhead. The buffers satisfy the graph constraint of fixed tensor addresses, and the pre-allocation satisfies the constraint of no dynamic memory operations during the captured region. Together, they reduce per-step overhead from 8.4ms to less than 0.5ms.

Speculative Verification: Amortizing Decode Cost

Each decode step produces one token but pays the full cost of loading all model weights. Speculative decoding amortizes this cost by verifying multiple candidate tokens in a single forward pass.

The key insight: during decode, the forward pass cost is dominated by weight loading, which is the same whether processing 1 token or K tokens. Verifying K speculative tokens costs approximately the same as generating 1 token.

class SpeculativeVerifier:
    """Verify speculative draft tokens in a single decode-like forward pass."""

    def __init__(self, target_model, draft_model, num_speculative=5):
        self.target = target_model
        self.draft = draft_model
        self.K = num_speculative

    def speculative_decode_step(self, input_token, kv_cache_target, kv_cache_draft):
        """One speculative decode step:
        1. Draft K tokens with small model
        2. Verify all K+1 positions with target model in one pass
        3. Accept prefix of correct tokens
        """

        # Step 1: Draft K tokens autoregressively with draft model
        draft_tokens = []
        draft_probs = []
        current = input_token

        for i in range(self.K):
            with torch.no_grad():
                logits = self.draft(current, past_key_values=kv_cache_draft)
                prob = torch.softmax(logits[:, -1, :], dim=-1)
                token = torch.multinomial(prob, 1)
                draft_tokens.append(token)
                draft_probs.append(prob)
                current = token

        # Step 2: Verify all K+1 positions in ONE target model forward pass
        # Input: [input_token, draft_0, draft_1, ..., draft_{K-1}]
        verify_input = torch.cat([input_token] + draft_tokens, dim=1)

        with torch.no_grad():
            # This forward pass processes K+1 tokens but costs ~same as 1 token
            # because weight loading (the bottleneck) is identical
            target_logits = self.target(
                verify_input,
                past_key_values=kv_cache_target,
                use_cache=True,
            )
            target_probs = torch.softmax(target_logits, dim=-1)

        # Step 3: Accept/reject using modified rejection sampling
        accepted = 0
        for i in range(self.K):
            # Target probability of the draft token at position i
            p_target = target_probs[0, i, draft_tokens[i].item()]
            # Draft probability of the draft token at position i
            p_draft = draft_probs[i][0, draft_tokens[i].item()]

            # Accept if target prob >= draft prob
            # Otherwise accept with probability p_target / p_draft
            if torch.rand(1).item() < min(1.0, p_target / p_draft):
                accepted += 1
            else:
                break

        # Sample one more token from adjusted distribution at rejection point
        if accepted < self.K:
            # Resample from max(0, p_target - p_draft) normalized
            adjusted = torch.clamp(
                target_probs[0, accepted] - draft_probs[accepted][0], min=0
            )
            adjusted = adjusted / adjusted.sum()
            bonus_token = torch.multinomial(adjusted, 1)
        else:
            bonus_token = torch.multinomial(target_probs[0, self.K], 1)
            accepted += 1

        # Return accepted + 1 tokens from one verification pass
        output_tokens = draft_tokens[:accepted] + [bonus_token]
        return output_tokens, accepted + 1

Speedup Analysis

Expected tokens per verification step: E[tokens]=1αK+11αE[\text{tokens}] = \frac{1 - \alpha^{K+1}}{1 - \alpha}

where α\alpha is the acceptance rate (probability that the draft model matches the target model).

def speculative_speedup(alpha, K, draft_cost_ratio=0.1):
    """Calculate expected speedup from speculative decoding.

    Args:
        alpha: acceptance rate (0-1)
        K: number of speculative tokens
        draft_cost_ratio: draft model cost / target model cost per token
    """
    # Expected accepted tokens per step
    expected_tokens = (1 - alpha**(K+1)) / (1 - alpha)

    # Cost per step: K draft steps + 1 verification step
    # Draft steps are cheap (small model, ~10% of target cost)
    cost_per_step = K * draft_cost_ratio + 1.0

    # Speedup = expected_tokens / cost_per_step
    speedup = expected_tokens / cost_per_step

    return {
        "expected_tokens": expected_tokens,
        "cost_per_step": cost_per_step,
        "speedup": speedup,
    }

# Typical values:
# alpha=0.8, K=5: expected 4.0 tokens, cost 1.5, speedup 2.67x
# alpha=0.9, K=5: expected 4.7 tokens, cost 1.5, speedup 3.13x
# alpha=0.7, K=5: expected 3.2 tokens, cost 1.5, speedup 2.14x

Speculative Decode Speedup vs Acceptance Rate (K=5)

line
Metric 0.50.60.70.750.80.850.90.95
Expected tokens per step
1.97
2.45
3.08
3.43
3.86
4.35
4.69
4.97
Speedup (with draft overhead)
1.31
1.63
2.05
2.29
2.57
2.9
3.13
3.31
Theoretical max (no draft cost)
1.97
2.45
3.08
3.43
3.86
4.35
4.69
4.97

CUDA Graph + Speculative Verification Combined

The verification forward pass processes K+1 tokens, which changes the tensor shapes from standard decode. This requires a separate CUDA graph:

class GraphedSpeculativeDecoder:
    """CUDA graph-accelerated speculative decode."""

    def __init__(self, target_model, draft_model, K=5, device="cuda:0"):
        self.target = target_model
        self.draft = draft_model
        self.K = K
        self.device = device
        self.graph_manager = CUDAGraphManager(target_model)

        # Capture graphs for:
        # 1. Standard decode (batch_size, seq_len=1) for draft model
        # 2. Verification (batch_size, seq_len=K+1) for target model
        self.draft_graphs = {}
        self.verify_graphs = {}

    def capture_verification_graph(self, batch_size):
        """Capture a graph for the verification pass.
        Shape: (batch_size, K+1) input tokens."""
        verify_len = self.K + 1

        static_io = {
            "input_ids": torch.zeros(
                (batch_size, verify_len), dtype=torch.long, device=self.device
            ),
            "position_ids": torch.zeros(
                (batch_size, verify_len), dtype=torch.long, device=self.device
            ),
            "logits": torch.zeros(
                (batch_size, verify_len, self.target.config.vocab_size),
                dtype=torch.float16, device=self.device,
            ),
        }

        # Warmup
        for _ in range(3):
            with torch.no_grad():
                logits = self.target.forward(
                    input_ids=static_io["input_ids"],
                    position_ids=static_io["position_ids"],
                )
                static_io["logits"].copy_(logits)

        # Capture
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph):
            with torch.no_grad():
                logits = self.target.forward(
                    input_ids=static_io["input_ids"],
                    position_ids=static_io["position_ids"],
                )
                static_io["logits"].copy_(logits)

        self.verify_graphs[batch_size] = (graph, static_io)

    def verify_step(self, batch_size, candidate_tokens, position_ids):
        """Run verification using captured graph."""
        if batch_size not in self.verify_graphs:
            self.capture_verification_graph(batch_size)

        graph, static_io = self.verify_graphs[batch_size]
        static_io["input_ids"].copy_(candidate_tokens)
        static_io["position_ids"].copy_(position_ids)
        graph.replay()
        return static_io["logits"]
📊

Decode Optimization Stack (Llama 70B, H100, Batch=64)

OptimizationTime per Output Token (ms)Cumulative SpeedupOverhead Eliminated
Baseline (eager PyTorch) 48.5 1.00x N/A
+ CUDA graphs 40.2 1.21x Kernel launch
+ Persistent buffers 38.8 1.25x Tensor allocation
+ Speculative (K=5, alpha=0.8) 15.1 3.21x Per-token weight load
+ FP8 quantization 8.2 5.91x 50% bandwidth

vLLM’s CUDA Graph Implementation

vLLM captures CUDA graphs at startup for a set of padded batch sizes:

# Simplified from vllm/worker/model_runner.py
class ModelRunner:
    GRAPH_BATCH_SIZES = [1, 2, 4] + list(range(8, 513, 8))  # 1,2,4,8,16,...,512

    def capture_graphs(self):
        """Capture decode graphs for all batch sizes at startup."""
        # This runs once during server initialization
        # Takes ~30-60 seconds for all batch sizes

        for batch_size in self.GRAPH_BATCH_SIZES:
            # Create dummy inputs with correct shapes
            input_ids = torch.zeros(batch_size, dtype=torch.long, device="cuda")
            positions = torch.zeros(batch_size, dtype=torch.long, device="cuda")
            slot_mapping = torch.arange(batch_size, dtype=torch.long, device="cuda")

            # Warmup
            for _ in range(2):
                self.model.forward(
                    input_ids=input_ids,
                    positions=positions,
                    kv_caches=self.kv_caches,
                    attn_metadata=self._build_decode_metadata(batch_size),
                )

            # Capture
            graph = torch.cuda.CUDAGraph()
            with torch.cuda.graph(graph, pool=self.graph_pool):
                output = self.model.forward(
                    input_ids=input_ids,
                    positions=positions,
                    kv_caches=self.kv_caches,
                    attn_metadata=self._build_decode_metadata(batch_size),
                )

            self.graph_runners[batch_size] = CUDAGraphRunner(
                graph, input_ids, positions, slot_mapping, output
            )

    def execute_model(self, scheduled_batch):
        """Execute one iteration, using graph when possible."""
        if scheduled_batch.is_decode_only:
            # Find the smallest graph that fits
            padded_bs = self._get_padded_batch_size(
                scheduled_batch.batch_size
            )
            return self.graph_runners[padded_bs].replay(
                scheduled_batch.input_ids,
                scheduled_batch.positions,
                scheduled_batch.slot_mapping,
            )
        else:
            # Prefill or mixed batch: cannot use graph (variable shapes)
            return self.model.forward(...)

The graph_pool parameter is important: it allows multiple graphs to share a memory pool, so CUDA does not allocate separate memory for each graph’s internal buffers. Without pooling, 64 captured graphs could consume 3+ GB of GPU memory just for graph bookkeeping.

ℹ️ Note

CUDA graphs in vLLM are only used for pure decode batches. Mixed prefill+decode batches (from chunked prefill) use eager execution because the prefill token count varies per iteration, violating the fixed-shape constraint.

Measuring the Full Stack

import torch
import time

def benchmark_decode_optimizations(model, batch_size=64, num_steps=200):
    """Benchmark each decode optimization independently."""
    device = "cuda:0"
    dummy_ids = torch.randint(0, 32000, (batch_size, 1), device=device)
    dummy_pos = torch.zeros(batch_size, 1, dtype=torch.long, device=device)

    results = {}

    # Baseline: eager, no graphs, no persistent buffers
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(num_steps):
        with torch.no_grad():
            _ = model(dummy_ids, dummy_pos)
    torch.cuda.synchronize()
    results["eager_ms"] = (time.perf_counter() - t0) * 1000 / num_steps

    # With CUDA graph
    gm = CUDAGraphManager(model)
    gm.capture(batch_size)
    torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(num_steps):
        _ = gm.replay(batch_size, dummy_ids, dummy_pos,
                       torch.arange(batch_size, device=device))
    torch.cuda.synchronize()
    results["graph_ms"] = (time.perf_counter() - t0) * 1000 / num_steps

    results["graph_speedup"] = results["eager_ms"] / results["graph_ms"]

    return results

CUDA Graph Limitations and Workarounds

Several common inference features conflict with CUDA graphs:

CUDA_GRAPH_COMPATIBILITY = {
    "standard_decode": {
        "compatible": True,
        "notes": "Core use case, works well",
    },
    "chunked_prefill": {
        "compatible": False,
        "reason": "Variable prefill chunk size changes input tensor shape",
        "workaround": "Use eager mode for mixed batches, graph for pure decode",
    },
    "speculative_decoding": {
        "compatible": "Partial",
        "reason": "Verification pass has variable accepted token count",
        "workaround": "Capture separate graphs for draft (fixed K) and verify (fixed K+1)",
    },
    "structured_output": {
        "compatible": "Partial",
        "reason": "Grammar-guided sampling changes logit masks per step",
        "workaround": "Apply grammar mask outside the graph, to the static output buffer",
    },
    "lora_adapters": {
        "compatible": False,
        "reason": "Different LoRA weights per request change the GEMM operands",
        "workaround": "Capture per-adapter graphs or use batched LoRA with fixed adapter set",
    },
    "dynamic_rope_scaling": {
        "compatible": True,
        "notes": "Position IDs are inputs, not control flow. Graph works fine.",
    },
    "prefix_caching": {
        "compatible": True,
        "notes": "Cache hit/miss changes slot_mapping but not tensor shapes",
    },
}

Graph Capture Memory Overhead

Each CUDA graph stores a copy of the kernel arguments and launch parameters. For a 70B model with approximately 1600 kernels per decode step, each graph consumes 20-50 MB of GPU memory:

def estimate_graph_memory(num_kernels, avg_args_per_kernel=8,
                           avg_arg_bytes=8, num_graphs=64):
    """Estimate memory overhead of CUDA graph pool."""
    # Per-graph: kernel node storage + argument buffers
    per_graph_bytes = num_kernels * (
        64 +  # Node metadata
        avg_args_per_kernel * avg_arg_bytes  # Kernel arguments
    )

    # Plus internal CUDA driver allocations (empirically ~2x)
    per_graph_total = per_graph_bytes * 2

    total_bytes = num_graphs * per_graph_total
    return {
        "per_graph_mb": per_graph_total / 1e6,
        "total_mb": total_bytes / 1e6,
        "total_gb": total_bytes / 1e9,
    }

# Llama 70B: ~1600 kernels, 64 graphs (batch sizes 1-512 in steps of 8)
# Per graph: ~30 MB, Total: ~1.9 GB
# This is 2.4% of an H100's HBM -- a worthwhile tradeoff for 20% decode speedup
📊

CUDA Graph Memory Overhead vs Decode Speedup

Num GraphsBatch Sizes CoveredMemory Overhead (MB)Avg Padding WasteDecode Speedup
10 Powers of 2 (1-512) 300 25% 1.15x (avg with padding)
32 Every 16 (1-512) 960 8% 1.19x
64 Every 8 (1-512) 1920 4% 1.20x
128 Every 4 (1-512) 3840 2% 1.21x

The sweet spot is 64 graphs with step-8 buckets: 1.9 GB of memory overhead (2.4% of H100 HBM) for 1.20x decode speedup with only 4% padding waste. Going finer-grained to 128 graphs doubles the memory cost but provides diminishing returns on padding efficiency.

The decode optimization stack is cumulative: CUDA graphs remove launch overhead, persistent buffers remove allocation overhead, speculative decoding removes the per-token weight loading cost, and quantization reduces the weight bytes themselves. Each targets a different component of the total decode latency, and together they can reduce per-token time by 5-6x on production workloads.