Part of Series vLLM v1 & Omni Internals 6 of 25
1 vLLM v1 Block Manager: Deconstructing KV Cache Memory Management at the Pointer Level 2 vLLM v1 Disaggregated Serving: The E/P/D/G Pipeline and Multimodal-First Architecture 3 vLLM OmniConnector: Async Multimodal Token Lifecycle Management 4 vLLM v1 Unified Scheduler: One Queue, No Prefill/Decode Distinction, and Persistent Batches 5 vLLM v1 Attention Backends: FlashAttention, FlashInfer, and PagedAttention Selection Logic 6 vLLM v1 Rejection Sampler: Native CFG and Speculative Verification Kernels 7 vLLM v1 Tensor Parallelism: Symmetric Workers, Incremental Updates, and NCCL Optimization 8 vLLM v1 Structured Output: The Native Grammar Engine and Token Mask Caching 9 vLLM v1 Prefix Caching: Hash Chains, LRU Eviction, and Hit Rate Optimization 10 vLLM v1 Multi-LoRA: Adapter Scheduling, Memory Management, and Batched Inference 11 vLLM v1 Performance Profiling: Finding and Fixing Bottlenecks in Production 12 vLLM v1 Speculative Decoding: Draft Model Integration and Token Verification Pipeline 13 vLLM v1 Vision Encoder: ViT Integration, Image Preprocessing, and Visual Token Pipeline 14 vLLM v1 Model Loading: Weight Distribution, safetensors Deserialization, and Progressive Startup 15 vLLM v1 Request Cancellation and Early Stopping: Freeing Resources Mid-Generation 16 vLLM v1 Quantized Inference: GPTQ, AWQ, FP8 Kernel Selection 17 vLLM v1 Distributed Execution: Ray Integration and Multi-Node Coordination 18 vLLM v1 KV Cache Offloading: GPU to CPU to SSD Tiered Memory 19 vLLM v1 Async Output: Detokenization, Streaming, and Queue Management 20 vLLM v1 Video and Audio: Temporal Encoding and Multi-Modal Batching 21 vLLM v1 Benchmarking: Systematic Optimization for Your Workload 22 vLLM v1 Error Handling: CUDA OOM Recovery, Request Retry, and Graceful Degradation 23 vLLM v1 Configuration Guide: gpu_memory_utilization, max_num_seqs, and Every Key Parameter 24 vLLM v1 Plugin Architecture: Custom Samplers, Schedulers, and Attention Backends 25 vLLM v1 Production Checklist: From Development to Reliable 24/7 Serving

Speculative decoding promises to generate K tokens for the cost of one forward pass โ€” but only if the draft model is good enough. Accept too few draft tokens and you waste compute on verification. Accept too many bad tokens and your output distribution diverges from the target model. The rejection sampling criterion solves this with a statistical guarantee: it accepts as many draft tokens as possible while ensuring the final distribution is mathematically identical to sampling from the target model directly. Classifier-free guidance (CFG) has a different goal โ€” strengthening prompt adherence by blending conditional and unconditional logits โ€” but it shares the same core infrastructure: running multiple forward passes per step and comparing probability distributions token-by-token.

This post covers the rejection sampling criterion, the batched verification CUDA kernel, CFG companion request pairing, the integration with vLLMโ€™s scheduler, and a complete implementation.

The Rejection Sampling Criterion

The Problem Setup

In speculative decoding, a draft model MdM_d generates KK candidate tokens x1,x2,โ€ฆ,xKx_1, x_2, \ldots, x_K autoregressively. Each token xix_i is sampled from the draft distribution q(xiโˆฃx<i)q(x_i \mid x_{<i}). The target model MtM_t then evaluates all KK positions in a single forward pass, producing target distributions p(xiโˆฃx<i)p(x_i \mid x_{<i}) for each position.

The goal: accept as many draft tokens as possible while ensuring the final output distribution is identical to sampling from pp directly. We want the speed of the draft model with the quality of the target model.

The Acceptance Criterion

For each draft token xix_i, the acceptance probability is:

ฮฑi=minโก(1,p(xiโˆฃx<i)q(xiโˆฃx<i))\alpha_i = \min\left(1, \frac{p(x_i \mid x_{<i})}{q(x_i \mid x_{<i})}\right)

The algorithm:

  1. For each position i=1,2,โ€ฆ,Ki = 1, 2, \ldots, K:
    • Draw riโˆผUniform(0,1)r_i \sim \text{Uniform}(0, 1)
    • If riโ‰คฮฑir_i \leq \alpha_i: accept xix_i, continue to position i+1i+1
    • If ri>ฮฑir_i > \alpha_i: reject xix_i and all subsequent tokens. Resample xix_i from a corrected distribution.
  2. The corrected distribution at the first rejection point is:

