Standard speculative decoding uses a separate draft model: generate K tokens with the draft model, verify all K with the target model in one forward pass. The verification is “free” because a single forward pass costs the same whether it processes 1 or K tokens (memory-bandwidth-bound during decode). Expected speedup: where is the acceptance rate.
But maintaining a separate draft model has costs: extra GPU memory (a 7B draft alongside a 70B target = 14 GB overhead), draft model loading, and the engineering complexity of running two models in one serving system. The next generation of speculative decoding eliminates the draft model entirely.
Medusa: Self-Drafting with Extra Prediction Heads
Medusa adds lightweight prediction heads to the target model itself. Each head predicts a different future token:
- Head 0 (standard): predicts token (the next token)
- Head 1 (Medusa): predicts token (two tokens ahead)
- Head 2 (Medusa): predicts token (three tokens ahead)
- Head K (Medusa): predicts token
import torch
import torch.nn as nn
class MedusaHead(nn.Module):
"""Single Medusa prediction head for future token prediction."""
def __init__(self, hidden_dim, vocab_size):
super().__init__()
# Simple 1-hidden-layer MLP
self.linear1 = nn.Linear(hidden_dim, hidden_dim)
self.act = nn.SiLU()
self.linear2 = nn.Linear(hidden_dim, vocab_size)
def forward(self, hidden_states):
# hidden_states: [batch, seq_len, hidden_dim]
return self.linear2(self.act(self.linear1(hidden_states)))
class MedusaModel(nn.Module):
"""Target model augmented with Medusa heads."""
def __init__(self, base_model, num_medusa_heads=4):
super().__init__()
self.base_model = base_model
hidden_dim = base_model.config.hidden_size
vocab_size = base_model.config.vocab_size
self.medusa_heads = nn.ModuleList([
MedusaHead(hidden_dim, vocab_size)
for _ in range(num_medusa_heads)
])
def forward(self, input_ids, **kwargs):
# Run base model normally
outputs = self.base_model(input_ids, **kwargs)
hidden = outputs.last_hidden_state # [B, S, D]
# Standard next-token prediction from the base model's LM head
base_logits = self.base_model.lm_head(hidden)
# Medusa heads predict future tokens
medusa_logits = [head(hidden) for head in self.medusa_heads]
return base_logits, medusa_logits
Training cost: Only the Medusa heads are trained (base model weights frozen). Each head has parameters. For D=4096, V=128K: ~530M params per head. 4 heads: ~2.1B params. Training: ~1 day on 8 GPUs with distillation from the base model’s own outputs.
Memory overhead: ~4.2 GB in FP16 for 4 heads. Compare to 14 GB for a separate 7B draft model. Medusa saves 70% of the draft memory.
Medusa Tree-Based Verification
Medusa doesn’t verify a single sequence — it constructs a token tree and verifies all branches in one forward pass:
def medusa_generate_step(model, input_ids, kv_cache):
"""One Medusa generation step with tree verification."""
# 1. Get predictions from all heads
base_logits, medusa_logits = model(input_ids, kv_cache=kv_cache)
# 2. Build token tree
# Head 0 predicts next token: top-k candidates
t1_candidates = base_logits[:, -1, :].topk(k=5).indices # 5 candidates
# Head 1 predicts t+2 for EACH t1 candidate
t2_candidates = medusa_logits[0][:, -1, :].topk(k=3).indices # 3 per branch
# Head 2 predicts t+3 for each (t1, t2) path
t3_candidates = medusa_logits[1][:, -1, :].topk(k=2).indices # 2 per branch
# Total tree: 5 * 3 * 2 = 30 candidate sequences of length 3
# 3. Verify all 30 sequences in ONE forward pass using tree attention
# Construct tree attention mask that allows each node to attend to its ancestors
tree_tokens = construct_tree(t1_candidates, t2_candidates, t3_candidates)
verified_logits = model.base_model(tree_tokens, tree_attention_mask=...)
# 4. Accept the longest matching prefix
accepted = find_longest_accepted_path(verified_logits, tree_tokens)
return accepted # Typically 2-4 tokens accepted per step
Medusa Performance (Vicuna-33B on A100)
| Configuration | Tokens/Step | Speedup vs Greedy | Memory Overhead |
|---|---|---|---|
| Greedy (no speculation) | 1.0 | 1.0x | 0 GB |
| Medusa-1 (1 head) | 1.8 | 1.6x | 1.1 GB |
| Medusa-2 (2 heads, tree) | 2.3 | 2.0x | 2.1 GB |
| Medusa-4 (4 heads, tree) | 2.7 | 2.3x | 4.2 GB |
| Separate 7B draft | 2.5 | 2.2x | 14 GB |
EAGLE: Hidden-State Drafting
EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) takes a different approach: instead of predicting future tokens from the last hidden state, it predicts the hidden state itself at future positions, then uses the base model’s LM head to convert those hidden states to tokens.
class EAGLEDrafter(nn.Module):
"""Predicts future hidden states, not tokens directly."""
def __init__(self, hidden_dim):
super().__init__()
# Predict next hidden state from current hidden state + current token embedding
self.fc = nn.Linear(hidden_dim * 2, hidden_dim)
def forward(self, hidden_state, token_embedding):
"""Predict hidden state at next position."""
combined = torch.cat([hidden_state, token_embedding], dim=-1)
return self.fc(combined) # Predicted hidden state
Why hidden states instead of tokens? The base model’s LM head is already optimized to convert hidden states to accurate token predictions. By predicting in hidden-state space and reusing the LM head, EAGLE leverages the base model’s existing prediction capability. The drafter only needs to predict how the hidden state evolves — a simpler task than predicting tokens directly.
EAGLE v2 improvement: Instead of a fixed draft tree shape, EAGLE v2 dynamically adjusts the tree based on confidence. High-confidence branches get deeper exploration, low-confidence branches are pruned. This adaptive tree yields 20-50% more accepted tokens than fixed trees.
Speculative Decoding: Tokens Accepted per Step
(tokens/step)Lookahead Decoding
Lookahead takes yet another approach: instead of a separate drafter, it uses the target model’s own n-gram patterns to predict future tokens.
The idea: during generation, the model often produces predictable sequences (common phrases, code patterns, formatting). Lookahead maintains a cache of the model’s own n-gram predictions and uses them as speculative drafts.
class LookaheadCache:
"""Cache n-gram predictions from the target model itself."""
def __init__(self, window_size=7, ngram_size=3):
self.window_size = window_size
self.ngram_size = ngram_size
self.cache = {} # (token_n-2, token_n-1) -> predicted token_n
def update(self, tokens):
"""Update cache with observed n-gram patterns."""
for i in range(len(tokens) - self.ngram_size + 1):
key = tuple(tokens[i:i+self.ngram_size-1])
value = tokens[i+self.ngram_size-1]
self.cache[key] = value
def predict(self, last_tokens, max_draft=5):
"""Predict future tokens using cached n-gram patterns."""
draft = []
current = tuple(last_tokens[-(self.ngram_size-1):])
for _ in range(max_draft):
if current in self.cache:
next_token = self.cache[current]
draft.append(next_token)
current = current[1:] + (next_token,)
else:
break
return draft
Advantage: Zero training required. No extra model or heads. Works with any model out of the box. Disadvantage: Lower acceptance rate than trained drafters. Works best for repetitive or formulaic text.
When Each Approach Wins
Self-Speculative Method Selection Guide
| Method | Best For | Memory Overhead | Training Required | Typical Speedup |
|---|---|---|---|---|
| Separate draft model | Maximum acceptance rate | 14+ GB | No (use existing small model) | 2.0-2.5x |
| Medusa | Memory-constrained serving | 2-4 GB | Yes (~1 day) | 1.8-2.3x |
| EAGLE v2 | Maximum self-draft performance | 1-2 GB | Yes (~1 day) | 2.5-3.6x |
| Lookahead | Zero-setup, any model | 0 GB | No | 1.3-1.8x |
| Multi-token prediction | Models trained for it | 0 GB (built-in) | During pretraining | 1.8-2.5x |
For new model deployments: use multi-token prediction (train the model with MTP heads from the start, like DeepSeek V3). For existing models: EAGLE v2 provides the best speedup-to-cost ratio. For quick experiments: Lookahead requires zero setup. Medusa is the sweet spot when you want self-drafting without EAGLE’s complexity.
Reviewer Agent Validation
Challenge: Using only this post, implement a simple Medusa-style generation step that: (1) gets base logits + one Medusa head’s logits, (2) picks top-1 from each, (3) verifies the Medusa token by checking if the target model agrees.
Expected:
def simple_medusa_step(model, medusa_head, input_ids, kv_cache):
outputs = model(input_ids, kv_cache=kv_cache)
hidden = outputs.last_hidden_state[:, -1:, :]
# Base prediction: token t+1
t1 = model.lm_head(hidden).argmax(dim=-1) # [B, 1]
# Medusa prediction: token t+2
t2_draft = medusa_head(hidden).argmax(dim=-1) # [B, 1]
# Verify t2 by running target model on t1
verify_out = model(t1, kv_cache=kv_cache)
t2_target = model.lm_head(verify_out.last_hidden_state[:, -1:, :]).argmax(dim=-1)
if (t2_draft == t2_target).all():
return torch.cat([t1, t2_draft], dim=-1) # Accept both
else:
return t1 # Accept only t1