Part of Series vLLM v1 & Omni Internals 13 of 25
1 vLLM v1 Block Manager: Deconstructing KV Cache Memory Management at the Pointer Level 2 vLLM v1 Disaggregated Serving: The E/P/D/G Pipeline and Multimodal-First Architecture 3 vLLM OmniConnector: Async Multimodal Token Lifecycle Management 4 vLLM v1 Unified Scheduler: One Queue, No Prefill/Decode Distinction, and Persistent Batches 5 vLLM v1 Attention Backends: FlashAttention, FlashInfer, and PagedAttention Selection Logic 6 vLLM v1 Rejection Sampler: Native CFG and Speculative Verification Kernels 7 vLLM v1 Tensor Parallelism: Symmetric Workers, Incremental Updates, and NCCL Optimization 8 vLLM v1 Structured Output: The Native Grammar Engine and Token Mask Caching 9 vLLM v1 Prefix Caching: Hash Chains, LRU Eviction, and Hit Rate Optimization 10 vLLM v1 Multi-LoRA: Adapter Scheduling, Memory Management, and Batched Inference 11 vLLM v1 Performance Profiling: Finding and Fixing Bottlenecks in Production 12 vLLM v1 Speculative Decoding: Draft Model Integration and Token Verification Pipeline 13 vLLM v1 Vision Encoder: ViT Integration, Image Preprocessing, and Visual Token Pipeline 14 vLLM v1 Model Loading: Weight Distribution, safetensors Deserialization, and Progressive Startup 15 vLLM v1 Request Cancellation and Early Stopping: Freeing Resources Mid-Generation 16 vLLM v1 Quantized Inference: GPTQ, AWQ, FP8 Kernel Selection 17 vLLM v1 Distributed Execution: Ray Integration and Multi-Node Coordination 18 vLLM v1 KV Cache Offloading: GPU to CPU to SSD Tiered Memory 19 vLLM v1 Async Output: Detokenization, Streaming, and Queue Management 20 vLLM v1 Video and Audio: Temporal Encoding and Multi-Modal Batching 21 vLLM v1 Benchmarking: Systematic Optimization for Your Workload 22 vLLM v1 Error Handling: CUDA OOM Recovery, Request Retry, and Graceful Degradation 23 vLLM v1 Configuration Guide: gpu_memory_utilization, max_num_seqs, and Every Key Parameter 24 vLLM v1 Plugin Architecture: Custom Samplers, Schedulers, and Attention Backends 25 vLLM v1 Production Checklist: From Development to Reliable 24/7 Serving

Speculative decoding can double throughput on latency-sensitive workloads β€” or it can make things 20% slower. The difference depends on acceptance rate: if your draft model is good enough that 70% of tokens are accepted, you generate 3-4 tokens per target forward pass and throughput soars. If acceptance rate drops below 40%, the verification overhead dominates and you would have been better off just running the target model directly. vLLM v1 implements speculative decoding with deep scheduler integration, making the draft-verify cycle automatic: the scheduler coordinates draft generation and target verification within a single batch, the KV cache manager handles speculated tokens that may get rejected, and the acceptance criterion guarantees output quality matches the target model exactly.

The Speculative Decoding Pipeline

Step-by-Step Execution

Each speculative decoding iteration has three phases:

Phase 1: Draft Generation (cheap)
  - Draft model generates K tokens autoregressively
  - Each token: one forward pass of the draft model
  - Total: K forward passes of the small model
  - Output: K draft tokens + K draft probability distributions

Phase 2: Target Verification (expensive, but batched)
  - Target model processes all K draft positions in ONE forward pass
  - Also computes logits for position K+1 (bonus token)
  - Total: one forward pass of the large model
  - Output: K+1 target probability distributions

Phase 3: Acceptance/Rejection
  - Compare draft and target distributions at each position
  - Accept tokens sequentially until the first rejection
  - Resample at the rejection point from corrected distribution
  - Output: n accepted tokens + 1 bonus token (n+1 total, 0 <= n <= K)