pโ€ฒ(x)=normalize(maxโก(0,p(x)โˆ’q(x)))p'(x) = \text{normalize}\left(\max(0, p(x) - q(x))\right)

This ensures the final output is distributed exactly as pp, not as a mixture of pp and qq.

Why This Works: The Proof

The combined probability of generating token xx through speculative decoding is:

Pspec(x)=q(x)โ‹…minโก(1,p(x)q(x))+(1โˆ’โˆ‘xโ€ฒq(xโ€ฒ)โ‹…minโก(1,p(xโ€ฒ)q(xโ€ฒ)))โ‹…pโ€ฒ(x)P_{\text{spec}}(x) = q(x) \cdot \min\left(1, \frac{p(x)}{q(x)}\right) + \left(1 - \sum_{x'} q(x') \cdot \min\left(1, \frac{p(x')}{q(x')}\right)\right) \cdot p'(x)

The first term covers accepted tokens. The second term covers rejected-and-resampled tokens. Expanding:

For tokens where p(x)โ‰ฅq(x)p(x) \geq q(x): the acceptance probability is 1, so the first term contributes q(x)q(x).

For tokens where p(x)<q(x)p(x) < q(x): the acceptance probability is p(x)/q(x)p(x)/q(x), so the first term contributes q(x)โ‹…p(x)/q(x)=p(x)q(x) \cdot p(x)/q(x) = p(x).

The probability of reaching the rejection branch is:

ฮฒ=1โˆ’โˆ‘xโ€ฒq(xโ€ฒ)โ‹…minโก(1,p(xโ€ฒ)q(xโ€ฒ))=1โˆ’โˆ‘xโ€ฒ:p(xโ€ฒ)โ‰ฅq(xโ€ฒ)q(xโ€ฒ)โˆ’โˆ‘xโ€ฒ:p(xโ€ฒ)<q(xโ€ฒ)p(xโ€ฒ)\beta = 1 - \sum_{x'} q(x') \cdot \min\left(1, \frac{p(x')}{q(x')}\right) = 1 - \sum_{x': p(x') \geq q(x')} q(x') - \sum_{x': p(x') < q(x')} p(x')

The corrected distribution pโ€ฒ(x)=maxโก(0,p(x)โˆ’q(x))/Zp'(x) = \max(0, p(x) - q(x)) / Z where Z=โˆ‘xmaxโก(0,p(x)โˆ’q(x))Z = \sum_x \max(0, p(x) - q(x)).

Substituting and simplifying, Pspec(x)=p(x)P_{\text{spec}}(x) = p(x) for all xx. The output distribution is exactly the target distribution.

โ„น๏ธ Exactness Guarantee

Rejection sampling for speculative decoding is not an approximation. The output distribution is mathematically identical to sampling from the target model directly. The draft model affects only throughput (how many tokens are accepted per iteration), never quality.

Expected Acceptance Rate

The expected number of accepted tokens depends on the overlap between pp and qq. Define the total variation distance:

TV(p,q)=12โˆ‘xโˆฃp(x)โˆ’q(x)โˆฃ\text{TV}(p, q) = \frac{1}{2} \sum_x |p(x) - q(x)|

The expected acceptance rate per token is:

E[ฮฑ]=โˆ‘xq(x)โ‹…minโก(1,p(x)q(x))=1โˆ’TV(p,q)\mathbb{E}[\alpha] = \sum_x q(x) \cdot \min\left(1, \frac{p(x)}{q(x)}\right) = 1 - \text{TV}(p, q)

If p=qp = q (perfect draft model), acceptance rate is 100%. If pp and qq are completely disjoint, acceptance rate is 0%. In practice, a well-chosen draft model achieves 70-90% acceptance rate.

๐Ÿ“Š

Typical Acceptance Rates by Draft Model Quality

Target ModelDraft ModelAcceptance RateAvg Accepted Tokens (K=5)
Llama 70B Llama 7B 65-75% 2.5-3.2
Llama 70B Llama 8B (distilled) 75-85% 3.2-4.0
Llama 70B Llama 70B (quantized W4) 85-92% 4.0-4.5
Llama 70B MLP draft head (2-layer) 55-65% 1.8-2.5
Code Llama 34B Code Llama 7B 70-80% 2.8-3.6
Note: Acceptance rates measured on mixed instruction-following workloads. Distilled and quantized draft models have higher overlap with the target distribution.

