Part of Series Inference Optimization Timeline 31 of 23
1 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 2 KV Cache: The Hidden Memory Giant in LLM Serving 3 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 4 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 5 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 6 Continuous Batching: The Complete Guide to LLM Inference Scheduling 7 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 8 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 9 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 10 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 11 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 12 Mamba and State Space Models: The O(n) Alternative to Attention 13 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 14 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 15 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 16 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 17 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 18 Memory Pool Management: Slab Allocators for GPU Inference 19 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 20 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 21 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 22 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 23 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification

Standard speculative decoding uses a separate draft model: generate K tokens with the draft model, verify all K with the target model in one forward pass. The verification is “free” because a single forward pass costs the same whether it processes 1 or K tokens (memory-bandwidth-bound during decode). Expected speedup: 11αK+1/(1α)\frac{1}{1 - \alpha^{K+1}} / (1 - \alpha) where α\alpha is the acceptance rate.

But maintaining a separate draft model has costs: extra GPU memory (a 7B draft alongside a 70B target = 14 GB overhead), draft model loading, and the engineering complexity of running two models in one serving system. The next generation of speculative decoding eliminates the draft model entirely.

Medusa: Self-Drafting with Extra Prediction Heads

Medusa adds lightweight prediction heads to the target model itself. Each head predicts a different future token:

  • Head 0 (standard): predicts token t+1t+1 (the next token)
  • Head 1 (Medusa): predicts token t+2t+2 (two tokens ahead)
  • Head 2 (Medusa): predicts token t+3t+3 (three tokens ahead)
  • Head K (Medusa): predicts token t+K+1t+K+1
import torch
import torch.nn as nn

class MedusaHead(nn.Module):
    """Single Medusa prediction head for future token prediction."""

    def __init__(self, hidden_dim, vocab_size):
        super().__init__()
        # Simple 1-hidden-layer MLP
        self.linear1 = nn.Linear(hidden_dim, hidden_dim)
        self.act = nn.SiLU()
        self.linear2 = nn.Linear(hidden_dim, vocab_size)

    def forward(self, hidden_states):
        # hidden_states: [batch, seq_len, hidden_dim]
        return self.linear2(self.act(self.linear1(hidden_states)))

class MedusaModel(nn.Module):
    """Target model augmented with Medusa heads."""

    def __init__(self, base_model, num_medusa_heads=4):
        super().__init__()
        self.base_model = base_model
        hidden_dim = base_model.config.hidden_size
        vocab_size = base_model.config.vocab_size

        self.medusa_heads = nn.ModuleList([
            MedusaHead(hidden_dim, vocab_size)
            for _ in range(num_medusa_heads)
        ])

    def forward(self, input_ids, **kwargs):
        # Run base model normally
        outputs = self.base_model(input_ids, **kwargs)
        hidden = outputs.last_hidden_state  # [B, S, D]

        # Standard next-token prediction from the base model's LM head
        base_logits = self.base_model.lm_head(hidden)

        # Medusa heads predict future tokens
        medusa_logits = [head(hidden) for head in self.medusa_heads]

        return base_logits, medusa_logits

Training cost: Only the Medusa heads are trained (base model weights frozen). Each head has 2×D2+D×V2 \times D^2 + D \times V parameters. For D=4096, V=128K: ~530M params per head. 4 heads: ~2.1B params. Training: ~1 day on 8 GPUs with distillation from the base model’s own outputs.

Memory overhead: ~4.2 GB in FP16 for 4 heads. Compare to 14 GB for a separate 7B draft model. Medusa saves 70% of the draft memory.

Medusa Tree-Based Verification

Medusa doesn’t verify a single sequence — it constructs a token tree and verifies all branches in one forward pass:

def medusa_generate_step(model, input_ids, kv_cache):
    """One Medusa generation step with tree verification."""
    # 1. Get predictions from all heads
    base_logits, medusa_logits = model(input_ids, kv_cache=kv_cache)

    # 2. Build token tree
    # Head 0 predicts next token: top-k candidates
    t1_candidates = base_logits[:, -1, :].topk(k=5).indices  # 5 candidates

    # Head 1 predicts t+2 for EACH t1 candidate
    t2_candidates = medusa_logits[0][:, -1, :].topk(k=3).indices  # 3 per branch

    # Head 2 predicts t+3 for each (t1, t2) path
    t3_candidates = medusa_logits[1][:, -1, :].topk(k=2).indices  # 2 per branch

    # Total tree: 5 * 3 * 2 = 30 candidate sequences of length 3

    # 3. Verify all 30 sequences in ONE forward pass using tree attention
    # Construct tree attention mask that allows each node to attend to its ancestors
    tree_tokens = construct_tree(t1_candidates, t2_candidates, t3_candidates)
    verified_logits = model.base_model(tree_tokens, tree_attention_mask=...)

    # 4. Accept the longest matching prefix
    accepted = find_longest_accepted_path(verified_logits, tree_tokens)
    return accepted  # Typically 2-4 tokens accepted per step
📊

Medusa Performance (Vicuna-33B on A100)