The expected speedup:

speedup=(nˉ+1)⋅ttargetK⋅tdraft+ttarget\text{speedup} = \frac{(\bar{n} + 1) \cdot t_{\text{target}}}{K \cdot t_{\text{draft}} + t_{\text{target}}}

where nˉ\bar{n} is the average number of accepted tokens, ttargett_{\text{target}} is one target forward pass latency, and tdraftt_{\text{draft}} is one draft forward pass latency.

πŸ“Š

Speculative Decoding Speedup (Llama 70B target, K=5)

Draft ModelDraft LatencyAccept RateAvg Tokens/StepSpeedup
Llama 8B 3.2ms 78% 4.4 2.8x
Llama 8B (distilled) 3.2ms 85% 4.8 3.0x
Llama 1B 0.8ms 62% 3.6 2.9x
MLP head (2-layer) 0.2ms 55% 3.2 2.7x
Medusa (2 heads) 0.3ms 70% 4.0 3.3x
Eagle (autoregressive head) 0.5ms 82% 4.6 3.5x
Note: Measurements on 8xH100 with TP=8. Draft latency is per-token. Speedup is wall-clock improvement in tokens/sec for a single request.

Draft Model Management

Loading the Draft Model

vLLM loads the draft model alongside the target model. The draft model can be:

  1. A separate smaller model (e.g., Llama 8B drafting for Llama 70B)
  2. A quantized version of the target (e.g., W4A16 Llama 70B drafting for FP16 Llama 70B)
  3. A lightweight head attached to the target (e.g., Medusa, Eagle)
class DraftModelManager:
    """Manages the draft model for speculative decoding."""

    def __init__(
        self,
        draft_model_name: str,
        target_model: "Model",
        draft_tp_size: int = 1,
        draft_device: str = "cuda",
    ):
        self.target_model = target_model
        self.draft_tp_size = draft_tp_size

        if draft_model_name == "__medusa__":
            # Medusa: lightweight heads on top of target model
            self.draft_type = "medusa"
            self.draft_model = MedusaHeads(
                target_model.config,
                num_heads=5,  # K = number of Medusa heads
            ).to(draft_device)
        elif draft_model_name == "__eagle__":
            # Eagle: autoregressive draft head
            self.draft_type = "eagle"
            self.draft_model = EagleDraftHead(
                target_model.config,
            ).to(draft_device)
        else:
            # Separate model
            self.draft_type = "model"
            self.draft_model = load_model(
                draft_model_name,
                tp_size=draft_tp_size,
                device=draft_device,
            )

        self.tokenizer = target_model.tokenizer  # Must share tokenizer

    def generate_draft_tokens(
        self,
        input_ids: "torch.Tensor",
        positions: "torch.Tensor",
        kv_cache: "KVCache",
        num_tokens: int,
    ):
        """Generate K draft tokens autoregressively."""
        if self.draft_type == "medusa":
            return self._generate_medusa(input_ids, positions, kv_cache, num_tokens)
        elif self.draft_type == "eagle":
            return self._generate_eagle(input_ids, positions, kv_cache, num_tokens)
        else:
            return self._generate_autoregressive(input_ids, positions, kv_cache, num_tokens)

Autoregressive Draft Generation

For separate draft models, generation is standard autoregressive decoding β€” but much faster because the draft model is small:

def _generate_autoregressive(self, input_ids, positions, kv_cache, K):
    """Generate K tokens autoregressively with the draft model."""
    draft_tokens = []
    draft_probs = []

    current_token = input_ids[:, -1:]  # Last token
    current_pos = positions[:, -1:] + 1

    for k in range(K):
        # Forward pass through draft model
        logits = self.draft_model(
            current_token,
            current_pos,
            kv_cache=kv_cache,
        )

        # Sample from draft distribution
        probs = torch.softmax(logits[:, -1, :], dim=-1)
        token = torch.multinomial(probs, num_samples=1)

        draft_tokens.append(token)
        draft_probs.append(probs)

        # Update for next iteration
        current_token = token
        current_pos = current_pos + 1

    return (
        torch.cat(draft_tokens, dim=1),   # [batch_size, K]
        torch.stack(draft_probs, dim=1),   # [batch_size, K, vocab_size]
    )