The Batched Verification Kernel

Why a Custom Kernel

The naive Python implementation of rejection sampling iterates over each sequence and each token position, performing random number generation and comparison. For a batch of 64 sequences with K=5K=5 draft tokens each, this is 320 sequential operations on CPU โ€” far too slow for the decode hot path.

vLLM implements a batched CUDA kernel that processes all sequences and all draft positions in parallel:

def rejection_sample_batch_python(
    target_probs: torch.Tensor,    # [batch_size, K, vocab_size]
    draft_probs: torch.Tensor,     # [batch_size, K, vocab_size]
    draft_tokens: torch.Tensor,    # [batch_size, K]
) -> tuple:
    """
    Batched rejection sampling (reference implementation).

    Returns:
        accepted_tokens: [batch_size, K] - accepted token IDs (-1 for rejected)
        num_accepted: [batch_size] - count of accepted tokens per sequence
        bonus_tokens: [batch_size] - resampled token at the first rejection point
    """
    batch_size, K, vocab_size = target_probs.shape
    device = target_probs.device

    accepted_tokens = torch.full((batch_size, K), -1, dtype=torch.long, device=device)
    num_accepted = torch.zeros(batch_size, dtype=torch.long, device=device)
    bonus_tokens = torch.full((batch_size,), -1, dtype=torch.long, device=device)

    # Random numbers for acceptance test
    r = torch.rand(batch_size, K, device=device)

    for b in range(batch_size):
        for k in range(K):
            token = draft_tokens[b, k]
            p_target = target_probs[b, k, token]
            p_draft = draft_probs[b, k, token]

            # Acceptance probability
            if p_draft > 0:
                alpha = min(1.0, (p_target / p_draft).item())
            else:
                alpha = 1.0 if p_target > 0 else 0.0

            if r[b, k].item() <= alpha:
                accepted_tokens[b, k] = token
                num_accepted[b] += 1
            else:
                # Reject: compute corrected distribution and resample
                corrected = torch.clamp(
                    target_probs[b, k] - draft_probs[b, k], min=0
                )
                corrected_sum = corrected.sum()
                if corrected_sum > 0:
                    corrected = corrected / corrected_sum
                else:
                    corrected = target_probs[b, k]
                bonus_tokens[b] = torch.multinomial(corrected, 1).item()
                break  # Stop checking further positions

        # If all K tokens accepted, sample one bonus token from p(x_{K+1})
        if num_accepted[b] == K:
            # The target model also computed logits for position K+1
            # Sample from p(x | x_{<K+1}) directly
            # (bonus_tokens[b] is set by the caller in this case)
            pass

    return accepted_tokens, num_accepted, bonus_tokens

The CUDA Kernel

The actual CUDA kernel parallelizes across sequences (one thread block per sequence) and uses warp-level operations for the sequential acceptance chain within each sequence:

// Simplified CUDA kernel for batched rejection sampling
__global__ void rejection_sample_kernel(
    const float* __restrict__ target_probs,  // [B, K, V]
    const float* __restrict__ draft_probs,   // [B, K, V]
    const int64_t* __restrict__ draft_tokens, // [B, K]
    int64_t* __restrict__ accepted_tokens,   // [B, K]
    int64_t* __restrict__ num_accepted,      // [B]
    int64_t* __restrict__ bonus_tokens,      // [B]
    const float* __restrict__ uniform_rand,  // [B, K]
    int K, int V
) {
    int b = blockIdx.x;  // One block per sequence

    // Thread 0 performs the sequential acceptance chain
    if (threadIdx.x == 0) {
        int accepted = 0;
        for (int k = 0; k < K; k++) {
            int64_t token = draft_tokens[b * K + k];
            float p_t = target_probs[(b * K + k) * V + token];
            float p_d = draft_probs[(b * K + k) * V + token];

            float alpha = (p_d > 0.0f) ? fminf(1.0f, p_t / p_d) : 0.0f;
            float r = uniform_rand[b * K + k];

            if (r <= alpha) {
                accepted_tokens[b * K + k] = token;
                accepted++;
            } else {
                accepted_tokens[b * K + k] = -1;
                // Mark remaining as rejected
                for (int j = k + 1; j < K; j++) {
                    accepted_tokens[b * K + j] = -1;
                }
                break;
            }
        }
        num_accepted[b] = accepted;
    }

    // All threads collaborate on resampling from corrected distribution
    // (done for the first rejected position using parallel prefix scan)
    __syncthreads();
    int rej_pos = (int)num_accepted[b];
    if (rej_pos < K) {
        // Parallel computation of corrected distribution
        // Each thread handles a chunk of the vocabulary
        // ... (prefix sum + multinomial sampling)
    }
}

