Autoregressive decoding is fundamentally latency-bound: each token depends on all previous tokens. Speculative decoding breaks this dependency by guessing multiple tokens, then verifying in parallel. When guesses are correct, we get free tokens.
The Core Algorithm
def speculative_decode(
target_model, # Large model (e.g., 70B)
draft_model, # Small model (e.g., 7B)
prompt_tokens,
num_speculative_tokens: int = 4,
temperature: float = 1.0
):
"""
Generate tokens using speculative decoding.
Time complexity per accepted token:
- Without speculation: O(T_target)
- With speculation: O(T_draft * K + T_target) / acceptance_rate
Where K = num_speculative_tokens
"""
generated = list(prompt_tokens)
while not should_stop(generated):
# Step 1: Draft model generates K speculative tokens
draft_tokens = []
draft_probs = []
for _ in range(num_speculative_tokens):
logits = draft_model(generated + draft_tokens)
probs = softmax(logits / temperature)
token = sample(probs)
draft_tokens.append(token)
draft_probs.append(probs[token])
# Step 2: Target model verifies ALL tokens in ONE forward pass
# Input: [prompt + generated + draft_tokens]
# Output: logits for positions [len(generated):len(generated)+K+1]
target_logits = target_model(generated + draft_tokens)
target_probs = softmax(target_logits / temperature)
# Step 3: Accept/reject using modified rejection sampling
num_accepted = 0
for i, draft_token in enumerate(draft_tokens):
# Acceptance probability
p_target = target_probs[i][draft_token]
p_draft = draft_probs[i]
acceptance_prob = min(1.0, p_target / p_draft)
if random.random() < acceptance_prob:
generated.append(draft_token)
num_accepted += 1
else:
# Rejection: sample from adjusted distribution
adjusted_probs = torch.clamp(target_probs[i] - draft_probs_full[i], min=0)
adjusted_probs = adjusted_probs / adjusted_probs.sum()
new_token = sample(adjusted_probs)
generated.append(new_token)
break # Stop accepting after first rejection
# If all K tokens accepted, sample one more from target
if num_accepted == num_speculative_tokens:
bonus_token = sample(target_probs[num_speculative_tokens])
generated.append(bonus_token)
return generated
Speculative decoding produces exactly the same distribution as standard decoding. The rejection sampling ensures this—no quality loss, only speedup.
Acceptance Rate Analysis
Speedup depends critically on acceptance rate α:
Speedup = (1 + α + α² + ... + αᴷ) / (1 + overhead)
≈ (1 - αᴷ⁺¹) / ((1 - α) * (1 + overhead))
For K=4, α=0.8: Speedup ≈ 2.95x (before overhead)
For K=4, α=0.5: Speedup ≈ 1.56x (before overhead)
Theoretical Speedup vs Acceptance Rate (K=4)
(x)Draft Model Selection
The draft model must balance speed and accuracy:
Draft Model Comparison for Llama-70B Target
| Draft Model | Acceptance Rate | Draft Time | Net Speedup |
|---|---|---|---|
| Llama-7B | 78% | 12ms | 2.1x |
| Llama-1B | 61% | 4ms | 1.9x |
| TinyLlama-1.1B | 58% | 3ms | 1.7x |
| Llama-70B (self) | 95% | 45ms | 1.3x |
| Medusa heads | 72% | 2ms | 2.4x |
Draft Model Requirements
def is_good_draft_model(draft_model, target_model, eval_data) -> bool:
"""
Evaluate draft model suitability.
"""
# Criterion 1: Fast enough
draft_time = benchmark_latency(draft_model)
target_time = benchmark_latency(target_model)
if draft_time > target_time * 0.15: # Draft should be <15% of target time
return False
# Criterion 2: High enough acceptance rate
acceptance_rate = measure_acceptance_rate(draft_model, target_model, eval_data)
if acceptance_rate < 0.5:
return False
# Criterion 3: Vocabulary compatibility
if draft_model.vocab_size != target_model.vocab_size:
return False # Or implement vocabulary mapping
return True
Tree-Structured Speculation
Instead of a single chain, speculate a tree of possibilities:
class TreeSpeculativeDecoder:
"""
Speculate multiple branches, verify all in parallel.
Higher acceptance probability but more complex verification.
"""
def __init__(self, draft_model, target_model, tree_config):
self.draft_model = draft_model
self.target_model = target_model
# Tree structure: [width at depth 0, width at depth 1, ...]
# e.g., [1, 3, 2] = 1 root, 3 children each, 2 grandchildren each
self.tree_structure = tree_config.tree_structure
def generate_speculation_tree(self, context) -> SpeculationTree:
"""Generate tree of speculative tokens."""
tree = SpeculationTree()
# BFS to generate tree
queue = [(context, tree.root)]
for depth, width in enumerate(self.tree_structure):
next_queue = []
for ctx, node in queue:
# Generate top-k continuations at this node
logits = self.draft_model(ctx)
top_k_tokens = torch.topk(logits, width).indices
for token in top_k_tokens:
child = node.add_child(token)
next_queue.append((ctx + [token], child))
queue = next_queue
return tree
def verify_tree(self, tree: SpeculationTree, context) -> List[int]:
"""
Verify entire tree in single target model forward pass.
Uses attention mask to handle tree structure.
"""
# Flatten tree to sequence with special attention mask
flat_tokens, attention_mask, position_ids = tree.flatten_for_verification()
# Single forward pass verifies all paths
target_logits = self.target_model(
input_ids=context + flat_tokens,
attention_mask=attention_mask,
position_ids=position_ids
)
# Find longest accepted path
return self._find_best_path(tree, target_logits)
When Speculative Decoding Helps
Speculative decoding is most beneficial when:
- Batch size = 1: Memory bandwidth bound, compute underutilized
- High acceptance rate: Draft model matches target well
- Long generation: Amortizes fixed overhead
Speculative Decoding Speedup by Scenario
| Scenario | Batch Size | Task | Speedup |
|---|---|---|---|
| Interactive chat | 1 | General | 2.1x |
| Code generation | 1 | Code | 2.4x |
| Translation | 1 | MT | 1.8x |
| Batch inference | 32 | General | 0.9x |
| Very long output | 1 | Story | 2.6x |
With larger batches, the target model is already compute-bound. Adding draft model overhead provides little benefit and can hurt throughput.
Implementation in vLLM
# vLLM speculative decoding configuration
from vllm import LLM, SamplingParams
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
speculative_model="meta-llama/Llama-2-7b-hf",
num_speculative_tokens=4,
speculative_draft_tensor_parallel_size=1, # Draft on single GPU
)
# Sampling params for speculative decoding
params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=512,
)
# Generate with speculation
outputs = llm.generate(prompts, params)
Optimizing Acceptance Rate
Temperature Matching
def optimize_temperature_for_speculation(
draft_model,
target_model,
eval_prompts,
target_temperature: float
) -> float:
"""
Find draft temperature that maximizes acceptance rate
for a given target temperature.
"""
best_acceptance = 0
best_draft_temp = target_temperature
for draft_temp in np.linspace(0.5, 2.0, 16):
acceptance = measure_acceptance_rate(
draft_model, draft_temp,
target_model, target_temperature,
eval_prompts
)
if acceptance > best_acceptance:
best_acceptance = acceptance
best_draft_temp = draft_temp
return best_draft_temp
Domain-Specific Draft Models
Fine-tune draft model on target domain:
# Fine-tune draft model to match target on domain data
from transformers import Trainer, TrainingArguments
# Use target model outputs as training signal
def create_distillation_dataset(target_model, domain_prompts):
dataset = []
for prompt in domain_prompts:
# Generate with target model
output = target_model.generate(prompt, do_sample=False)
dataset.append({
'input_ids': prompt,
'labels': output
})
return dataset
training_args = TrainingArguments(
output_dir='./draft_finetuned',
num_train_epochs=3,
per_device_train_batch_size=8,
learning_rate=1e-5,
)
trainer = Trainer(
model=draft_model,
args=training_args,
train_dataset=create_distillation_dataset(target_model, domain_data)
)
trainer.train()
# Domain-tuned draft model typically achieves 5-15% higher acceptance rate
Conclusion
Speculative decoding trades compute for latency, achieving 1.5-2.5x speedup for single-request inference. Key success factors:
- Fast draft model (under 15% of target time)
- High acceptance rate (over 70% for meaningful speedup)
- Batch size 1 (otherwise already compute-bound)
- Domain matching between draft and target
For batch inference or throughput-focused deployments, stick with standard decoding. For latency-critical single-user scenarios, speculative decoding is highly effective.