Medusa: Parallel Draft Generation

Medusa uses multiple prediction heads that generate all KK tokens in a single forward pass (non-autoregressive):

class MedusaHeads(torch.nn.Module):
    """Medusa-style parallel draft heads."""

    def __init__(self, config, num_heads=5):
        super().__init__()
        self.num_heads = num_heads
        # Each head is a small MLP that predicts the next token
        # Head i predicts token at position current + i + 1
        self.heads = torch.nn.ModuleList([
            torch.nn.Sequential(
                torch.nn.Linear(config.hidden_size, config.hidden_size),
                torch.nn.SiLU(),
                torch.nn.Linear(config.hidden_size, config.vocab_size),
            )
            for _ in range(num_heads)
        ])

    def forward(self, hidden_states):
        """
        Args:
            hidden_states: [batch_size, hidden_dim] - last hidden state from target model
        Returns:
            draft_logits: [batch_size, num_heads, vocab_size]
        """
        logits = []
        for head in self.heads:
            logits.append(head(hidden_states))
        return torch.stack(logits, dim=1)

    def generate(self, hidden_states):
        logits = self.forward(hidden_states)
        probs = torch.softmax(logits, dim=-1)
        tokens = torch.argmax(logits, dim=-1)  # Greedy for Medusa
        return tokens, probs

Medusa’s advantage: all KK draft tokens are generated in a single forward pass (vs. KK passes for autoregressive drafting). The cost is lower quality: each head predicts independently, so later heads have less context.

Eagle: Autoregressive Draft Head

Eagle uses a small autoregressive model that takes the target model’s hidden states as input:

class EagleDraftHead(torch.nn.Module):
    """Eagle-style autoregressive draft head."""

    def __init__(self, config):
        super().__init__()
        # Feature projection: combine target hidden state + token embedding
        self.fc = torch.nn.Linear(
            config.hidden_size * 2,
            config.hidden_size,
        )
        # Small transformer for autoregressive draft
        self.transformer = torch.nn.TransformerDecoder(
            torch.nn.TransformerDecoderLayer(
                d_model=config.hidden_size,
                nhead=8,
                dim_feedforward=config.hidden_size * 2,
                batch_first=True,
            ),
            num_layers=2,
        )
        self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, target_hidden, token_embeddings):
        """
        Args:
            target_hidden: [batch_size, seq_len, hidden_dim]
            token_embeddings: [batch_size, seq_len, hidden_dim]
        """
        combined = torch.cat([target_hidden, token_embeddings], dim=-1)
        features = self.fc(combined)
        hidden = self.transformer(features, features)
        return self.lm_head(hidden)

Scheduler Coordination

The Two-Phase Scheduler

The scheduler must orchestrate draft and verification steps for each sequence. In vLLM v1, this is handled by extending the unified scheduler with speculative awareness:

class SpeculativeScheduler:
    """Scheduler that coordinates draft and verification phases."""

    def __init__(self, base_scheduler, num_speculative_tokens: int):
        self.base_scheduler = base_scheduler
        self.K = num_speculative_tokens
        self.seq_states = {}  # seq_id -> "drafting" | "verifying" | "normal"

    def schedule(self):
        """
        Schedule one iteration of speculative decoding.

        The iteration either:
        1. Runs draft generation for sequences in "drafting" state
        2. Runs target verification for sequences in "verifying" state
        3. Both can happen in the same iteration for different sequences
        """
        scheduled = self.base_scheduler.schedule()

        draft_batch = []
        verify_batch = []
        normal_batch = []

        for seq in scheduled.running:
            state = self.seq_states.get(seq.seq_id, "drafting")

            if state == "drafting":
                draft_batch.append(seq)
            elif state == "verifying":
                verify_batch.append(seq)
            else:
                normal_batch.append(seq)

        return ScheduleOutput(
            draft_batch=draft_batch,
            verify_batch=verify_batch,
            normal_batch=normal_batch,
            prefill_batch=scheduled.prefill,
        )

    def on_verification_complete(self, seq_id, num_accepted, bonus_token):
        """Update state after verification."""
        # Accepted tokens + bonus = tokens produced this round
        # Reset to drafting state for next round
        self.seq_states[seq_id] = "drafting"

    def on_draft_complete(self, seq_id, draft_tokens):
        """Update state after draft generation."""
        self.seq_states[seq_id] = "verifying"

Interleaving with Regular Decode

Not all sequences use speculative decoding. Some may be in prefill, some may have speculative decoding disabled (e.g., short outputs where the overhead is not worth it). The scheduler handles mixed batches:

def execute_step(self, schedule_output):
    """Execute one scheduler step with mixed batches."""
    # 1. Run prefill for new sequences (no speculation)
    if schedule_output.prefill_batch:
        self.run_prefill(schedule_output.prefill_batch)

    # 2. Run draft generation (cheap, on draft model)
    if schedule_output.draft_batch:
        draft_results = self.run_draft(schedule_output.draft_batch)
        for seq, (tokens, probs) in zip(
            schedule_output.draft_batch, draft_results
        ):
            seq.draft_tokens = tokens
            seq.draft_probs = probs
            self.scheduler.on_draft_complete(seq.seq_id, tokens)

    # 3. Run target verification (expensive, on target model)
    # Verification batch includes draft tokens as additional input positions
    if schedule_output.verify_batch:
        verify_results = self.run_verification(schedule_output.verify_batch)
        for seq, (accepted, bonus) in zip(
            schedule_output.verify_batch, verify_results
        ):
            self.process_verification_result(seq, accepted, bonus)

    # 4. Run normal decode for non-speculative sequences
    if schedule_output.normal_batch:
        self.run_decode(schedule_output.normal_batch)

Target Model Verification

Batched Verification Forward Pass

The verification step runs the target model over all draft positions plus one additional position in a single forward pass. The input is constructed by appending all draft tokens to the sequence:

def prepare_verification_input(self, sequences):
    """Prepare input for target model verification."""
    all_input_ids = []
    all_positions = []
    all_slot_mappings = []
    seq_lens = []

    for seq in sequences:
        K = seq.draft_tokens.shape[0]

        # The verification input is the K draft tokens
        # The target model will compute logits at all K positions
        # plus position K+1 for the bonus token
        input_ids = seq.draft_tokens  # [K]

        # Positions continue from where the sequence left off
        start_pos = seq.current_length
        positions = torch.arange(start_pos, start_pos + K, device="cuda")

        # Slot mappings for KV cache writes
        slots = self.kv_manager.get_slots(seq.seq_id, K)

        all_input_ids.append(input_ids)
        all_positions.append(positions)
        all_slot_mappings.append(slots)
        seq_lens.append(K)

    # Batch the inputs
    # Since K may vary per sequence (if adaptive K is used),
    # we need to handle variable lengths
    max_K = max(seq_lens)
    batch_size = len(sequences)

    padded_ids = torch.zeros(batch_size, max_K, dtype=torch.long, device="cuda")
    padded_pos = torch.zeros(batch_size, max_K, dtype=torch.long, device="cuda")

    for i, (ids, pos, sl) in enumerate(zip(all_input_ids, all_positions, seq_lens)):
        padded_ids[i, :sl] = ids
        padded_pos[i, :sl] = pos

    return padded_ids, padded_pos, all_slot_mappings, seq_lens