The kernel launches with batch_size thread blocks. The acceptance chain is sequential (each position depends on the previous one), so it runs on a single thread. The resampling at the rejection point uses all threads in the block for parallel multinomial sampling over the vocabulary.

Performance

Rejection Sampling Latency (batch=64, K=5, vocab=128K)

(ms)
Python loop (CPU)
45 ms
PyTorch vectorized
2.8 ms
CUDA kernel 375x faster than Python
0.12 ms

CFG: Classifier-Free Guidance

How CFG Works

Classifier-free guidance strengthens the modelโ€™s adherence to the prompt by computing two sets of logits:

  1. Conditional logits โ„“c\ell_c: the model forward pass with the full prompt
  2. Unconditional logits โ„“u\ell_u: the model forward pass with an empty or generic prompt

The guided logits are:

โ„“guided=โ„“u+sโ‹…(โ„“cโˆ’โ„“u)\ell_{\text{guided}} = \ell_u + s \cdot (\ell_c - \ell_u)

where ss is the guidance scale. At s=1.0s = 1.0, guided logits equal conditional logits (no guidance). At s>1.0s > 1.0, the modelโ€™s prompt-conditioned behavior is amplified. Typical values are s=1.5s = 1.5 to s=3.0s = 3.0.

The intuition: the difference โ„“cโˆ’โ„“u\ell_c - \ell_u isolates the effect of the prompt on the output distribution. Multiplying by s>1s > 1 amplifies this effect.

Companion Request Pairing in vLLM

vLLM implements CFG by creating a โ€œcompanionโ€ request for each CFG-enabled request. The companion request runs the same model with the unconditional prompt:

class CFGRequestPair:
    """A pair of requests for classifier-free guidance."""

    def __init__(self, original_request, guidance_scale: float):
        self.conditional_request = original_request
        self.guidance_scale = guidance_scale

        # Create companion with unconditional prompt
        self.unconditional_request = Request(
            request_id=f"{original_request.request_id}_cfg_uncond",
            prompt_tokens=get_unconditional_tokens(original_request),
            sampling_params=original_request.sampling_params,
            # Share the same output stream
            output_stream=original_request.output_stream,
        )

    def combine_logits(
        self,
        cond_logits: torch.Tensor,
        uncond_logits: torch.Tensor,
    ) -> torch.Tensor:
        """Apply CFG formula to combine logits."""
        return uncond_logits + self.guidance_scale * (cond_logits - uncond_logits)

def get_unconditional_tokens(request):
    """Generate the unconditional prompt for CFG."""
    # Option 1: Empty prompt (just BOS token)
    # return [request.tokenizer.bos_token_id]

    # Option 2: Generic system prompt without user content
    # This preserves the model's general behavior while removing
    # the specific instruction
    return request.tokenizer.encode("You are a helpful assistant.")

Scheduler Integration

The scheduler must co-schedule conditional and unconditional requests to ensure they run in the same batch (so logits are available simultaneously):

class CFGAwareScheduler:
    """Scheduler that co-schedules CFG request pairs."""

    def schedule(self, waiting_queue, running_queue, max_batch_size):
        scheduled = []

        for request in waiting_queue:
            if request.cfg_pair is not None:
                # This is a CFG-enabled request
                # Must schedule both conditional and unconditional together
                if len(scheduled) + 2 <= max_batch_size:
                    scheduled.append(request.cfg_pair.conditional_request)
                    scheduled.append(request.cfg_pair.unconditional_request)
                else:
                    break  # Not enough room for the pair
            else:
                if len(scheduled) + 1 <= max_batch_size:
                    scheduled.append(request)
                else:
                    break

        return scheduled

The cost of CFG is exactly 2x: every CFG-enabled request consumes two batch slots and requires two KV cache allocations. The unconditional requestโ€™s KV cache is typically smaller (shorter prompt), but it still occupies a sequence slot.

โš ๏ธ CFG Doubles Resource Consumption

Each CFG-enabled request requires two forward passes and two KV cache allocations. At high load, CFG can halve your effective throughput. Consider using CFG only for requests where prompt adherence is critical, and route non-CFG requests to separate workers.

CFG Logit Combination Kernel

For efficiency, the logit combination is done in a single CUDA kernel rather than three separate PyTorch operations:

