Part of Series Inference Optimization Timeline 23 of 60
1 Transformer Fundamentals for Systems Engineers: The 10-Minute Bridge from Architecture to Inference 2 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 3 KV Cache: The Hidden Memory Giant in LLM Serving 4 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 5 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 6 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 7 Continuous Batching: The Complete Guide to LLM Inference Scheduling 8 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 9 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 10 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 11 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 12 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 13 Mamba and State Space Models: The O(n) Alternative to Attention 14 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 15 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 16 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 17 Model Loading and Cold Start: safetensors, mmap, and Startup Optimization 18 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 19 Kernel Autotuning: How TensorRT and torch.compile Find Optimal CUDA Kernels 20 Attention Kernel Comparison: FlashAttention vs FlashInfer vs xformers vs Triton 21 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 22 Dynamic Batching: Orca, Sarathi, and Iteration-Level Scheduling Algorithms 23 Memory Pool Management: Slab Allocators for GPU Inference 24 Prefill vs Decode Optimization: Different Bottlenecks, Different Solutions 25 Decode Optimization: CUDA Graphs, Persistent Batches, and Speculative Verification 26 Multi-Model Serving: GPU Sharing, Model Switching, and Adapter Pool Management 27 Structured Output Acceleration: Compressed FSMs, Speculative JSON, and Grammar Caching 28 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 29 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 30 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 31 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 32 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification 33 Disaggregated Serving v2: Mooncake KV-Centric Architecture and LoongServe Elastic SP 34 Request Preemption and Priority Scheduling in Production LLM Serving 35 Autoscaling LLM Inference: Signals, Lag, Warm Pools, and Cost-Optimal Scaling 36 The Inference Stack in 2026: From HTTP Request to GPU Kernel and Back 37 Video and Audio LLM Serving: Temporal Encoding, Chunked Streaming, and Latency Budgets 38 KV Cache Compression and Eviction: H2O, Attention Sinks, Sliding Window, and Quantized KV 39 Distributed Inference: Tensor Parallelism vs Pipeline Parallelism for Serving 40 Serving Benchmark Methodology: How to Properly Measure LLM Inference Performance 41 Compute-Communication Overlap: Hiding Distributed Training Latency 42 DeepSpeed ZeRO: Memory Optimization for Distributed Training at Scale 43 Pipeline Parallelism: From GPipe to DualPipe -- Eliminating the Bubble 44 Gradient Compression for Distributed Training: Promise, Reality, and Where It Still Wins 45 The Definitive Guide to Distributed Parallelism: Data, Tensor, Pipeline, Expert, and Sequence Parallelism for Large-Scale Training 46 Decoding Performance: Beam Search vs Sampling — Latency, Throughput, Memory, and the Full Design Space 47 LLM Prefill Phase Optimization: Why Prompt Processing Is Compute-Bound and How to Fix It 48 LLM Serving Engines: vLLM vs SGLang vs TensorRT-LLM — A Systems Comparison 49 Request Routing for LLM Inference: From Naive Load Balancing to KV Cache-Aware Scheduling 50 Why Adam Is Expensive and What To Do About It: 8-bit Adam, Adafactor, CAME, and the Memory Math of Optimizers 51 How Large Models Actually Get Loaded: Safetensors, mmap, Tensor Parallelism, and Progressive Loading 52 Mixed Precision Training: The Complete Precision Landscape from FP32 to FP4 53 Model Compression: Pruning, Distillation, and Why Quantization Won 54 From NAS to Scaling Laws: How We Design LLM Architectures Now 55 NVIDIA NCCL Performance Tuning for Multi-GPU Training 56 ONNX Runtime in Practice: Graph Optimization, Execution Providers, Quantization, and When ORT Is the Right Choice 57 Optimizing GEMM for Neural Networks: BLAS vs Custom Kernels (Nov 2019) 58 Long Context: From Sparse Attention to Ring Attention 59 TensorRT-LLM: Graph Optimization for Maximum Inference Performance 60 Long Context LLMs: From 2K to 1M Tokens