def run_verification(self, sequences):
    """Run target model verification on draft tokens."""
    input_ids, positions, slot_mappings, seq_lens = \
        self.prepare_verification_input(sequences)

    # Single forward pass over all draft positions
    with torch.no_grad():
        target_logits = self.target_model(
            input_ids, positions,
            kv_cache=self.target_kv_cache,
            slot_mappings=slot_mappings,
        )

    # Run rejection sampling
    results = []
    for i, seq in enumerate(sequences):
        K = seq_lens[i]
        t_logits = target_logits[i, :K+1]  # K draft + 1 bonus position
        d_probs = seq.draft_probs[:K]
        d_tokens = seq.draft_tokens[:K]

        accepted, num_accepted, bonus = self.rejection_sampler(
            t_logits.unsqueeze(0),
            d_probs.unsqueeze(0),
            d_tokens.unsqueeze(0),
            temperature=seq.sampling_params.temperature,
        )
        results.append((accepted.squeeze(0), num_accepted.item(), bonus.item()))

    return results
ℹ️ Verification Is Cheap

The verification forward pass processes KK tokens per sequence, but this is structured as an extended prefill β€” the target model processes multiple tokens in one pass using the same compute as a short prefill. For K=5K=5 with batch size 64, the verification pass costs roughly the same as a single decode step with batch size 320.

KV Cache Management for Speculated Tokens

Optimistic Allocation

When the draft model generates KK tokens, the KV cache manager must allocate slots for all KK positions. These allocations are β€œoptimistic” β€” they may be freed if the corresponding tokens are rejected.

class SpeculativeKVCacheManager:
    """KV cache management for speculative decoding."""

    def __init__(self, block_manager, block_size=16):
        self.block_manager = block_manager
        self.block_size = block_size
        self.speculative_allocations = {}  # seq_id -> list of (block_id, slot_idx)

    def allocate_for_draft(self, seq_id, num_tokens):
        """Allocate KV cache slots for draft tokens."""
        slots = []
        for _ in range(num_tokens):
            block_id, slot_idx = self.block_manager.allocate_slot(seq_id)
            slots.append((block_id, slot_idx))

        self.speculative_allocations[seq_id] = slots
        return [block_id * self.block_size + slot_idx for block_id, slot_idx in slots]

    def commit(self, seq_id, num_accepted):
        """Commit accepted slots, free rejected ones."""
        if seq_id not in self.speculative_allocations:
            return

        slots = self.speculative_allocations[seq_id]

        # Free rejected slots (from position num_accepted onward)
        for block_id, slot_idx in slots[num_accepted:]:
            self.block_manager.free_slot(seq_id, block_id, slot_idx)

        # Keep accepted slots (positions 0 to num_accepted-1)
        # Also allocate one more slot for the bonus token
        bonus_block_id, bonus_slot_idx = self.block_manager.allocate_slot(seq_id)

        del self.speculative_allocations[seq_id]

        return {
            'accepted_slots': num_accepted,
            'freed_slots': len(slots) - num_accepted,
            'bonus_slot': bonus_block_id * self.block_size + bonus_slot_idx,
        }

    def rollback(self, seq_id):
        """Free all speculative slots (complete rejection)."""
        if seq_id in self.speculative_allocations:
            for block_id, slot_idx in self.speculative_allocations[seq_id]:
                self.block_manager.free_slot(seq_id, block_id, slot_idx)
            del self.speculative_allocations[seq_id]

KV Cache Invalidation on Rejection

When tokens are rejected at position ii, the KV cache entries at positions i,i+1,…,Kβˆ’1i, i+1, \ldots, K-1 contain invalid data (computed from wrong token IDs). These entries must be either:

  1. Freed (the slots return to the free pool), or
  2. Overwritten in the next verification pass

vLLM v1 uses approach (1): free the slots immediately. The next iteration will re-allocate and rewrite them with correct data.

