Speculative decoding promises to generate K tokens for the cost of one forward pass โ but only if the draft model is good enough. Accept too few draft tokens and you waste compute on verification. Accept too many bad tokens and your output distribution diverges from the target model. The rejection sampling criterion solves this with a statistical guarantee: it accepts as many draft tokens as possible while ensuring the final distribution is mathematically identical to sampling from the target model directly. Classifier-free guidance (CFG) has a different goal โ strengthening prompt adherence by blending conditional and unconditional logits โ but it shares the same core infrastructure: running multiple forward passes per step and comparing probability distributions token-by-token.
This post covers the rejection sampling criterion, the batched verification CUDA kernel, CFG companion request pairing, the integration with vLLMโs scheduler, and a complete implementation.
The Rejection Sampling Criterion
The Problem Setup
In speculative decoding, a draft model generates candidate tokens autoregressively. Each token is sampled from the draft distribution . The target model then evaluates all positions in a single forward pass, producing target distributions for each position.
The goal: accept as many draft tokens as possible while ensuring the final output distribution is identical to sampling from directly. We want the speed of the draft model with the quality of the target model.
The Acceptance Criterion
For each draft token , the acceptance probability is:
The algorithm:
- For each position :
- Draw
- If : accept , continue to position
- If : reject and all subsequent tokens. Resample from a corrected distribution.
- The corrected distribution at the first rejection point is:
This ensures the final output is distributed exactly as , not as a mixture of and .
Why This Works: The Proof
The combined probability of generating token through speculative decoding is:
The first term covers accepted tokens. The second term covers rejected-and-resampled tokens. Expanding:
For tokens where : the acceptance probability is 1, so the first term contributes .
For tokens where : the acceptance probability is , so the first term contributes .
The probability of reaching the rejection branch is:
The corrected distribution where .
Substituting and simplifying, for all . The output distribution is exactly the target distribution.
Rejection sampling for speculative decoding is not an approximation. The output distribution is mathematically identical to sampling from the target model directly. The draft model affects only throughput (how many tokens are accepted per iteration), never quality.
Expected Acceptance Rate
The expected number of accepted tokens depends on the overlap between and . Define the total variation distance:
The expected acceptance rate per token is:
If (perfect draft model), acceptance rate is 100%. If and are completely disjoint, acceptance rate is 0%. In practice, a well-chosen draft model achieves 70-90% acceptance rate.
Typical Acceptance Rates by Draft Model Quality
| Target Model | Draft Model | Acceptance Rate | Avg Accepted Tokens (K=5) |
|---|---|---|---|
| Llama 70B | Llama 7B | 65-75% | 2.5-3.2 |
| Llama 70B | Llama 8B (distilled) | 75-85% | 3.2-4.0 |
| Llama 70B | Llama 70B (quantized W4) | 85-92% | 4.0-4.5 |
| Llama 70B | MLP draft head (2-layer) | 55-65% | 1.8-2.5 |
| Code Llama 34B | Code Llama 7B | 70-80% | 2.8-3.6 |
The Batched Verification Kernel
Why a Custom Kernel
The naive Python implementation of rejection sampling iterates over each sequence and each token position, performing random number generation and comparison. For a batch of 64 sequences with draft tokens each, this is 320 sequential operations on CPU โ far too slow for the decode hot path.
vLLM implements a batched CUDA kernel that processes all sequences and all draft positions in parallel:
def rejection_sample_batch_python(
target_probs: torch.Tensor, # [batch_size, K, vocab_size]
draft_probs: torch.Tensor, # [batch_size, K, vocab_size]
draft_tokens: torch.Tensor, # [batch_size, K]
) -> tuple:
"""
Batched rejection sampling (reference implementation).
Returns:
accepted_tokens: [batch_size, K] - accepted token IDs (-1 for rejected)
num_accepted: [batch_size] - count of accepted tokens per sequence
bonus_tokens: [batch_size] - resampled token at the first rejection point
"""
batch_size, K, vocab_size = target_probs.shape
device = target_probs.device
accepted_tokens = torch.full((batch_size, K), -1, dtype=torch.long, device=device)
num_accepted = torch.zeros(batch_size, dtype=torch.long, device=device)
bonus_tokens = torch.full((batch_size,), -1, dtype=torch.long, device=device)
# Random numbers for acceptance test
r = torch.rand(batch_size, K, device=device)
for b in range(batch_size):
for k in range(K):
token = draft_tokens[b, k]
p_target = target_probs[b, k, token]
p_draft = draft_probs[b, k, token]
# Acceptance probability
if p_draft > 0:
alpha = min(1.0, (p_target / p_draft).item())
else:
alpha = 1.0 if p_target > 0 else 0.0
if r[b, k].item() <= alpha:
accepted_tokens[b, k] = token
num_accepted[b] += 1
else:
# Reject: compute corrected distribution and resample
corrected = torch.clamp(
target_probs[b, k] - draft_probs[b, k], min=0
)
corrected_sum = corrected.sum()
if corrected_sum > 0:
corrected = corrected / corrected_sum
else:
corrected = target_probs[b, k]
bonus_tokens[b] = torch.multinomial(corrected, 1).item()
break # Stop checking further positions
# If all K tokens accepted, sample one bonus token from p(x_{K+1})
if num_accepted[b] == K:
# The target model also computed logits for position K+1
# Sample from p(x | x_{<K+1}) directly
# (bonus_tokens[b] is set by the caller in this case)
pass
return accepted_tokens, num_accepted, bonus_tokens
The CUDA Kernel
The actual CUDA kernel parallelizes across sequences (one thread block per sequence) and uses warp-level operations for the sequential acceptance chain within each sequence:
// Simplified CUDA kernel for batched rejection sampling
__global__ void rejection_sample_kernel(
const float* __restrict__ target_probs, // [B, K, V]
const float* __restrict__ draft_probs, // [B, K, V]
const int64_t* __restrict__ draft_tokens, // [B, K]
int64_t* __restrict__ accepted_tokens, // [B, K]
int64_t* __restrict__ num_accepted, // [B]
int64_t* __restrict__ bonus_tokens, // [B]
const float* __restrict__ uniform_rand, // [B, K]
int K, int V
) {
int b = blockIdx.x; // One block per sequence
// Thread 0 performs the sequential acceptance chain
if (threadIdx.x == 0) {
int accepted = 0;
for (int k = 0; k < K; k++) {
int64_t token = draft_tokens[b * K + k];
float p_t = target_probs[(b * K + k) * V + token];
float p_d = draft_probs[(b * K + k) * V + token];
float alpha = (p_d > 0.0f) ? fminf(1.0f, p_t / p_d) : 0.0f;
float r = uniform_rand[b * K + k];
if (r <= alpha) {
accepted_tokens[b * K + k] = token;
accepted++;
} else {
accepted_tokens[b * K + k] = -1;
// Mark remaining as rejected
for (int j = k + 1; j < K; j++) {
accepted_tokens[b * K + j] = -1;
}
break;
}
}
num_accepted[b] = accepted;
}
// All threads collaborate on resampling from corrected distribution
// (done for the first rejected position using parallel prefix scan)
__syncthreads();
int rej_pos = (int)num_accepted[b];
if (rej_pos < K) {
// Parallel computation of corrected distribution
// Each thread handles a chunk of the vocabulary
// ... (prefix sum + multinomial sampling)
}
}
The kernel launches with batch_size thread blocks. The acceptance chain is sequential (each position depends on the previous one), so it runs on a single thread. The resampling at the rejection point uses all threads in the block for parallel multinomial sampling over the vocabulary.
Performance
Rejection Sampling Latency (batch=64, K=5, vocab=128K)
(ms)CFG: Classifier-Free Guidance
How CFG Works
Classifier-free guidance strengthens the modelโs adherence to the prompt by computing two sets of logits:
- Conditional logits : the model forward pass with the full prompt
- Unconditional logits : the model forward pass with an empty or generic prompt
The guided logits are:
where is the guidance scale. At , guided logits equal conditional logits (no guidance). At , the modelโs prompt-conditioned behavior is amplified. Typical values are to .
The intuition: the difference isolates the effect of the prompt on the output distribution. Multiplying by amplifies this effect.
Companion Request Pairing in vLLM
vLLM implements CFG by creating a โcompanionโ request for each CFG-enabled request. The companion request runs the same model with the unconditional prompt:
class CFGRequestPair:
"""A pair of requests for classifier-free guidance."""
def __init__(self, original_request, guidance_scale: float):
self.conditional_request = original_request
self.guidance_scale = guidance_scale
# Create companion with unconditional prompt
self.unconditional_request = Request(
request_id=f"{original_request.request_id}_cfg_uncond",
prompt_tokens=get_unconditional_tokens(original_request),
sampling_params=original_request.sampling_params,
# Share the same output stream
output_stream=original_request.output_stream,
)
def combine_logits(
self,
cond_logits: torch.Tensor,
uncond_logits: torch.Tensor,
) -> torch.Tensor:
"""Apply CFG formula to combine logits."""
return uncond_logits + self.guidance_scale * (cond_logits - uncond_logits)
def get_unconditional_tokens(request):
"""Generate the unconditional prompt for CFG."""
# Option 1: Empty prompt (just BOS token)
# return [request.tokenizer.bos_token_id]
# Option 2: Generic system prompt without user content
# This preserves the model's general behavior while removing
# the specific instruction
return request.tokenizer.encode("You are a helpful assistant.")
Scheduler Integration
The scheduler must co-schedule conditional and unconditional requests to ensure they run in the same batch (so logits are available simultaneously):
class CFGAwareScheduler:
"""Scheduler that co-schedules CFG request pairs."""
def schedule(self, waiting_queue, running_queue, max_batch_size):
scheduled = []
for request in waiting_queue:
if request.cfg_pair is not None:
# This is a CFG-enabled request
# Must schedule both conditional and unconditional together
if len(scheduled) + 2 <= max_batch_size:
scheduled.append(request.cfg_pair.conditional_request)
scheduled.append(request.cfg_pair.unconditional_request)
else:
break # Not enough room for the pair
else:
if len(scheduled) + 1 <= max_batch_size:
scheduled.append(request)
else:
break
return scheduled
The cost of CFG is exactly 2x: every CFG-enabled request consumes two batch slots and requires two KV cache allocations. The unconditional requestโs KV cache is typically smaller (shorter prompt), but it still occupies a sequence slot.
Each CFG-enabled request requires two forward passes and two KV cache allocations. At high load, CFG can halve your effective throughput. Consider using CFG only for requests where prompt adherence is critical, and route non-CFG requests to separate workers.
CFG Logit Combination Kernel
For efficiency, the logit combination is done in a single CUDA kernel rather than three separate PyTorch operations:
@torch.compile
def cfg_combine_logits(
cond_logits: torch.Tensor, # [num_cfg_pairs, vocab_size]
uncond_logits: torch.Tensor, # [num_cfg_pairs, vocab_size]
guidance_scales: torch.Tensor, # [num_cfg_pairs]
) -> torch.Tensor:
"""Fused CFG logit combination."""
# guidance_scales: [N, 1] for broadcasting
scales = guidance_scales.unsqueeze(-1)
return uncond_logits + scales * (cond_logits - uncond_logits)
The @torch.compile decorator fuses this into a single Triton kernel, avoiding two intermediate tensor allocations.
Speculative Decoding + CFG: The Combined Pipeline
When both speculative decoding and CFG are active, the pipeline becomes:
- Draft model generates tokens (no CFG applied to draft โ too expensive)
- Target model runs a single forward pass over all positions
- For CFG-enabled requests, the target model also runs the unconditional forward pass
- CFG combination produces guided logits at each position
- Rejection sampling compares guided target logits against draft logits
class SpeculativeCFGPipeline:
"""Combined speculative decoding + CFG pipeline."""
def verify_and_accept(
self,
draft_tokens, # [batch_size, K]
draft_probs, # [batch_size, K, vocab_size]
target_cond_probs, # [batch_size, K, vocab_size]
target_uncond_probs, # [cfg_batch_size, K, vocab_size] (only CFG requests)
guidance_scales, # [batch_size] (1.0 for non-CFG requests)
cfg_mask, # [batch_size] bool (which requests use CFG)
):
# Step 1: Apply CFG to target probs
target_probs = target_cond_probs.clone()
if cfg_mask.any():
cfg_indices = cfg_mask.nonzero(as_tuple=True)[0]
# Convert logits to probs, apply CFG, convert back
cond_logits = torch.log(target_cond_probs[cfg_indices] + 1e-10)
uncond_logits = torch.log(target_uncond_probs + 1e-10)
scales = guidance_scales[cfg_indices].unsqueeze(-1).unsqueeze(-1)
guided_logits = uncond_logits + scales * (cond_logits - uncond_logits)
target_probs[cfg_indices] = torch.softmax(guided_logits, dim=-1)
# Step 2: Run rejection sampling with guided probs
accepted, num_accepted, bonus = rejection_sample_batch(
target_probs, draft_probs, draft_tokens
)
return accepted, num_accepted, bonus
Token-Level Probability Extraction
From Logits to Probabilities
The target model produces logits, not probabilities. Converting requires softmax, which is expensive for large vocabularies:
# vocab_size = 128,000 for Llama 3
# batch_size * K = 64 * 5 = 320 positions
# Softmax over 128K elements for 320 positions = 41M elements
# Naive: full softmax
target_probs = torch.softmax(target_logits, dim=-1) # [320, 128000]
# Memory: 320 * 128000 * 4 bytes = 163 MB
# Optimized: only compute probs for draft tokens
# We only need p(x_i) for the specific token x_i that was drafted
target_probs_at_draft = torch.gather(
torch.softmax(target_logits, dim=-1),
dim=-1,
index=draft_tokens.unsqueeze(-1)
).squeeze(-1) # [320]
But we also need the full distribution for resampling at the rejection point. The approach in vLLM:
def extract_probs_for_rejection(
target_logits: torch.Tensor, # [batch_size, K, vocab_size]
draft_logits: torch.Tensor, # [batch_size, K, vocab_size]
draft_tokens: torch.Tensor, # [batch_size, K]
):
"""Extract probabilities needed for rejection sampling."""
# Full softmax (needed for resampling)
target_probs = torch.softmax(target_logits, dim=-1)
draft_probs = torch.softmax(draft_logits, dim=-1)
# Extract p(x_i) and q(x_i) for each draft token
batch_size, K, V = target_logits.shape
batch_idx = torch.arange(batch_size, device=target_logits.device)
batch_idx = batch_idx.unsqueeze(1).expand(-1, K)
k_idx = torch.arange(K, device=target_logits.device).unsqueeze(0).expand(batch_size, -1)
p_at_draft = target_probs[batch_idx, k_idx, draft_tokens] # [batch_size, K]
q_at_draft = draft_probs[batch_idx, k_idx, draft_tokens] # [batch_size, K]
return target_probs, draft_probs, p_at_draft, q_at_draft
Top-K Optimization
For resampling, we do not need the full vocabulary distribution. The corrected distribution is nonzero only where . In practice, this is a small subset of the vocabulary. A top-K optimization:
def fast_corrected_resample(
target_probs: torch.Tensor, # [vocab_size]
draft_probs: torch.Tensor, # [vocab_size]
top_k: int = 1024,
):
"""Resample from corrected distribution using top-K approximation."""
# Get top-K tokens by target probability
top_vals, top_indices = torch.topk(target_probs, top_k)
draft_top = draft_probs[top_indices]
# Corrected distribution over top-K
corrected = torch.clamp(top_vals - draft_top, min=0)
corrected_sum = corrected.sum()
if corrected_sum > 0:
corrected = corrected / corrected_sum
else:
corrected = top_vals / top_vals.sum()
# Sample from corrected
idx = torch.multinomial(corrected.unsqueeze(0), 1).squeeze()
return top_indices[idx]
This reduces the resampling cost from to where .
KV Cache Management for Speculated Tokens
The Problem
When the draft model generates tokens, the target model runs a forward pass over all positions. This writes KV cache data for all positions. But if token is rejected, positions are invalid and their KV cache must be discarded.
class SpeculativeKVManager:
"""Manage KV cache for speculated tokens."""
def __init__(self, block_manager):
self.block_manager = block_manager
def allocate_speculative_slots(self, seq_id, num_draft_tokens):
"""Pre-allocate KV cache slots for draft tokens."""
# These slots are tentative -- may be freed if tokens are rejected
slots = []
for i in range(num_draft_tokens):
slot = self.block_manager.allocate_slot(seq_id)
slots.append(slot)
return slots
def commit_accepted(self, seq_id, slots, num_accepted):
"""Commit accepted token slots, free rejected ones."""
# Keep slots 0..num_accepted-1
# Free slots num_accepted..K-1
for i in range(num_accepted, len(slots)):
self.block_manager.free_slot(seq_id, slots[i])
# Also free the draft model's KV cache for all positions
# (draft KV is never needed after verification)
self.block_manager.free_draft_kv(seq_id)
def rollback(self, seq_id, slots):
"""Free all speculative slots (all tokens rejected)."""
for slot in slots:
self.block_manager.free_slot(seq_id, slot)
The key insight: speculative token slots are allocated optimistically and freed on rejection. The block manager must support fine-grained slot-level deallocation, not just full-block deallocation.
Temperature and Sampling Interaction
Temperature Scaling
Rejection sampling interacts with temperature in a subtle way. The acceptance criterion uses probabilities, which depend on temperature:
Both the draft model and target model must use the same temperature for rejection sampling to maintain its exactness guarantee. If the draft model uses and the target model uses with , the acceptance rate degrades because the distributions diverge more.
def verify_temperature_consistency(
draft_sampling_params,
target_sampling_params,
):
"""Verify that draft and target use compatible sampling parameters."""
issues = []
if draft_sampling_params.temperature != target_sampling_params.temperature:
issues.append(
f"Temperature mismatch: draft={draft_sampling_params.temperature}, "
f"target={target_sampling_params.temperature}. "
f"Acceptance rate will degrade."
)
if draft_sampling_params.top_p != target_sampling_params.top_p:
issues.append(
f"top_p mismatch: draft={draft_sampling_params.top_p}, "
f"target={target_sampling_params.top_p}. "
f"Output distribution will not match target exactly."
)
return issues
Top-P / Top-K Interaction
When top-p or top-k sampling is applied, the effective distributions are truncated. Rejection sampling on truncated distributions is still valid but requires careful handling:
def rejection_sample_with_topk(
target_logits: torch.Tensor,
draft_logits: torch.Tensor,
draft_tokens: torch.Tensor,
top_k: int,
temperature: float,
):
"""Rejection sampling with top-K truncation."""
# Apply temperature
target_logits = target_logits / temperature
draft_logits = draft_logits / temperature
# Apply top-K to both distributions
target_probs = top_k_softmax(target_logits, top_k)
draft_probs = top_k_softmax(draft_logits, top_k)
# Standard rejection sampling on truncated distributions
return rejection_sample_batch(target_probs, draft_probs, draft_tokens)
def top_k_softmax(logits: torch.Tensor, k: int) -> torch.Tensor:
"""Softmax with top-K truncation."""
top_k_vals, top_k_indices = torch.topk(logits, k, dim=-1)
# Zero out everything not in top-K
mask = torch.zeros_like(logits, dtype=torch.bool)
mask.scatter_(-1, top_k_indices, True)
logits = logits.masked_fill(~mask, float("-inf"))
return torch.softmax(logits, dim=-1)
Complete Rejection Sampler Implementation
import torch
from dataclasses import dataclass
@dataclass
class RejectionResult:
"""Result of batched rejection sampling."""
accepted_tokens: torch.Tensor # [batch_size, max_accepted] padded with -1
num_accepted: torch.Tensor # [batch_size]
bonus_tokens: torch.Tensor # [batch_size] resampled at rejection point
acceptance_rate: float # Average acceptance rate across batch
class RejectionSampler:
"""
Production-quality rejection sampler for speculative decoding.
Handles batched verification, temperature, top-K, and CFG.
"""
def __init__(self, strict_mode: bool = True):
self.strict_mode = strict_mode
self._total_proposed = 0
self._total_accepted = 0
@torch.no_grad()
def __call__(
self,
target_logits: torch.Tensor, # [batch_size, K+1, vocab_size]
draft_logits: torch.Tensor, # [batch_size, K, vocab_size]
draft_tokens: torch.Tensor, # [batch_size, K]
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
) -> RejectionResult:
batch_size, K = draft_tokens.shape
device = target_logits.device
# Apply temperature
if temperature != 1.0:
target_logits = target_logits / temperature
draft_logits = draft_logits / temperature
# Convert to probabilities (with optional top-K/top-P)
target_probs = self._to_probs(target_logits[:, :K], top_k, top_p)
draft_probs = self._to_probs(draft_logits, top_k, top_p)
# Extract p(x_i) and q(x_i) for each drafted token
p_at_x = self._gather_probs(target_probs, draft_tokens) # [B, K]
q_at_x = self._gather_probs(draft_probs, draft_tokens) # [B, K]
# Compute acceptance probabilities
alpha = torch.where(
q_at_x > 0,
torch.clamp(p_at_x / q_at_x, max=1.0),
torch.where(p_at_x > 0, torch.ones_like(p_at_x), torch.zeros_like(p_at_x)),
)
# Draw uniform random numbers
r = torch.rand(batch_size, K, device=device)
# Compute acceptance mask: accepted[b, k] = True iff all positions 0..k are accepted
accepted_per_pos = r <= alpha # [B, K]
# Chain: position k is accepted only if all 0..k-1 are also accepted
accepted_chain = torch.cumprod(accepted_per_pos.long(), dim=1) # [B, K]
num_accepted = accepted_chain.sum(dim=1) # [B]
# Build accepted_tokens tensor
accepted_tokens = torch.where(
accepted_chain.bool(),
draft_tokens,
torch.full_like(draft_tokens, -1),
)
# Compute bonus tokens at the first rejection point
bonus_tokens = self._resample_at_rejection(
target_probs, draft_probs, num_accepted, K, target_logits[:, K]
)
# Update stats
self._total_proposed += batch_size * K
self._total_accepted += num_accepted.sum().item()
return RejectionResult(
accepted_tokens=accepted_tokens,
num_accepted=num_accepted,
bonus_tokens=bonus_tokens,
acceptance_rate=num_accepted.float().mean().item() / K,
)
def _to_probs(self, logits, top_k, top_p):
"""Convert logits to probabilities with optional truncation."""
if top_k > 0:
top_vals, top_idx = torch.topk(logits, top_k, dim=-1)
mask = torch.zeros_like(logits, dtype=torch.bool)
mask.scatter_(-1, top_idx, True)
logits = logits.masked_fill(~mask, float("-inf"))
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1)
cumulative = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
mask = cumulative - torch.softmax(sorted_logits, dim=-1) >= top_p
sorted_logits[mask] = float("-inf")
logits = sorted_logits.scatter(-1, sorted_idx, sorted_logits)
return torch.softmax(logits, dim=-1)
def _gather_probs(self, probs, tokens):
"""Gather probabilities for specific tokens."""
return probs.gather(-1, tokens.unsqueeze(-1)).squeeze(-1)
def _resample_at_rejection(self, target_probs, draft_probs, num_accepted, K, next_logits):
"""Resample token at the first rejection point."""
batch_size = target_probs.shape[0]
device = target_probs.device
bonus = torch.zeros(batch_size, dtype=torch.long, device=device)
for b in range(batch_size):
n = num_accepted[b].item()
if n < K:
# Rejection at position n: sample from corrected distribution
corrected = torch.clamp(target_probs[b, n] - draft_probs[b, n], min=0)
csum = corrected.sum()
if csum > 0:
corrected = corrected / csum
else:
corrected = target_probs[b, n]
bonus[b] = torch.multinomial(corrected.unsqueeze(0), 1).squeeze()
else:
# All K accepted: sample bonus from position K+1
bonus_probs = torch.softmax(next_logits[b], dim=-1)
bonus[b] = torch.multinomial(bonus_probs.unsqueeze(0), 1).squeeze()
return bonus
@property
def overall_acceptance_rate(self):
if self._total_proposed == 0:
return 0.0
return self._total_accepted / self._total_proposed
Monitoring and Tuning
Key Metrics
class SpeculativeMetrics:
"""Track speculative decoding performance metrics."""
def __init__(self):
self.acceptance_rates = []
self.draft_latencies = []
self.verify_latencies = []
self.tokens_per_step = []
def record_step(self, result: RejectionResult, draft_ms: float, verify_ms: float, K: int):
self.acceptance_rates.append(result.acceptance_rate)
self.draft_latencies.append(draft_ms)
self.verify_latencies.append(verify_ms)
# Accepted tokens + 1 bonus token
self.tokens_per_step.append(result.num_accepted.float().mean().item() + 1)
def report(self):
avg_rate = sum(self.acceptance_rates) / len(self.acceptance_rates)
avg_draft = sum(self.draft_latencies) / len(self.draft_latencies)
avg_verify = sum(self.verify_latencies) / len(self.verify_latencies)
avg_tokens = sum(self.tokens_per_step) / len(self.tokens_per_step)
# Speedup = tokens_per_step / (draft_time + verify_time) * baseline_time
avg_step_time = avg_draft + avg_verify
effective_tps = avg_tokens / (avg_step_time / 1000)
return {
"acceptance_rate": avg_rate,
"avg_tokens_per_step": avg_tokens,
"avg_step_time_ms": avg_step_time,
"effective_tokens_per_sec": effective_tps,
}
When Speculative Decoding Helps
Speculative decoding is beneficial when:
where is the average number of accepted tokens, is the draft model latency, is the target model verification latency, and is the target model decode latency without speculation.
Since verification processes tokens in one forward pass (same cost as one decode step due to batching), . The speedup ratio simplifies to:
For (draft is 10x faster) and :