@torch.compile
def cfg_combine_logits(
    cond_logits: torch.Tensor,    # [num_cfg_pairs, vocab_size]
    uncond_logits: torch.Tensor,  # [num_cfg_pairs, vocab_size]
    guidance_scales: torch.Tensor, # [num_cfg_pairs]
) -> torch.Tensor:
    """Fused CFG logit combination."""
    # guidance_scales: [N, 1] for broadcasting
    scales = guidance_scales.unsqueeze(-1)
    return uncond_logits + scales * (cond_logits - uncond_logits)

The @torch.compile decorator fuses this into a single Triton kernel, avoiding two intermediate tensor allocations.

Speculative Decoding + CFG: The Combined Pipeline

When both speculative decoding and CFG are active, the pipeline becomes:

  1. Draft model generates KK tokens (no CFG applied to draft โ€” too expensive)
  2. Target model runs a single forward pass over all KK positions
  3. For CFG-enabled requests, the target model also runs the unconditional forward pass
  4. CFG combination produces guided logits at each position
  5. Rejection sampling compares guided target logits against draft logits
class SpeculativeCFGPipeline:
    """Combined speculative decoding + CFG pipeline."""

    def verify_and_accept(
        self,
        draft_tokens,      # [batch_size, K]
        draft_probs,       # [batch_size, K, vocab_size]
        target_cond_probs, # [batch_size, K, vocab_size]
        target_uncond_probs, # [cfg_batch_size, K, vocab_size] (only CFG requests)
        guidance_scales,   # [batch_size] (1.0 for non-CFG requests)
        cfg_mask,          # [batch_size] bool (which requests use CFG)
    ):
        # Step 1: Apply CFG to target probs
        target_probs = target_cond_probs.clone()
        if cfg_mask.any():
            cfg_indices = cfg_mask.nonzero(as_tuple=True)[0]
            # Convert logits to probs, apply CFG, convert back
            cond_logits = torch.log(target_cond_probs[cfg_indices] + 1e-10)
            uncond_logits = torch.log(target_uncond_probs + 1e-10)
            scales = guidance_scales[cfg_indices].unsqueeze(-1).unsqueeze(-1)
            guided_logits = uncond_logits + scales * (cond_logits - uncond_logits)
            target_probs[cfg_indices] = torch.softmax(guided_logits, dim=-1)

        # Step 2: Run rejection sampling with guided probs
        accepted, num_accepted, bonus = rejection_sample_batch(
            target_probs, draft_probs, draft_tokens
        )
        return accepted, num_accepted, bonus

Token-Level Probability Extraction

From Logits to Probabilities

The target model produces logits, not probabilities. Converting requires softmax, which is expensive for large vocabularies:

# vocab_size = 128,000 for Llama 3
# batch_size * K = 64 * 5 = 320 positions
# Softmax over 128K elements for 320 positions = 41M elements

# Naive: full softmax
target_probs = torch.softmax(target_logits, dim=-1)  # [320, 128000]
# Memory: 320 * 128000 * 4 bytes = 163 MB

# Optimized: only compute probs for draft tokens
# We only need p(x_i) for the specific token x_i that was drafted
target_probs_at_draft = torch.gather(
    torch.softmax(target_logits, dim=-1),
    dim=-1,
    index=draft_tokens.unsqueeze(-1)
).squeeze(-1)  # [320]

But we also need the full distribution for resampling at the rejection point. The approach in vLLM:

def extract_probs_for_rejection(
    target_logits: torch.Tensor,   # [batch_size, K, vocab_size]
    draft_logits: torch.Tensor,    # [batch_size, K, vocab_size]
    draft_tokens: torch.Tensor,    # [batch_size, K]
):
    """Extract probabilities needed for rejection sampling."""
    # Full softmax (needed for resampling)
    target_probs = torch.softmax(target_logits, dim=-1)
    draft_probs = torch.softmax(draft_logits, dim=-1)

    # Extract p(x_i) and q(x_i) for each draft token
    batch_size, K, V = target_logits.shape
    batch_idx = torch.arange(batch_size, device=target_logits.device)
    batch_idx = batch_idx.unsqueeze(1).expand(-1, K)
    k_idx = torch.arange(K, device=target_logits.device).unsqueeze(0).expand(batch_size, -1)

    p_at_draft = target_probs[batch_idx, k_idx, draft_tokens]  # [batch_size, K]
    q_at_draft = draft_probs[batch_idx, k_idx, draft_tokens]   # [batch_size, K]

    return target_probs, draft_probs, p_at_draft, q_at_draft

Top-K Optimization