The same GPU kernel that makes prefill fast will destroy your decode throughput. Prefill wants massive batch sizes and compute-bound GEMMs that saturate tensor cores. Decode wants minimal latency and memory-bandwidth optimization because you’re loading 140 GB of model weights to produce a single token. Treating them as one workload guarantees you’ll optimize for neither. Production serving systems recognize this split explicitly: chunked prefill to prevent decode stalls, CUDA graphs to eliminate decode overhead, and separate routing strategies because the bottleneck isn’t the same. This post covers the optimization techniques that acknowledge the arithmetic intensity divide.

The Arithmetic Intensity Divide

The key metric is arithmetic intensity: FLOPs per byte of memory accessed. High arithmetic intensity means compute-bound (GPU cores are the bottleneck). Low arithmetic intensity means bandwidth-bound (memory transfer is the bottleneck).

For a transformer layer with hidden dimension dd, the attention projection matrices are WQ,WK,WV,WORd×dW_Q, W_K, W_V, W_O \in \mathbb{R}^{d \times d} and the FFN matrices are Wgate,WupRd×4dW_{\text{gate}}, W_{\text{up}} \in \mathbb{R}^{d \times 4d} and WdownR4d×dW_{\text{down}} \in \mathbb{R}^{4d \times d}. Total weight bytes per layer (FP16): 4d22+34dd2=32d24d^2 \cdot 2 + 3 \cdot 4d \cdot d \cdot 2 = 32d^2 bytes.

Prefill with sequence length SS and batch size BB:

  • The input activation matrix is XR(BS)×dX \in \mathbb{R}^{(B \cdot S) \times d}
  • Each weight matrix GEMM does 2(BS)dd2 \cdot (B \cdot S) \cdot d \cdot d FLOPs
  • Weight loading: 2d22d^2 bytes (FP16)
  • Arithmetic intensity: 2BSd22d2=BS\frac{2 \cdot B \cdot S \cdot d^2}{2d^2} = B \cdot S FLOPs/byte

Decode with batch size BB, generating 1 token per request:

  • The input activation matrix is XRB×dX \in \mathbb{R}^{B \times d}
  • Each weight matrix GEMM does 2Bdd2 \cdot B \cdot d \cdot d FLOPs
  • Weight loading: 2d22d^2 bytes (same weights, regardless of batch size)
  • Arithmetic intensity: 2Bd22d2=B\frac{2 \cdot B \cdot d^2}{2d^2} = B FLOPs/byte
📊

Arithmetic Intensity: Prefill vs Decode

PhaseBatch SizeSeq LengthArithmetic IntensityBottleneck (H100)
Prefill 1 2048 2048 FLOPs/byte Compute-bound
Prefill 1 512 512 FLOPs/byte Compute-bound
Prefill 1 32 32 FLOPs/byte Borderline
Decode 1 1 1 FLOP/byte Bandwidth-bound
Decode 32 1 32 FLOPs/byte Borderline
Decode 256 1 256 FLOPs/byte Compute-bound

The H100 has 989 TFLOPS (FP16 tensor core) and 3.35 TB/s HBM bandwidth. The compute-bandwidth balance point is:

Balance point=989×10123.35×1012295 FLOPs/byte\text{Balance point} = \frac{989 \times 10^{12}}{3.35 \times 10^{12}} \approx 295 \text{ FLOPs/byte}

Prefill with any reasonable sequence length exceeds this. Decode at batch size 1 is 295x below this. This is why you cannot optimize both with the same approach.

Prefill Optimization: Maximizing Compute Utilization

Prefill processes the full prompt in one forward pass. The input is a matrix XRS×dX \in \mathbb{R}^{S \times d} where SS is the prompt length. Every GEMM operates on this full matrix.

GEMM Tiling for Prefill

For prefill, the GEMMs are large enough to saturate GPU compute. The optimization goal is to maximize tensor core utilization through proper tiling:

import torch
import triton
import triton.language as tl