def process_verification_result(self, seq, accepted, num_accepted, bonus_token):
    """Process the result of token verification."""
    K = len(seq.draft_tokens)

    # Commit accepted tokens to the sequence
    for i in range(num_accepted):
        seq.append_token(accepted[i])

    # Append the bonus token
    seq.append_token(bonus_token)

    # Manage KV cache
    kv_result = self.kv_manager.commit(seq.seq_id, num_accepted)

    # Update sequence position
    seq.current_length += num_accepted + 1  # accepted + bonus

    # Log for monitoring
    self.metrics.record_speculation(
        seq_id=seq.seq_id,
        proposed=K,
        accepted=num_accepted,
        acceptance_rate=num_accepted / K,
    )

Draft Model KV Cache

The draft model also builds up KV cache during its KK autoregressive steps. This KV cache has two options:

  1. Discard after each verification: Simple but wasteful. The draft model must re-encode the entire context for the next round.
  2. Keep and extend: The draft model’s KV cache is kept and extended across iterations. Only the rejected positions are rolled back.
class DraftKVCacheStrategy:
    """Manage draft model KV cache across speculative iterations."""

    def __init__(self, strategy="keep_and_rollback"):
        self.strategy = strategy

    def after_verification(self, draft_kv_cache, seq_id, num_accepted, K):
        if self.strategy == "discard":
            # Free all draft KV cache -- simple but requires re-encoding
            draft_kv_cache.free_all(seq_id)
        elif self.strategy == "keep_and_rollback":
            # Keep positions 0..num_accepted, free num_accepted..K-1
            # The accepted positions have valid KV data
            draft_kv_cache.truncate(seq_id, num_accepted)
            # Next round starts drafting from position num_accepted + 1
            # (the bonus token will be the new starting point)

When Speculative Decoding Helps vs. Hurts

The Crossover Analysis

Speculative decoding helps when:

nˉ+1K⋅td+tv>1tv\frac{\bar{n} + 1}{K \cdot t_d + t_v} > \frac{1}{t_v}

Rearranging: nˉ>K⋅tdtv\bar{n} > \frac{K \cdot t_d}{t_v}

This means the average number of accepted tokens must exceed the ratio of total draft time to verification time.

def should_use_speculation(
    draft_latency_ms: float,
    target_latency_ms: float,
    acceptance_rate: float,
    K: int,
) -> dict:
    """Determine if speculative decoding improves throughput."""
    # Average accepted tokens
    avg_accepted = sum(
        acceptance_rate ** i * (1 - acceptance_rate)
        for i in range(1, K + 1)
    ) + K * acceptance_rate ** K
    # Actually compute expected value properly
    avg_accepted = sum(
        acceptance_rate ** k for k in range(1, K + 1)
    )

    # Time per step
    spec_time = K * draft_latency_ms + target_latency_ms
    normal_time = target_latency_ms

    # Tokens per step
    spec_tokens = avg_accepted + 1  # accepted + bonus
    normal_tokens = 1

    # Throughput comparison
    spec_throughput = spec_tokens / spec_time
    normal_throughput = normal_tokens / normal_time
    speedup = spec_throughput / normal_throughput

    return {
        'avg_accepted': avg_accepted,
        'speedup': speedup,
        'beneficial': speedup > 1.0,
        'spec_tokens_per_ms': spec_throughput,
        'normal_tokens_per_ms': normal_throughput,
    }

Batch Size Interaction

Speculative decoding’s benefit depends on batch size. At large batch sizes, the GPU is already well-utilized, and the draft model’s compute cost is proportionally larger:

Speculative Decoding Speedup vs. Batch Size (Llama 70B, K=5)

(x speedup)
BS=1 Best case
2.8 x speedup
BS=4
2.5 x speedup
BS=16
2 x speedup
BS=32
1.5 x speedup
BS=64 Marginal
1.1 x speedup
BS=128 Hurts
0.85 x speedup

At batch size 128, speculative decoding actually hurts throughput because:

  1. The draft model competes for GPU compute with the target model
  2. Verification requires KΓ—K \times more positions per sequence, inflating the effective batch size
  3. The overhead of rejection sampling and KV cache management exceeds the benefit