For resampling, we do not need the full vocabulary distribution. The corrected distribution pโ€ฒ(x)=maxโก(0,p(x)โˆ’q(x))p'(x) = \max(0, p(x) - q(x)) is nonzero only where p(x)>q(x)p(x) > q(x). In practice, this is a small subset of the vocabulary. A top-K optimization:

def fast_corrected_resample(
    target_probs: torch.Tensor,  # [vocab_size]
    draft_probs: torch.Tensor,   # [vocab_size]
    top_k: int = 1024,
):
    """Resample from corrected distribution using top-K approximation."""
    # Get top-K tokens by target probability
    top_vals, top_indices = torch.topk(target_probs, top_k)
    draft_top = draft_probs[top_indices]

    # Corrected distribution over top-K
    corrected = torch.clamp(top_vals - draft_top, min=0)
    corrected_sum = corrected.sum()

    if corrected_sum > 0:
        corrected = corrected / corrected_sum
    else:
        corrected = top_vals / top_vals.sum()

    # Sample from corrected
    idx = torch.multinomial(corrected.unsqueeze(0), 1).squeeze()
    return top_indices[idx]

This reduces the resampling cost from O(V)O(V) to O(K)O(K) where Kโ‰ชVK \ll V.

KV Cache Management for Speculated Tokens

The Problem

When the draft model generates KK tokens, the target model runs a forward pass over all KK positions. This writes KV cache data for all KK positions. But if token ii is rejected, positions i+1,โ€ฆ,Ki+1, \ldots, K are invalid and their KV cache must be discarded.

class SpeculativeKVManager:
    """Manage KV cache for speculated tokens."""

    def __init__(self, block_manager):
        self.block_manager = block_manager

    def allocate_speculative_slots(self, seq_id, num_draft_tokens):
        """Pre-allocate KV cache slots for draft tokens."""
        # These slots are tentative -- may be freed if tokens are rejected
        slots = []
        for i in range(num_draft_tokens):
            slot = self.block_manager.allocate_slot(seq_id)
            slots.append(slot)
        return slots

    def commit_accepted(self, seq_id, slots, num_accepted):
        """Commit accepted token slots, free rejected ones."""
        # Keep slots 0..num_accepted-1
        # Free slots num_accepted..K-1
        for i in range(num_accepted, len(slots)):
            self.block_manager.free_slot(seq_id, slots[i])

        # Also free the draft model's KV cache for all positions
        # (draft KV is never needed after verification)
        self.block_manager.free_draft_kv(seq_id)

    def rollback(self, seq_id, slots):
        """Free all speculative slots (all tokens rejected)."""
        for slot in slots:
            self.block_manager.free_slot(seq_id, slot)

The key insight: speculative token slots are allocated optimistically and freed on rejection. The block manager must support fine-grained slot-level deallocation, not just full-block deallocation.

Temperature and Sampling Interaction

Temperature Scaling

Rejection sampling interacts with temperature in a subtle way. The acceptance criterion uses probabilities, which depend on temperature:

pT(x)=expโก(โ„“x/T)โˆ‘xโ€ฒexpโก(โ„“xโ€ฒ/T)p_T(x) = \frac{\exp(\ell_x / T)}{\sum_{x'} \exp(\ell_{x'} / T)}

Both the draft model and target model must use the same temperature for rejection sampling to maintain its exactness guarantee. If the draft model uses TdT_d and the target model uses TtT_t with Tdโ‰ TtT_d \neq T_t, the acceptance rate degrades because the distributions diverge more.

def verify_temperature_consistency(
    draft_sampling_params,
    target_sampling_params,
):
    """Verify that draft and target use compatible sampling parameters."""
    issues = []

    if draft_sampling_params.temperature != target_sampling_params.temperature:
        issues.append(
            f"Temperature mismatch: draft={draft_sampling_params.temperature}, "
            f"target={target_sampling_params.temperature}. "
            f"Acceptance rate will degrade."
        )

    if draft_sampling_params.top_p != target_sampling_params.top_p:
        issues.append(
            f"top_p mismatch: draft={draft_sampling_params.top_p}, "
            f"target={target_sampling_params.top_p}. "
            f"Output distribution will not match target exactly."
        )

    return issues

Top-P / Top-K Interaction

When top-p or top-k sampling is applied, the effective distributions are truncated. Rejection sampling on truncated distributions is still valid but requires careful handling:

def rejection_sample_with_topk(
    target_logits: torch.Tensor,
    draft_logits: torch.Tensor,
    draft_tokens: torch.Tensor,
    top_k: int,
    temperature: float,
):
    """Rejection sampling with top-K truncation."""
    # Apply temperature
    target_logits = target_logits / temperature
    draft_logits = draft_logits / temperature

    # Apply top-K to both distributions
    target_probs = top_k_softmax(target_logits, top_k)
    draft_probs = top_k_softmax(draft_logits, top_k)

    # Standard rejection sampling on truncated distributions
    return rejection_sample_batch(target_probs, draft_probs, draft_tokens)

def top_k_softmax(logits: torch.Tensor, k: int) -> torch.Tensor:
    """Softmax with top-K truncation."""
    top_k_vals, top_k_indices = torch.topk(logits, k, dim=-1)
    # Zero out everything not in top-K
    mask = torch.zeros_like(logits, dtype=torch.bool)
    mask.scatter_(-1, top_k_indices, True)
    logits = logits.masked_fill(~mask, float("-inf"))
    return torch.softmax(logits, dim=-1)

Complete Rejection Sampler Implementation

import torch
from dataclasses import dataclass

@dataclass
class RejectionResult:
    """Result of batched rejection sampling."""
    accepted_tokens: torch.Tensor    # [batch_size, max_accepted] padded with -1
    num_accepted: torch.Tensor       # [batch_size]
    bonus_tokens: torch.Tensor       # [batch_size] resampled at rejection point
    acceptance_rate: float           # Average acceptance rate across batch

class RejectionSampler:
    """
    Production-quality rejection sampler for speculative decoding.
    Handles batched verification, temperature, top-K, and CFG.
    """

    def __init__(self, strict_mode: bool = True):
        self.strict_mode = strict_mode
        self._total_proposed = 0
        self._total_accepted = 0

    @torch.no_grad()
    def __call__(
        self,
        target_logits: torch.Tensor,   # [batch_size, K+1, vocab_size]
        draft_logits: torch.Tensor,    # [batch_size, K, vocab_size]
        draft_tokens: torch.Tensor,    # [batch_size, K]
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0,
    ) -> RejectionResult:
        batch_size, K = draft_tokens.shape
        device = target_logits.device

        # Apply temperature
        if temperature != 1.0:
            target_logits = target_logits / temperature
            draft_logits = draft_logits / temperature

        # Convert to probabilities (with optional top-K/top-P)
        target_probs = self._to_probs(target_logits[:, :K], top_k, top_p)
        draft_probs = self._to_probs(draft_logits, top_k, top_p)

        # Extract p(x_i) and q(x_i) for each drafted token
        p_at_x = self._gather_probs(target_probs, draft_tokens)  # [B, K]
        q_at_x = self._gather_probs(draft_probs, draft_tokens)   # [B, K]

        # Compute acceptance probabilities
        alpha = torch.where(
            q_at_x > 0,
            torch.clamp(p_at_x / q_at_x, max=1.0),
            torch.where(p_at_x > 0, torch.ones_like(p_at_x), torch.zeros_like(p_at_x)),
        )

        # Draw uniform random numbers
        r = torch.rand(batch_size, K, device=device)

        # Compute acceptance mask: accepted[b, k] = True iff all positions 0..k are accepted
        accepted_per_pos = r <= alpha                  # [B, K]
        # Chain: position k is accepted only if all 0..k-1 are also accepted
        accepted_chain = torch.cumprod(accepted_per_pos.long(), dim=1)  # [B, K]

        num_accepted = accepted_chain.sum(dim=1)       # [B]

        # Build accepted_tokens tensor
        accepted_tokens = torch.where(
            accepted_chain.bool(),
            draft_tokens,
            torch.full_like(draft_tokens, -1),
        )

        # Compute bonus tokens at the first rejection point
        bonus_tokens = self._resample_at_rejection(
            target_probs, draft_probs, num_accepted, K, target_logits[:, K]
        )

        # Update stats
        self._total_proposed += batch_size * K
        self._total_accepted += num_accepted.sum().item()

        return RejectionResult(
            accepted_tokens=accepted_tokens,
            num_accepted=num_accepted,
            bonus_tokens=bonus_tokens,
            acceptance_rate=num_accepted.float().mean().item() / K,
        )

    def _to_probs(self, logits, top_k, top_p):
        """Convert logits to probabilities with optional truncation."""
        if top_k > 0:
            top_vals, top_idx = torch.topk(logits, top_k, dim=-1)
            mask = torch.zeros_like(logits, dtype=torch.bool)
            mask.scatter_(-1, top_idx, True)
            logits = logits.masked_fill(~mask, float("-inf"))
        if top_p < 1.0:
            sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
            cumulative = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
            mask = cumulative - torch.softmax(sorted_logits, dim=-1) >= top_p
            sorted_logits[mask] = float("-inf")
            logits = sorted_logits.scatter(-1, sorted_idx, sorted_logits)
        return torch.softmax(logits, dim=-1)

    def _gather_probs(self, probs, tokens):
        """Gather probabilities for specific tokens."""
        return probs.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)

    def _resample_at_rejection(self, target_probs, draft_probs, num_accepted, K, next_logits):
        """Resample token at the first rejection point."""
        batch_size = target_probs.shape[0]
        device = target_probs.device
        bonus = torch.zeros(batch_size, dtype=torch.long, device=device)

        for b in range(batch_size):
            n = num_accepted[b].item()
            if n < K:
                # Rejection at position n: sample from corrected distribution
                corrected = torch.clamp(target_probs[b, n] - draft_probs[b, n], min=0)
                csum = corrected.sum()
                if csum > 0:
                    corrected = corrected / csum
                else:
                    corrected = target_probs[b, n]
                bonus[b] = torch.multinomial(corrected.unsqueeze(0), 1).squeeze()
            else:
                # All K accepted: sample bonus from position K+1
                bonus_probs = torch.softmax(next_logits[b], dim=-1)
                bonus[b] = torch.multinomial(bonus_probs.unsqueeze(0), 1).squeeze()

        return bonus

    @property
    def overall_acceptance_rate(self):
        if self._total_proposed == 0:
            return 0.0
        return self._total_accepted / self._total_proposed