@triton.jit
def prefill_gemm_kernel(
    A_ptr, B_ptr, C_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    """GEMM kernel optimized for prefill: large M (seq_len), large K (hidden_dim).
    Tile sizes chosen for H100 SM occupancy."""
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    num_pid_n = tl.cdiv(N, BLOCK_N)

    # 2D tile swizzling for L2 cache locality
    pid_m = pid // num_pid_n
    pid_n = pid % num_pid_n

    # Compute tile offsets
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    # Initialize accumulator in FP32
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # Main loop over K dimension
    for k in range(0, K, BLOCK_K):
        a = tl.load(
            A_ptr + offs_m[:, None] * stride_am + (k + offs_k[None, :]) * stride_ak,
            mask=(offs_m[:, None] < M) & ((k + offs_k[None, :]) < K),
            other=0.0,
        )
        b = tl.load(
            B_ptr + (k + offs_k[:, None]) * stride_bk + offs_n[None, :] * stride_bn,
            mask=((k + offs_k[:, None]) < K) & (offs_n[None, :] < N),
            other=0.0,
        )
        acc += tl.dot(a, b)  # Tensor core HMMA

    # Store result
    c = acc.to(tl.float16)
    tl.store(
        C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
        c,
        mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
    )

For prefill on H100, optimal tile sizes are typically BLOCK_M=128, BLOCK_N=256, BLOCK_K=64 for FP16. This gives 128x256 = 32768 elements per output tile, requiring 128x64 + 64x256 = 24576 elements of input, fitting comfortably in shared memory (192 KB per SM on H100).

Attention During Prefill

Prefill attention computes the full S×SS \times S attention matrix. FlashAttention tiles this computation to keep intermediates in SRAM:

def prefill_attention_analysis(seq_len, num_heads, head_dim, dtype_bytes=2):
    """Analyze prefill attention compute and memory."""
    S = seq_len
    d = head_dim
    H = num_heads

    # FLOPs: Q*K^T (S*S*d per head) + softmax (S*S) + attn*V (S*S*d)
    qk_flops = 2 * H * S * S * d
    av_flops = 2 * H * S * S * d
    total_flops = qk_flops + av_flops

    # Memory: Q,K,V tensors (no materialized S*S matrix with FlashAttention)
    qkv_bytes = 3 * H * S * d * dtype_bytes
    output_bytes = H * S * d * dtype_bytes
    total_memory = qkv_bytes + output_bytes

    # Without FlashAttention: must store S*S attention matrix
    naive_memory = qkv_bytes + H * S * S * dtype_bytes + output_bytes

    return {
        "total_tflops": total_flops / 1e12,
        "flash_memory_gb": total_memory / 1e9,
        "naive_memory_gb": naive_memory / 1e9,
        "arithmetic_intensity": total_flops / total_memory,
    }

# Llama 70B: 80 layers, 64 heads, head_dim=128
# Prefill S=4096:
# Total: 80 * 2 * 64 * 4096^2 * 128 = 2.75 TFLOP
# This takes 2.75 / 989 = 2.8ms on H100 (just attention)

Decode Optimization: Maximizing Bandwidth Utilization

Decode generates one token per request. The input to each layer is a vector xRdx \in \mathbb{R}^{d} (batch=1) or a thin matrix XRB×dX \in \mathbb{R}^{B \times d} (batched decode). The GEMM is now a matrix-vector product (GEMV) or a thin GEMM.

The problem: for Llama 70B, each layer’s weights are approximately 1.75 GB (FP16). At batch size 1, the GEMM does 2×8192×8192134M2 \times 8192 \times 8192 \approx 134M FLOPs but loads 134 MB of weights. On an H100 (3.35 TB/s bandwidth), loading takes 134MB/3.35TB/s=40μs134\text{MB} / 3.35\text{TB/s} = 40\mu\text{s}. The compute takes 134M/989T=0.14μs134\text{M} / 989\text{T} = 0.14\mu\text{s}. The GPU spends 99.6% of its time waiting for weight data.

def decode_roofline(hidden_dim, batch_size, num_layers,
                    bw_tbs=3.35, compute_tflops=989):
    """Roofline analysis for decode phase."""
    d = hidden_dim

    # Weight bytes per layer (Q,K,V,O projections + FFN gate/up/down)
    # Assuming GQA with 8 KV heads and 64 query heads, head_dim=128
    qkv_bytes = (64 + 8 + 8) * 128 * d * 2  # Q + K + V projections
    o_bytes = 64 * 128 * d * 2  # O projection
    ffn_bytes = 3 * d * (d * 8 // 3) * 2  # gate + up + down (SwiGLU)

    weight_bytes_per_layer = qkv_bytes + o_bytes + ffn_bytes

    # FLOPs per layer
    flops_per_layer = 2 * batch_size * weight_bytes_per_layer // 2  # 2*M*N for each GEMM

    # Time per layer
    bandwidth_time = weight_bytes_per_layer / (bw_tbs * 1e12)
    compute_time = flops_per_layer / (compute_tflops * 1e12)

    # Decode is bandwidth-bound when bandwidth_time > compute_time
    # This happens when batch_size < balance_point
    balance_point = (bw_tbs * 1e12) / (compute_tflops * 1e12) * (weight_bytes_per_layer // 2) / (weight_bytes_per_layer // 2)
    # Simplifies to: balance_point = bw / compute * weight_elements

    total_time = num_layers * max(bandwidth_time, compute_time)
    tokens_per_sec = batch_size / total_time

    return {
        "weight_bytes_per_layer_mb": weight_bytes_per_layer / 1e6,
        "bandwidth_time_us": bandwidth_time * 1e6,
        "compute_time_us": compute_time * 1e6,
        "bottleneck": "bandwidth" if bandwidth_time > compute_time else "compute",
        "tokens_per_sec": tokens_per_sec,
        "time_per_token_ms": total_time * 1000 / batch_size,
    }

Decode Throughput vs Batch Size (Llama 70B, H100)

line
Metric 1248163264128256512
Tokens/sec (bandwidth model)
24
48
96
192
384
768
1536
3072
6144
12288
Tokens/sec (actual measured)
22
44
86
168
328
640
1180
2050
3200
4100
Tokens/sec (compute ceiling)
15000
15000
15000
15000
15000
15000
15000
15000
15000
15000
Performance

At batch size 1, decode throughput is approximately 24 tokens/sec on an H100 for Llama 70B FP16. The theoretical compute ceiling is over 15,000 tokens/sec. The 600x gap is entirely due to memory bandwidth. Increasing batch size is the primary lever: at batch=256, actual throughput reaches 3,200 tokens/sec (21% of compute ceiling), because weights are loaded once and reused across 256 requests.

CUDA Graphs for Decode

Every decode step executes the exact same sequence of CUDA kernels with the same tensor shapes (batch_size x hidden_dim). CUDA graphs eliminate the CPU-side kernel launch overhead by recording the kernel sequence once and replaying it:

import torch

class CUDAGraphDecoder:
    """Capture and replay decode step as a CUDA graph."""

    def __init__(self, model, max_batch_size, device="cuda:0"):
        self.model = model
        self.device = device
        self.graphs = {}  # batch_size -> captured graph
        self.static_inputs = {}  # batch_size -> static input tensors

    def capture(self, batch_size):
        """Capture the decode forward pass for a specific batch size."""
        # Allocate static tensors (CUDA graph requires fixed addresses)
        static_input_ids = torch.zeros(
            batch_size, 1, dtype=torch.long, device=self.device
        )
        static_position_ids = torch.zeros(
            batch_size, 1, dtype=torch.long, device=self.device
        )
        static_output = torch.zeros(
            batch_size, 1, self.model.config.vocab_size,
            dtype=torch.float16, device=self.device
        )

        self.static_inputs[batch_size] = {
            "input_ids": static_input_ids,
            "position_ids": static_position_ids,
            "output": static_output,
        }

        # Warmup (required before capture)
        for _ in range(3):
            with torch.no_grad():
                _ = self.model(
                    input_ids=static_input_ids,
                    position_ids=static_position_ids,
                )

        # Capture
        graph = torch.cuda.CUDAGraph()
        with torch.cuda.graph(graph):
            with torch.no_grad():
                output = self.model(
                    input_ids=static_input_ids,
                    position_ids=static_position_ids,
                )
                self.static_inputs[batch_size]["output"].copy_(output.logits)

        self.graphs[batch_size] = graph

    def decode_step(self, input_ids, position_ids):
        """Execute one decode step using the captured graph."""
        batch_size = input_ids.shape[0]

        if batch_size not in self.graphs:
            self.capture(batch_size)

        static = self.static_inputs[batch_size]

        # Copy dynamic data into static buffers
        static["input_ids"].copy_(input_ids)
        static["position_ids"].copy_(position_ids)

        # Replay the captured graph (no CPU kernel launch overhead)
        self.graphs[batch_size].replay()

        return static["output"].clone()

Without CUDA graphs, each decode step involves:

  1. CPU Python overhead to call each module’s forward (5-15 us per module)
  2. CUDA kernel launch for each operation (3-5 us per launch)
  3. For 80 layers with ~20 kernels each: 80×20×5μs=8ms80 \times 20 \times 5\mu\text{s} = 8\text{ms} of launch overhead

With CUDA graphs: 1 graph replay launch = ~5 us total. For a model where the actual GPU compute takes 30ms, eliminating 8ms of launch overhead gives a 21% speedup.

📊

CUDA Graph Impact on Decode Latency (Llama 70B, H100)

Batch SizeWithout Graph (ms)With Graph (ms)SpeedupLaunch Overhead Eliminated
1 42.1 33.8 1.25x 8.3ms
8 43.2 34.1 1.27x 9.1ms
32 48.5 39.8 1.22x 8.7ms
128 72.3 63.1 1.15x 9.2ms
256 118.4 109.8 1.08x 8.6ms

Chunked Prefill: Taming Long Prompts

A 32K token prompt processed in a single prefill pass creates several problems:

  1. Attention compute: O(S2)O(S^2) even with FlashAttention, the 32K x 32K computation takes significant time
  2. KV cache allocation: must reserve KV cache for all 32K positions at once
  3. Decode starvation: while prefill runs, no decode steps happen for other requests

Chunked prefill splits the prompt into chunks of size CC and processes them sequentially:

class ChunkedPrefillScheduler:
    """Process prefill in chunks to avoid starving decode requests."""

    def __init__(self, model, chunk_size=512, max_batch_tokens=8192):
        self.model = model
        self.chunk_size = chunk_size
        self.max_batch_tokens = max_batch_tokens

    def schedule_iteration(self, prefill_queue, decode_queue):
        """Schedule one iteration mixing prefill chunks and decode tokens."""
        batch = []
        token_budget = self.max_batch_tokens

        # Priority 1: decode tokens (low latency requirement)
        decode_requests = []
        for req in decode_queue:
            if token_budget >= 1:  # Each decode request uses 1 token
                decode_requests.append(req)
                token_budget -= 1
        batch.extend(decode_requests)

        # Priority 2: prefill chunks (fill remaining budget)
        prefill_chunks = []
        for req in prefill_queue:
            remaining = req.prompt_len - req.processed_tokens
            chunk = min(remaining, self.chunk_size, token_budget)
            if chunk > 0:
                prefill_chunks.append((req, chunk))
                token_budget -= chunk
        batch.extend(prefill_chunks)

        return batch

    def execute_chunked_prefill(self, request, chunk_start, chunk_size):
        """Execute one chunk of prefill for a request."""
        chunk_ids = request.prompt_ids[chunk_start:chunk_start + chunk_size]
        position_ids = torch.arange(
            chunk_start, chunk_start + chunk_size,
            device=chunk_ids.device
        )

        # Forward pass for this chunk only
        # KV cache is appended incrementally
        with torch.no_grad():
            outputs = self.model(
                input_ids=chunk_ids.unsqueeze(0),
                position_ids=position_ids.unsqueeze(0),
                past_key_values=request.kv_cache,
                use_cache=True,
            )

        # Update request state
        request.kv_cache = outputs.past_key_values
        request.processed_tokens += chunk_size

        if request.processed_tokens >= request.prompt_len:
            request.phase = "decode"
            return outputs.logits[:, -1, :]  # Ready to sample first token

        return None  # More chunks needed

Chunked Prefill Performance Analysis

The chunk size CC creates a direct tradeoff:

  • Small CC (e.g., 128): minimal decode starvation, but prefill takes many iterations. Each chunk’s GEMM has lower arithmetic intensity (CC instead of SS).
  • Large CC (e.g., 4096): efficient prefill GEMMs, but decode requests wait up to one chunk’s execution time between steps.

The inter-token latency (ITL) for decode requests during chunked prefill:

ITLtdecode_batch+tprefill_chunk(C)\text{ITL} \leq t_{\text{decode\_batch}} + t_{\text{prefill\_chunk}}(C)

For a 70B model on H100 with batch=128 decode requests and one prefill chunk:

def chunked_prefill_latency(chunk_size, decode_batch, model_params):
    """Estimate per-iteration latency with chunked prefill."""
    d = model_params["hidden_dim"]
    num_layers = model_params["num_layers"]
    bw = model_params["hbm_bw_tbs"]  # TB/s

    # Decode: bandwidth-bound, load all weights once
    weight_bytes = num_layers * 32 * d * d  # Approximate
    decode_time = weight_bytes / (bw * 1e12)

    # Prefill chunk: compute-bound if chunk_size is large enough
    prefill_flops = num_layers * 2 * chunk_size * 12 * d * d  # All linear layers
    compute_tflops = model_params["compute_tflops"]
    prefill_time = prefill_flops / (compute_tflops * 1e12)

    # With chunked prefill, decode and prefill share the same forward pass
    # The total token count determines whether we are compute or bandwidth bound
    total_tokens = decode_batch + chunk_size
    total_flops = num_layers * 2 * total_tokens * 12 * d * d
    bw_time = weight_bytes / (bw * 1e12)
    compute_time = total_flops / (compute_tflops * 1e12)

    iteration_time = max(bw_time, compute_time)

    return {
        "iteration_ms": iteration_time * 1000,
        "decode_itl_ms": iteration_time * 1000,
        "bottleneck": "bandwidth" if bw_time > compute_time else "compute",
        "total_tokens": total_tokens,
    }

Decode ITL vs Prefill Chunk Size (Llama 70B, H100, 128 Decode Requests)

line
Metric 0 (decode only)128256512102420484096
Inter-Token Latency (ms)
33.8
34.2
34.8
36.1
39.5
48.2
67.1
Prefill Throughput (tokens/s)
0
3745
7352
14173
25907
42491
61042

Why You Cannot Optimize Both With One Approach

The fundamental tension is:

  1. Prefill wants large matrices - high arithmetic intensity, saturate compute. Pack as many prompt tokens as possible into each forward pass.

  2. Decode wants low latency - minimize the time between generated tokens. Any prefill work in the same batch increases decode latency.

  3. Batching helps decode but not prefill - increasing decode batch size improves GPU utilization (weights loaded once, used BB times). But prefill already has high utilization; adding more prefill tokens just takes proportionally longer.

Approach 1: Shared Forward Pass (vLLM/SGLang)

Mix prefill chunks and decode tokens in the same forward pass:

def mixed_batch_forward(model, decode_requests, prefill_chunks):
    """Single forward pass with both decode and prefill tokens."""

    # Concatenate all tokens into one batch
    all_input_ids = []
    all_position_ids = []
    all_seq_lens = []

    # Decode tokens: 1 token per request
    for req in decode_requests:
        all_input_ids.append(req.current_token_id)
        all_position_ids.append(req.current_position)
        all_seq_lens.append(req.total_seq_len)

    # Prefill chunks: chunk_size tokens per request
    for req, chunk_start, chunk_size in prefill_chunks:
        all_input_ids.extend(
            req.prompt_ids[chunk_start:chunk_start + chunk_size]
        )
        all_position_ids.extend(
            range(chunk_start, chunk_start + chunk_size)
        )
        all_seq_lens.append(chunk_size)

    # Single forward pass processes everything
    # The attention kernel handles variable sequence lengths
    # via the sequence length metadata
    input_ids = torch.tensor(all_input_ids, device="cuda")
    position_ids = torch.tensor(all_position_ids, device="cuda")

    with torch.no_grad():
        logits = model(input_ids, position_ids, ...)

    return logits

Tradeoff: simple implementation, but decode latency increases with prefill chunk size. Decode tokens “pay” for the prefill compute.

Approach 2: Disaggregated Prefill/Decode (Splitwise, DistServe)

Run prefill and decode on separate GPU pools:

class DisaggregatedServing:
    """Separate GPU pools for prefill and decode."""

    def __init__(self, prefill_workers, decode_workers):
        self.prefill_pool = prefill_workers  # Optimized for compute
        self.decode_pool = decode_workers    # Optimized for bandwidth

    def handle_request(self, prompt_ids):
        # Phase 1: Prefill on compute-optimized GPU
        prefill_worker = self.prefill_pool.get_worker()
        kv_cache = prefill_worker.prefill(prompt_ids)
        first_token = prefill_worker.sample(kv_cache)

        # Transfer KV cache to decode worker
        decode_worker = self.decode_pool.get_worker()
        decode_worker.receive_kv_cache(kv_cache)  # RDMA transfer

        # Phase 2: Decode on bandwidth-optimized GPU
        tokens = [first_token]
        while not is_eos(tokens[-1]):
            next_token = decode_worker.decode_step(tokens[-1])
            tokens.append(next_token)

        return tokens

Tradeoff: optimal performance for each phase, but KV cache transfer between pools adds latency (KV size/RDMA bandwidthKV \text{ size} / \text{RDMA bandwidth}). For a 4K context on Llama 70B: KV cache = 2×80×8×128×4096×2=1.3GB2 \times 80 \times 8 \times 128 \times 4096 \times 2 = 1.3\text{GB}, transfer at 400 Gbps = 26ms overhead.

Approach 3: Priority Scheduling

Keep both phases on the same GPU but give decode strict priority:

class PriorityScheduler:
    """Decode-priority scheduler that preempts prefill for decode."""

    def __init__(self, model, decode_slo_ms=50):
        self.model = model
        self.decode_slo_ms = decode_slo_ms
        self.decode_queue = []
        self.prefill_queue = []

    def schedule(self):
        """Always process all pending decode tokens first.
        Use remaining capacity for prefill chunks."""
        batch = []
        token_budget = 8192  # Max tokens per iteration

        # Decode tokens get absolute priority
        for req in self.decode_queue:
            batch.append(("decode", req, 1))
            token_budget -= 1

        # If batch is too small (low decode load), add prefill
        if token_budget > 128:  # Minimum chunk size
            for req in self.prefill_queue:
                remaining = req.prompt_len - req.processed
                chunk = min(remaining, 512, token_budget)
                if chunk >= 128:
                    batch.append(("prefill", req, chunk))
                    token_budget -= chunk

        return batch
📊

Optimization Strategy Comparison (Llama 70B, H100, 100 QPS)

StrategyTTFT P50 (ms)TTFT P99 (ms)ITL P50 (ms)ITL P99 (ms)Throughput (tok/s)
Shared (no chunking) 85 320 68 180 3200
Chunked prefill (C=512) 210 450 36 52 3800
Chunked prefill (C=2048) 130 380 48 95 4100
Disaggregated 95 180 34 42 4500
Priority scheduling 250 600 34 45 3600

Profiling the Phase Boundary

The transition from prefill to decode is visible in profiling traces. Here is what to look for:

import torch.profiler

def profile_phase_transition(model, prompt_ids, num_decode_steps=10):
    """Profile the prefill-to-decode transition to see the bottleneck shift."""
    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        record_shapes=True,
        with_stack=True,
    ) as prof:
        # Prefill phase
        with torch.profiler.record_function("PREFILL"):
            with torch.no_grad():
                outputs = model(prompt_ids, use_cache=True)
                kv_cache = outputs.past_key_values
                next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)

        # Decode phase
        for step in range(num_decode_steps):
            with torch.profiler.record_function(f"DECODE_STEP_{step}"):
                with torch.no_grad():
                    outputs = model(
                        next_token,
                        past_key_values=kv_cache,
                        use_cache=True,
                    )
                    kv_cache = outputs.past_key_values
                    next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)

    # Export for Nsight/Chrome trace viewer
    prof.export_chrome_trace("phase_transition.json")

    # Print kernel-level breakdown
    print(prof.key_averages().table(
        sort_by="cuda_time_total", row_limit=20
    ))

