Autoregressive decoding is fundamentally latency-bound: each token depends on all previous tokens. Speculative decoding breaks this dependency by guessing multiple tokens, then verifying in parallel. When guesses are correct, we get free tokens.

The Core Algorithm

def speculative_decode(
    target_model,      # Large model (e.g., 70B)
    draft_model,       # Small model (e.g., 7B)
    prompt_tokens,
    num_speculative_tokens: int = 4,
    temperature: float = 1.0
):
    """
    Generate tokens using speculative decoding.
    
    Time complexity per accepted token:
    - Without speculation: O(T_target)
    - With speculation: O(T_draft * K + T_target) / acceptance_rate
    
    Where K = num_speculative_tokens
    """
    generated = list(prompt_tokens)
    
    while not should_stop(generated):
        # Step 1: Draft model generates K speculative tokens
        draft_tokens = []
        draft_probs = []
        
        for _ in range(num_speculative_tokens):
            logits = draft_model(generated + draft_tokens)
            probs = softmax(logits / temperature)
            token = sample(probs)
            draft_tokens.append(token)
            draft_probs.append(probs[token])
        
        # Step 2: Target model verifies ALL tokens in ONE forward pass
        # Input: [prompt + generated + draft_tokens]
        # Output: logits for positions [len(generated):len(generated)+K+1]
        target_logits = target_model(generated + draft_tokens)
        target_probs = softmax(target_logits / temperature)
        
        # Step 3: Accept/reject using modified rejection sampling
        num_accepted = 0
        for i, draft_token in enumerate(draft_tokens):
            # Acceptance probability
            p_target = target_probs[i][draft_token]
            p_draft = draft_probs[i]
            
            acceptance_prob = min(1.0, p_target / p_draft)
            
            if random.random() < acceptance_prob:
                generated.append(draft_token)
                num_accepted += 1
            else:
                # Rejection: sample from adjusted distribution
                adjusted_probs = torch.clamp(target_probs[i] - draft_probs_full[i], min=0)
                adjusted_probs = adjusted_probs / adjusted_probs.sum()
                new_token = sample(adjusted_probs)
                generated.append(new_token)
                break  # Stop accepting after first rejection
        
        # If all K tokens accepted, sample one more from target
        if num_accepted == num_speculative_tokens:
            bonus_token = sample(target_probs[num_speculative_tokens])
            generated.append(bonus_token)
    
    return generated
ℹ️ Mathematical Guarantee

Speculative decoding produces exactly the same distribution as standard decoding. The rejection sampling ensures this—no quality loss, only speedup.

Acceptance Rate Analysis

Speedup depends critically on acceptance rate α:

Speedup = (1 + α + α² + ... + αᴷ) / (1 + overhead)
        ≈ (1 - αᴷ⁺¹) / ((1 - α) * (1 + overhead))

For K=4, α=0.8: Speedup ≈ 2.95x (before overhead)
For K=4, α=0.5: Speedup ≈ 1.56x (before overhead)

Theoretical Speedup vs Acceptance Rate (K=4)

(x)
α = 0.9
3.44 x
α = 0.8
2.95 x
α = 0.7
2.5 x
α = 0.6
2.09 x
α = 0.5
1.72 x
α = 0.4
1.41 x

Draft Model Selection

The draft model must balance speed and accuracy:

📊

Draft Model Comparison for Llama-70B Target

Draft ModelAcceptance RateDraft TimeNet Speedup
Llama-7B 78% 12ms 2.1x
Llama-1B 61% 4ms 1.9x
TinyLlama-1.1B 58% 3ms 1.7x
Llama-70B (self) 95% 45ms 1.3x
Medusa heads 72% 2ms 2.4x
Note: A100-80GB, K=4, batch_size=1, code generation task

Draft Model Requirements

def is_good_draft_model(draft_model, target_model, eval_data) -> bool:
    """
    Evaluate draft model suitability.
    """
    # Criterion 1: Fast enough
    draft_time = benchmark_latency(draft_model)
    target_time = benchmark_latency(target_model)
    
    if draft_time > target_time * 0.15:  # Draft should be <15% of target time
        return False
    
    # Criterion 2: High enough acceptance rate
    acceptance_rate = measure_acceptance_rate(draft_model, target_model, eval_data)
    
    if acceptance_rate < 0.5:
        return False
    
    # Criterion 3: Vocabulary compatibility
    if draft_model.vocab_size != target_model.vocab_size:
        return False  # Or implement vocabulary mapping
    
    return True

Tree-Structured Speculation

Instead of a single chain, speculate a tree of possibilities:

class TreeSpeculativeDecoder:
    """
    Speculate multiple branches, verify all in parallel.
    Higher acceptance probability but more complex verification.
    """
    
    def __init__(self, draft_model, target_model, tree_config):
        self.draft_model = draft_model
        self.target_model = target_model
        # Tree structure: [width at depth 0, width at depth 1, ...]
        # e.g., [1, 3, 2] = 1 root, 3 children each, 2 grandchildren each
        self.tree_structure = tree_config.tree_structure
    
    def generate_speculation_tree(self, context) -> SpeculationTree:
        """Generate tree of speculative tokens."""
        tree = SpeculationTree()
        
        # BFS to generate tree
        queue = [(context, tree.root)]
        
        for depth, width in enumerate(self.tree_structure):
            next_queue = []
            for ctx, node in queue:
                # Generate top-k continuations at this node
                logits = self.draft_model(ctx)
                top_k_tokens = torch.topk(logits, width).indices
                
                for token in top_k_tokens:
                    child = node.add_child(token)
                    next_queue.append((ctx + [token], child))
            
            queue = next_queue
        
        return tree
    
    def verify_tree(self, tree: SpeculationTree, context) -> List[int]:
        """
        Verify entire tree in single target model forward pass.
        Uses attention mask to handle tree structure.
        """
        # Flatten tree to sequence with special attention mask
        flat_tokens, attention_mask, position_ids = tree.flatten_for_verification()
        
        # Single forward pass verifies all paths
        target_logits = self.target_model(
            input_ids=context + flat_tokens,
            attention_mask=attention_mask,
            position_ids=position_ids
        )
        
        # Find longest accepted path
        return self._find_best_path(tree, target_logits)

When Speculative Decoding Helps

Speculative decoding is most beneficial when:

  1. Batch size = 1: Memory bandwidth bound, compute underutilized
  2. High acceptance rate: Draft model matches target well
  3. Long generation: Amortizes fixed overhead
📊

Speculative Decoding Speedup by Scenario

ScenarioBatch SizeTaskSpeedup
Interactive chat 1 General 2.1x
Code generation 1 Code 2.4x
Translation 1 MT 1.8x
Batch inference 32 General 0.9x
Very long output 1 Story 2.6x
Note: Llama-70B target, Llama-7B draft, K=4
⚠️ Batch Size > 1

With larger batches, the target model is already compute-bound. Adding draft model overhead provides little benefit and can hurt throughput.

Implementation in vLLM

# vLLM speculative decoding configuration
from vllm import LLM, SamplingParams

llm = LLM(
    model="meta-llama/Llama-2-70b-hf",
    speculative_model="meta-llama/Llama-2-7b-hf",
    num_speculative_tokens=4,
    speculative_draft_tensor_parallel_size=1,  # Draft on single GPU
)

# Sampling params for speculative decoding
params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    max_tokens=512,
)

# Generate with speculation
outputs = llm.generate(prompts, params)

Optimizing Acceptance Rate

Temperature Matching

def optimize_temperature_for_speculation(
    draft_model,
    target_model,
    eval_prompts,
    target_temperature: float
) -> float:
    """
    Find draft temperature that maximizes acceptance rate
    for a given target temperature.
    """
    best_acceptance = 0
    best_draft_temp = target_temperature
    
    for draft_temp in np.linspace(0.5, 2.0, 16):
        acceptance = measure_acceptance_rate(
            draft_model, draft_temp,
            target_model, target_temperature,
            eval_prompts
        )
        
        if acceptance > best_acceptance:
            best_acceptance = acceptance
            best_draft_temp = draft_temp
    
    return best_draft_temp

Domain-Specific Draft Models

Fine-tune draft model on target domain:

# Fine-tune draft model to match target on domain data
from transformers import Trainer, TrainingArguments

# Use target model outputs as training signal
def create_distillation_dataset(target_model, domain_prompts):
    dataset = []
    for prompt in domain_prompts:
        # Generate with target model
        output = target_model.generate(prompt, do_sample=False)
        dataset.append({
            'input_ids': prompt,
            'labels': output
        })
    return dataset

training_args = TrainingArguments(
    output_dir='./draft_finetuned',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    learning_rate=1e-5,
)

trainer = Trainer(
    model=draft_model,
    args=training_args,
    train_dataset=create_distillation_dataset(target_model, domain_data)
)

trainer.train()
# Domain-tuned draft model typically achieves 5-15% higher acceptance rate

Conclusion

Speculative decoding trades compute for latency, achieving 1.5-2.5x speedup for single-request inference. Key success factors:

  1. Fast draft model (under 15% of target time)
  2. High acceptance rate (over 70% for meaningful speedup)
  3. Batch size 1 (otherwise already compute-bound)
  4. Domain matching between draft and target

For batch inference or throughput-focused deployments, stick with standard decoding. For latency-critical single-user scenarios, speculative decoding is highly effective.