Monitoring and Tuning

Key Metrics

class SpeculativeMetrics:
    """Track speculative decoding performance metrics."""

    def __init__(self):
        self.acceptance_rates = []
        self.draft_latencies = []
        self.verify_latencies = []
        self.tokens_per_step = []

    def record_step(self, result: RejectionResult, draft_ms: float, verify_ms: float, K: int):
        self.acceptance_rates.append(result.acceptance_rate)
        self.draft_latencies.append(draft_ms)
        self.verify_latencies.append(verify_ms)
        # Accepted tokens + 1 bonus token
        self.tokens_per_step.append(result.num_accepted.float().mean().item() + 1)

    def report(self):
        avg_rate = sum(self.acceptance_rates) / len(self.acceptance_rates)
        avg_draft = sum(self.draft_latencies) / len(self.draft_latencies)
        avg_verify = sum(self.verify_latencies) / len(self.verify_latencies)
        avg_tokens = sum(self.tokens_per_step) / len(self.tokens_per_step)

        # Speedup = tokens_per_step / (draft_time + verify_time) * baseline_time
        avg_step_time = avg_draft + avg_verify
        effective_tps = avg_tokens / (avg_step_time / 1000)

        return {
            "acceptance_rate": avg_rate,
            "avg_tokens_per_step": avg_tokens,
            "avg_step_time_ms": avg_step_time,
            "effective_tokens_per_sec": effective_tps,
        }

When Speculative Decoding Helps

Speculative decoding is beneficial when:

nห‰+1tdraft+tverify>1tbaseline\frac{\bar{n} + 1}{t_{\text{draft}} + t_{\text{verify}}} > \frac{1}{t_{\text{baseline}}}

where nห‰\bar{n} is the average number of accepted tokens, tdraftt_{\text{draft}} is the draft model latency, tverifyt_{\text{verify}} is the target model verification latency, and tbaselinet_{\text{baseline}} is the target model decode latency without speculation.

Since verification processes KK tokens in one forward pass (same cost as one decode step due to batching), tverifyโ‰ˆtbaselinet_{\text{verify}} \approx t_{\text{baseline}}. The speedup ratio simplifies to:

speedup=(nห‰+1)โ‹…tbaselinetdraft+tbaseline=nห‰+11+tdraft/tbaseline\text{speedup} = \frac{(\bar{n} + 1) \cdot t_{\text{baseline}}}{t_{\text{draft}} + t_{\text{baseline}}} = \frac{\bar{n} + 1}{1 + t_{\text{draft}} / t_{\text{baseline}}}

For tdraft=0.1โ‹…tbaselinet_{\text{draft}} = 0.1 \cdot t_{\text{baseline}} (draft is 10x faster) and nห‰=3\bar{n} = 3:

speedup=41.1=3.6ร—\text{speedup} = \frac{4}{1.1} = 3.6\times

Speculative Decoding Speedup vs. Acceptance Rate (K=5)

(x speedup)
Rate=50% Marginal
1.64 x speedup
Rate=60%
2 x speedup
Rate=70%
2.36 x speedup
Rate=80%
2.73 x speedup
Rate=90% Excellent
3.18 x speedup