In the trace, you will see:

  • Prefill: large GEMM kernels (e.g., sm80_xmma_gemm_f16f16_f16f32_f32_tn_n) dominating, high SM occupancy, low memory bandwidth utilization
  • Decode: small GEMV kernels, low SM occupancy, HBM bandwidth near saturation, many small kernel launches between GEMMs

Quantization Interacts Differently With Each Phase

Quantizing weights from FP16 to INT8 halves the weight bytes. This has different effects on each phase:

Prefill (compute-bound): halving weight bytes does not help because compute is the bottleneck. The GEMM still does the same number of FLOPs. INT8 tensor cores are 2x faster than FP16 on H100 (1979 vs 989 TOPS), so INT8 quantization gives ~2x prefill speedup through faster compute, not reduced memory.

Decode (bandwidth-bound): halving weight bytes directly halves the memory load time, giving ~2x decode speedup. The compute was not the bottleneck, so the reduced precision does not matter.

def quantization_impact(hidden_dim, batch_size, seq_len, phase,
                         bw_tbs=3.35, fp16_tflops=989, int8_tops=1979):
    """Estimate quantization impact on each phase."""
    d = hidden_dim
    weight_elements = 12 * d * d  # Approximate per layer

    if phase == "prefill":
        tokens = seq_len
        fp16_flops = 2 * tokens * weight_elements
        fp16_bytes = weight_elements * 2
        int8_bytes = weight_elements * 1

        fp16_compute_time = fp16_flops / (fp16_tflops * 1e12)
        fp16_bw_time = fp16_bytes / (bw_tbs * 1e12)
        int8_compute_time = fp16_flops / (int8_tops * 1e12)  # INT8 tensor cores
        int8_bw_time = int8_bytes / (bw_tbs * 1e12)

        return {
            "fp16_time": max(fp16_compute_time, fp16_bw_time),
            "int8_time": max(int8_compute_time, int8_bw_time),
            "speedup": max(fp16_compute_time, fp16_bw_time) / max(int8_compute_time, int8_bw_time),
            "fp16_bottleneck": "compute" if fp16_compute_time > fp16_bw_time else "bandwidth",
            "int8_bottleneck": "compute" if int8_compute_time > int8_bw_time else "bandwidth",
        }

    else:  # decode
        tokens = batch_size
        fp16_bytes = weight_elements * 2
        int8_bytes = weight_elements * 1
        fp16_flops = 2 * tokens * weight_elements

        fp16_bw_time = fp16_bytes / (bw_tbs * 1e12)
        int8_bw_time = int8_bytes / (bw_tbs * 1e12)

        return {
            "fp16_time": fp16_bw_time,  # BW-bound at small batch
            "int8_time": int8_bw_time,
            "speedup": fp16_bw_time / int8_bw_time,  # ~2x
            "bottleneck": "bandwidth (both)",
        }
📊

INT8 Quantization Impact by Phase (Llama 70B, H100)

ConfigurationPrefill (S=2048) msDecode (B=1) msDecode (B=128) ms
FP16 28.4 42.1 72.3
INT8 (W8A8) 15.2 21.8 41.5
Speedup 1.87x 1.93x 1.74x
Bottleneck shift Compute -> Compute BW -> BW BW -> borderline

The fundamental insight is that prefill and decode live on opposite sides of the roofline. Any optimization that reduces memory traffic helps decode. Any optimization that increases compute throughput helps prefill. Chunked prefill is the practical compromise that lets both coexist on the same GPU, while disaggregated serving is the principled solution that gives each phase its own hardware optimized for its specific bottleneck.