⚠️ Batch Size Threshold

Speculative decoding is most beneficial at small batch sizes (1-16) and becomes marginal or harmful above batch sizes of 64. For high-throughput serving with large batches, standard continuous batching without speculation is usually faster.

Advanced: Multi-Sequence Speculation

Sequence-Level Adaptive K

Different sequences benefit from different numbers of draft tokens. A sequence generating creative text (low confidence, high entropy) will have lower acceptance rates than a sequence generating structured output (high confidence, low entropy).

class AdaptiveSpeculator:
    """Adapt K per sequence based on running acceptance rate."""

    def __init__(self, min_K=1, max_K=8, target_efficiency=0.7):
        self.min_K = min_K
        self.max_K = max_K
        self.target_efficiency = target_efficiency
        self.seq_stats = {}  # seq_id -> running stats

    def get_K(self, seq_id):
        """Get the optimal K for this sequence."""
        if seq_id not in self.seq_stats:
            return 5  # Default starting K

        stats = self.seq_stats[seq_id]
        rate = stats['accepted'] / max(stats['proposed'], 1)

        # Binary search for K that achieves target efficiency
        # Efficiency = (avg_accepted + 1) / (K * t_draft / t_target + 1)
        # Simplified: if rate is high, increase K; if low, decrease K
        current_K = stats['current_K']

        if rate > 0.85 and current_K < self.max_K:
            return current_K + 1
        elif rate < 0.55 and current_K > self.min_K:
            return current_K - 1
        else:
            return current_K

    def update(self, seq_id, K, num_accepted):
        if seq_id not in self.seq_stats:
            self.seq_stats[seq_id] = {
                'accepted': 0, 'proposed': 0, 'current_K': 5
            }
        stats = self.seq_stats[seq_id]
        stats['accepted'] += num_accepted
        stats['proposed'] += K
        stats['current_K'] = self.get_K(seq_id)

Tree-Based Speculation (Medusa/Eagle Extension)

Instead of generating KK tokens in a linear chain, tree-based speculation generates a tree of candidate continuations:

Position 0: token A
  Position 1a: token B (given A)
    Position 2a: token D (given A,B)
    Position 2b: token E (given A,B)
  Position 1b: token C (given A)
    Position 2c: token F (given A,C)
    Position 2d: token G (given A,C)

The tree has more candidates than a linear chain, increasing the chance that at least one path matches the target distribution. The verification step uses a tree attention mask:

def build_tree_attention_mask(tree_structure):
    """Build attention mask for tree-based verification."""
    num_nodes = len(tree_structure)
    mask = torch.zeros(num_nodes, num_nodes, dtype=torch.bool)

    for node_id, node in enumerate(tree_structure):
        # Each node can attend to itself and all ancestors
        current = node_id
        while current >= 0:
            mask[node_id, current] = True
            current = tree_structure[current].parent_id

    return mask

# Tree verification: target model processes all tree nodes at once
# with a tree attention mask that enforces the correct causal structure
def verify_tree(target_model, tree_tokens, tree_mask, kv_cache):
    target_logits = target_model(
        tree_tokens,
        attention_mask=tree_mask,
        kv_cache=kv_cache,
    )

    # Find the longest accepted path in the tree
    # using the rejection sampling criterion at each node
    accepted_path = find_longest_accepted_path(
        target_logits, tree_tokens, tree_structure
    )
    return accepted_path

Complete Integration Example