ConfigurationTokens/StepSpeedup vs GreedyMemory Overhead
Greedy (no speculation) 1.0 1.0x 0 GB
Medusa-1 (1 head) 1.8 1.6x 1.1 GB
Medusa-2 (2 heads, tree) 2.3 2.0x 2.1 GB
Medusa-4 (4 heads, tree) 2.7 2.3x 4.2 GB
Separate 7B draft 2.5 2.2x 14 GB
Note: Medusa-4 matches or exceeds separate draft performance at 70% less memory overhead.

EAGLE: Hidden-State Drafting

EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) takes a different approach: instead of predicting future tokens from the last hidden state, it predicts the hidden state itself at future positions, then uses the base model’s LM head to convert those hidden states to tokens.

class EAGLEDrafter(nn.Module):
    """Predicts future hidden states, not tokens directly."""

    def __init__(self, hidden_dim):
        super().__init__()
        # Predict next hidden state from current hidden state + current token embedding
        self.fc = nn.Linear(hidden_dim * 2, hidden_dim)

    def forward(self, hidden_state, token_embedding):
        """Predict hidden state at next position."""
        combined = torch.cat([hidden_state, token_embedding], dim=-1)
        return self.fc(combined)  # Predicted hidden state

Why hidden states instead of tokens? The base model’s LM head is already optimized to convert hidden states to accurate token predictions. By predicting in hidden-state space and reusing the LM head, EAGLE leverages the base model’s existing prediction capability. The drafter only needs to predict how the hidden state evolves — a simpler task than predicting tokens directly.

EAGLE v2 improvement: Instead of a fixed draft tree shape, EAGLE v2 dynamically adjusts the tree based on confidence. High-confidence branches get deeper exploration, low-confidence branches are pruned. This adaptive tree yields 20-50% more accepted tokens than fixed trees.

Speculative Decoding: Tokens Accepted per Step

(tokens/step)
Separate 7B draft
2.5 tokens/step
Medusa-4
2.7 tokens/step
EAGLE v1
3.1 tokens/step
EAGLE v2 (adaptive) Best self-draft
3.6 tokens/step

Lookahead Decoding

Lookahead takes yet another approach: instead of a separate drafter, it uses the target model’s own n-gram patterns to predict future tokens.

The idea: during generation, the model often produces predictable sequences (common phrases, code patterns, formatting). Lookahead maintains a cache of the model’s own n-gram predictions and uses them as speculative drafts.

class LookaheadCache:
    """Cache n-gram predictions from the target model itself."""

    def __init__(self, window_size=7, ngram_size=3):
        self.window_size = window_size
        self.ngram_size = ngram_size
        self.cache = {}  # (token_n-2, token_n-1) -> predicted token_n

    def update(self, tokens):
        """Update cache with observed n-gram patterns."""
        for i in range(len(tokens) - self.ngram_size + 1):
            key = tuple(tokens[i:i+self.ngram_size-1])
            value = tokens[i+self.ngram_size-1]
            self.cache[key] = value

    def predict(self, last_tokens, max_draft=5):
        """Predict future tokens using cached n-gram patterns."""
        draft = []
        current = tuple(last_tokens[-(self.ngram_size-1):])
        for _ in range(max_draft):
            if current in self.cache:
                next_token = self.cache[current]
                draft.append(next_token)
                current = current[1:] + (next_token,)
            else:
                break
        return draft

Advantage: Zero training required. No extra model or heads. Works with any model out of the box. Disadvantage: Lower acceptance rate than trained drafters. Works best for repetitive or formulaic text.

When Each Approach Wins

📊

Self-Speculative Method Selection Guide

MethodBest ForMemory OverheadTraining RequiredTypical Speedup
Separate draft model Maximum acceptance rate 14+ GB No (use existing small model) 2.0-2.5x
Medusa Memory-constrained serving 2-4 GB Yes (~1 day) 1.8-2.3x
EAGLE v2 Maximum self-draft performance 1-2 GB Yes (~1 day) 2.5-3.6x
Lookahead Zero-setup, any model 0 GB No 1.3-1.8x
Multi-token prediction Models trained for it 0 GB (built-in) During pretraining 1.8-2.5x
💡 The Practical Choice in 2026

For new model deployments: use multi-token prediction (train the model with MTP heads from the start, like DeepSeek V3). For existing models: EAGLE v2 provides the best speedup-to-cost ratio. For quick experiments: Lookahead requires zero setup. Medusa is the sweet spot when you want self-drafting without EAGLE’s complexity.

Reviewer Agent Validation

Challenge: Using only this post, implement a simple Medusa-style generation step that: (1) gets base logits + one Medusa head’s logits, (2) picks top-1 from each, (3) verifies the Medusa token by checking if the target model agrees.

Expected:

def simple_medusa_step(model, medusa_head, input_ids, kv_cache):
    outputs = model(input_ids, kv_cache=kv_cache)
    hidden = outputs.last_hidden_state[:, -1:, :]

    # Base prediction: token t+1
    t1 = model.lm_head(hidden).argmax(dim=-1)  # [B, 1]

    # Medusa prediction: token t+2
    t2_draft = medusa_head(hidden).argmax(dim=-1)  # [B, 1]

    # Verify t2 by running target model on t1
    verify_out = model(t1, kv_cache=kv_cache)
    t2_target = model.lm_head(verify_out.last_hidden_state[:, -1:, :]).argmax(dim=-1)

    if (t2_draft == t2_target).all():
        return torch.cat([t1, t2_draft], dim=-1)  # Accept both
    else:
        return t1  # Accept only t1