class SpeculativeDecodingEngine:
    """
    Complete speculative decoding integration.
    Coordinates draft model, target model, scheduler, and KV cache.
    """

    def __init__(
        self,
        target_model,
        draft_model_name,
        num_speculative_tokens=5,
        max_batch_size=256,
    ):
        self.target_model = target_model
        self.K = num_speculative_tokens

        # Initialize draft model
        self.draft_manager = DraftModelManager(
            draft_model_name, target_model
        )

        # KV cache for both models
        self.target_kv = KVCachePool(target_model.config)
        self.draft_kv = KVCachePool(
            self.draft_manager.draft_model.config
            if hasattr(self.draft_manager.draft_model, 'config')
            else target_model.config
        )

        # Speculative KV manager
        self.spec_kv_manager = SpeculativeKVCacheManager(self.target_kv)

        # Rejection sampler
        from vllm_v1_rejection_sampler_cfg import RejectionSampler
        self.sampler = RejectionSampler()

        # Adaptive K
        self.adaptive = AdaptiveSpeculator()

        # Metrics
        self.total_steps = 0
        self.total_tokens_generated = 0

    def step(self, active_sequences):
        """Execute one speculative decoding step."""
        self.total_steps += 1
        batch_results = {}

        for seq in active_sequences:
            K = self.adaptive.get_K(seq.seq_id)

            # Phase 1: Draft
            draft_slots = self.spec_kv_manager.allocate_for_draft(seq.seq_id, K)
            draft_tokens, draft_probs = self.draft_manager.generate_draft_tokens(
                seq.get_input(), seq.get_positions(),
                self.draft_kv, K,
            )

            # Phase 2: Verify
            target_logits = self._run_target_verification(
                seq, draft_tokens, draft_slots
            )

            # Phase 3: Accept/Reject
            result = self.sampler(
                target_logits.unsqueeze(0),
                draft_probs.unsqueeze(0),
                draft_tokens.unsqueeze(0),
                temperature=seq.sampling_params.temperature,
            )

            num_accepted = result.num_accepted[0].item()
            bonus = result.bonus_tokens[0].item()

            # Commit KV cache
            self.spec_kv_manager.commit(seq.seq_id, num_accepted)

            # Update sequence
            for i in range(num_accepted):
                seq.append_token(result.accepted_tokens[0, i].item())
            seq.append_token(bonus)

            # Update adaptive K
            self.adaptive.update(seq.seq_id, K, num_accepted)

            tokens_produced = num_accepted + 1
            self.total_tokens_generated += tokens_produced
            batch_results[seq.seq_id] = tokens_produced

        return batch_results

    def _run_target_verification(self, seq, draft_tokens, draft_slots):
        """Run target model verification pass."""
        # Build verification input: the K draft tokens
        K = draft_tokens.shape[0]
        positions = torch.arange(
            seq.current_length,
            seq.current_length + K + 1,  # +1 for bonus position
            device="cuda",
        )

        # Forward pass
        with torch.no_grad():
            logits = self.target_model(
                torch.cat([draft_tokens, torch.zeros(1, dtype=torch.long, device="cuda")]),
                positions,
                kv_cache=self.target_kv,
            )
        return logits

    @property
    def avg_tokens_per_step(self):
        if self.total_steps == 0:
            return 0
        return self.total_tokens_generated / self.total_steps

Performance Monitoring

class SpeculativeMetricsCollector:
    """Collect and report speculative decoding metrics."""

    def __init__(self):
        self.acceptance_rates = []
        self.tokens_per_step = []
        self.draft_latencies = []
        self.verify_latencies = []

    def report(self):
        if not self.acceptance_rates:
            return {}

        import numpy as np
        return {
            'acceptance_rate_mean': np.mean(self.acceptance_rates),
            'acceptance_rate_p50': np.percentile(self.acceptance_rates, 50),
            'tokens_per_step_mean': np.mean(self.tokens_per_step),
            'draft_latency_ms_mean': np.mean(self.draft_latencies),
            'verify_latency_ms_mean': np.mean(self.verify_latencies),
            'overhead_ratio': (
                np.mean(self.draft_latencies) /
                np.mean(self.verify_latencies)
            ),
            'effective_speedup': (
                np.mean(self.tokens_per_step) /
                (1 + np.mean(self.draft_latencies) / np.mean(self.verify_latencies))
            ),
        }