You cannot ship a custom sampling strategy by forking vLLM — every upstream merge becomes a merge conflict nightmare. You cannot add support for a new model architecture by modifying core files — your changes break with every release. vLLM v1 solves this by exposing formal extension points: custom samplers that run after logits are computed, custom schedulers that change batching policy, custom attention backends that replace the kernel implementation, and model plugins that add new architectures. These extension points come with strict contracts — latency budgets, API signatures, and performance constraints — that ensure plugins integrate cleanly without degrading the serving engine.
Extension Point Architecture
vLLM v1 provides four primary extension points.
class VLLMExtensionPoints:
"""
vLLM v1 extension points and their interfaces.
"""
EXTENSION_POINTS = {
"sampler": {
"interface": "SamplerBase",
"called_at": "Every decode step, after logits are computed",
"latency_budget_ms": 0.5,
"input": "Logits tensor [batch_size, vocab_size]",
"output": "Selected token IDs [batch_size]",
"registration": "Via sampling_params or engine config",
},
"scheduler": {
"interface": "SchedulerBase",
"called_at": "Every scheduler iteration (every ~10-50ms)",
"latency_budget_ms": 1.0,
"input": "Waiting queue + running queue + KV cache state",
"output": "Batch of requests to process this iteration",
"registration": "Via engine config",
},
"attention_backend": {
"interface": "AttentionBackend",
"called_at": "Every attention layer in every forward pass",
"latency_budget_ms": 0.1, # Per layer
"input": "Q, K, V tensors + attention mask",
"output": "Attention output tensor",
"registration": "Via VLLM_ATTENTION_BACKEND env var",
},
"model": {
"interface": "ModelForCausalLM",
"called_at": "Model loading + every forward pass",
"input": "Input token IDs + KV cache",
"output": "Logits",
"registration": "Via model config auto-detection",
},
}
Extension Point Latency Budgets
| Extension Point | Called Per | Latency Budget | Performance Impact If Exceeded |
|---|---|---|---|
| Custom Sampler | Decode step | 0.5 ms | Increases TPOT by same amount |
| Custom Scheduler | Iteration | 1.0 ms | Reduces max throughput |
| Attention Backend | Layer * Step | 0.1 ms | Catastrophic (80 layers) |
| Model Plugin | Forward pass | N/A | Replaces entire model |
Custom Sampler Implementation
Custom samplers modify token selection after logits are computed.
from typing import Optional
class SamplerBase:
"""
Base class for custom samplers in vLLM.
The sampler receives logits and sampling parameters,
and returns selected token IDs.
"""
def sample(
self,
logits: list, # [batch_size, vocab_size] float32
sampling_params: list, # Per-request params
) -> list:
"""
Select tokens from logits.
Must be implemented by custom samplers.
Must be fast: called every decode step for every sequence.
"""
raise NotImplementedError
class ConstrainedJSONSampler(SamplerBase):
"""
Custom sampler that constrains output to valid JSON.
Uses a state machine to mask invalid tokens at each step.
"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
# Build token classification
self.digit_tokens = set()
self.string_tokens = set()
self.bracket_tokens = set()
self.colon_tokens = set()
self.comma_tokens = set()
for token_id in range(tokenizer.vocab_size):
text = tokenizer.decode([token_id])
if text.strip() in '0123456789':
self.digit_tokens.add(token_id)
if '"' in text:
self.string_tokens.add(token_id)
if text.strip() in '{}[]':
self.bracket_tokens.add(token_id)
def sample(self, logits: list, sampling_params: list) -> list:
"""
Apply JSON structural constraints to logits before sampling.
"""
results = []
for i, (logit_row, params) in enumerate(zip(logits, sampling_params)):
state = params.get("json_state", "value")
allowed_tokens = self._get_allowed_tokens(state)
# Mask disallowed tokens to -inf
masked_logits = []
for token_id, logit in enumerate(logit_row):
if token_id in allowed_tokens:
masked_logits.append(logit)
else:
masked_logits.append(float('-inf'))
# Standard temperature + top-p sampling on masked logits
token_id = self._sample_from_logits(
masked_logits,
temperature=params.get("temperature", 1.0),
top_p=params.get("top_p", 1.0)
)
results.append(token_id)
return results
def _get_allowed_tokens(self, state: str) -> set:
"""Get allowed tokens based on JSON parser state."""
if state == "value":
return self.digit_tokens | self.string_tokens | self.bracket_tokens
elif state == "key":
return self.string_tokens
elif state == "colon":
return self.colon_tokens
elif state == "comma_or_close":
return self.comma_tokens | self.bracket_tokens
return set(range(self.tokenizer.vocab_size)) # Allow all
def _sample_from_logits(self, logits: list, temperature: float,
top_p: float) -> int:
"""Standard sampling with temperature and top-p."""
import math
import random
# Apply temperature
if temperature > 0:
logits = [l / temperature for l in logits]
# Softmax
max_logit = max(logits)
exp_logits = [math.exp(l - max_logit) for l in logits]
total = sum(exp_logits)
probs = [e / total for e in exp_logits]
# Top-p (nucleus) sampling
sorted_indices = sorted(range(len(probs)), key=lambda i: -probs[i])
cumsum = 0
cutoff_idx = len(sorted_indices)
for idx, i in enumerate(sorted_indices):
cumsum += probs[i]
if cumsum >= top_p:
cutoff_idx = idx + 1
break
allowed = sorted_indices[:cutoff_idx]
allowed_probs = [probs[i] for i in allowed]
total_allowed = sum(allowed_probs)
allowed_probs = [p / total_allowed for p in allowed_probs]
# Sample
r = random.random()
cumsum = 0
for i, p in zip(allowed, allowed_probs):
cumsum += p
if r <= cumsum:
return i
return allowed[-1]
Constrained decoding (JSON mode, function calling, grammar-guided generation) is one of the most common use cases for custom samplers. The key performance requirement: the constraint check must run in under 0.5ms per step. For complex grammars, pre-compute the token masks for each state to avoid per-token regex evaluation.
Custom Scheduler Extension
Custom schedulers change request prioritization and batching.
class SchedulerBase:
"""
Base class for custom schedulers.
"""
def schedule(
self,
waiting_queue: list,
running_queue: list,
kv_cache_state: dict
) -> dict:
"""
Decide which requests to process this iteration.
Returns:
- prefill_requests: new requests to start
- decode_requests: running requests to continue
- preempt_requests: running requests to preempt
"""
raise NotImplementedError
class PriorityFairnessScheduler(SchedulerBase):
"""
Custom scheduler that balances priority with fairness.
Priority: higher-priority requests are served first.
Fairness: no request waits more than max_wait_time.
"""
def __init__(self, config: dict):
self.max_wait_sec = config.get("max_wait_sec", 30)
self.priority_weight = config.get("priority_weight", 0.7)
self.fairness_weight = config.get("fairness_weight", 0.3)
def schedule(self, waiting_queue: list, running_queue: list,
kv_cache_state: dict) -> dict:
"""
Score-based scheduling with priority and fairness.
"""
import time
# Score each waiting request
scored_requests = []
current_time = time.time()
for req in waiting_queue:
wait_time = current_time - req["arrival_time"]
priority = req.get("priority", 3) # 1=highest, 5=lowest
# Score: combination of priority and wait time
priority_score = (6 - priority) / 5 # Normalize to [0, 1]
fairness_score = min(1.0, wait_time / self.max_wait_sec)
total_score = (
self.priority_weight * priority_score +
self.fairness_weight * fairness_score
)
# Boost if approaching max wait time
if wait_time > self.max_wait_sec * 0.8:
total_score += 0.5 # Emergency boost
scored_requests.append({
"request": req,
"score": total_score,
"wait_time": wait_time,
})
# Sort by score (descending)
scored_requests.sort(key=lambda x: -x["score"])
# Admit requests based on available KV cache
available_blocks = kv_cache_state["free_blocks"]
prefill_requests = []
for scored in scored_requests:
req = scored["request"]
blocks_needed = self._estimate_blocks(req)
if blocks_needed <= available_blocks:
prefill_requests.append(req)
available_blocks -= blocks_needed
return {
"prefill_requests": prefill_requests,
"decode_requests": running_queue, # Continue all running
"preempt_requests": [],
}
def _estimate_blocks(self, request: dict) -> int:
"""Estimate KV cache blocks needed for a request."""
tokens = request.get("input_length", 0) + request.get("max_tokens", 1024)
block_size = 16 # tokens per block
return (tokens + block_size - 1) // block_size
Custom Scheduler: Priority-Fairness vs FCFS
| Metric | FCFS (default) | Priority-Fairness | Improvement |
|---|---|---|---|
| P1 avg TTFT | 500 ms | 200 ms | -60% |
| P5 avg TTFT | 300 ms | 800 ms | +167% (trade-off) |
| Max wait time | Unbounded | 30 sec (guaranteed) | Bounded |
| Overall throughput | 3,400 tok/s | 3,300 tok/s | -3% |
| P1 tail latency (P99) | 2,000 ms | 500 ms | -75% |
Custom Attention Backend
Attention backends are the lowest-level extension point. They replace the attention kernel.
class AttentionBackend:
"""
Base class for attention backend implementations.
This is performance-critical: called once per layer per step.
For an 80-layer model at 20ms/step, each call has ~0.25ms budget.
"""
def forward(
self,
query: list, # [num_tokens, num_heads, head_dim]
key: list, # [num_tokens, num_kv_heads, head_dim]
value: list, # [num_tokens, num_kv_heads, head_dim]
kv_cache: dict, # Paged KV cache reference
attn_metadata: dict, # Sequence lengths, block tables, etc.
) -> list:
"""
Compute attention output.
Must handle:
1. Prefill attention (causal, full context)
2. Decode attention (single query token, full KV cache)
3. Paged KV cache (blocks may be non-contiguous)
4. GQA (num_heads != num_kv_heads)
"""
raise NotImplementedError
class CustomSlidingWindowAttention(AttentionBackend):
"""
Example: attention backend with custom sliding window.
Only attends to the last W tokens, reducing memory for long contexts.
"""
def __init__(self, window_size: int = 4096):
self.window_size = window_size
def forward(self, query, key, value, kv_cache, attn_metadata):
"""
Compute attention with sliding window.
For each query position q:
- Only attend to keys in range [max(0, q - W), q]
- This reduces attention computation from O(n^2) to O(n * W)
"""
# In practice, this is a CUDA kernel
# The key optimization: only load W KV blocks from paged cache
# instead of all blocks for the sequence
num_tokens = len(query)
outputs = []
for i in range(num_tokens):
seq_len = attn_metadata["seq_lens"][i]
window_start = max(0, seq_len - self.window_size)
# Load only KV blocks within window
relevant_blocks = self._get_blocks_in_range(
kv_cache, attn_metadata["block_tables"][i],
window_start, seq_len
)
# Compute attention (simplified)
# output = softmax(Q @ K^T / sqrt(d)) @ V
# Only over window_start:seq_len range
output = self._compute_windowed_attention(
query[i], relevant_blocks, attn_metadata
)
outputs.append(output)
return outputs
def _get_blocks_in_range(self, kv_cache, block_table,
start, end) -> list:
"""Get KV cache blocks within the sliding window."""
block_size = 16
start_block = start // block_size
end_block = (end + block_size - 1) // block_size
return block_table[start_block:end_block]
def _compute_windowed_attention(self, query, kv_blocks,
metadata) -> list:
"""Compute attention over windowed KV blocks."""
pass # CUDA kernel
Custom attention backends are the hardest extension point to get right. The performance requirements are strict: each call must complete in under 0.25ms for an 80-layer model at 20ms/step. A custom attention backend that is 2x slower than FlashAttention will double the forward pass time. Always benchmark against the built-in backends before deploying.
Model Plugin: Adding New Architectures
class ModelPluginInterface:
"""
Interface for adding new model architectures to vLLM.
To add a new model:
1. Implement the model class
2. Register it in vLLM's model registry
3. Handle weight loading from checkpoint format
"""
def __init__(self, config):
self.config = config
self.layers = []
def load_weights(self, weight_path: str):
"""
Load model weights from checkpoint.
Must handle:
- Weight name mapping (checkpoint format -> model attribute)
- Tensor parallelism (split weights across GPUs)
- Quantization (load quantized weights if applicable)
"""
raise NotImplementedError
def forward(
self,
input_ids: list,
positions: list,
kv_caches: list,
attn_metadata: dict,
) -> list:
"""
Forward pass returning logits.
Must support:
- Mixed prefill and decode in same batch
- Paged KV cache via attn_metadata
- Variable sequence lengths
"""
raise NotImplementedError
class CustomMoEModel(ModelPluginInterface):
"""
Example: custom MoE model with non-standard routing.
"""
def __init__(self, config):
super().__init__(config)
self.num_experts = config.get("num_experts", 64)
self.top_k = config.get("top_k", 4)
self.hidden_dim = config.get("hidden_dim", 4096)
def forward(self, input_ids, positions, kv_caches, attn_metadata):
"""
Custom forward pass with MoE routing.
The key integration point: must use vLLM's paged attention
for KV cache management, but can implement custom MoE logic.
"""
hidden = self.embed(input_ids)
for layer_idx, layer in enumerate(self.layers):
# Standard attention (uses vLLM's paged attention)
residual = hidden
hidden = layer.attention(
hidden, positions, kv_caches[layer_idx], attn_metadata
)
hidden = residual + hidden
# Custom MoE FFN
residual = hidden
router_logits = layer.router(hidden)
expert_ids, expert_weights = self.custom_routing(router_logits)
hidden = layer.moe_ffn(hidden, expert_ids, expert_weights)
hidden = residual + hidden
logits = self.lm_head(hidden)
return logits
def custom_routing(self, router_logits: list) -> tuple:
"""
Custom routing logic (e.g., expert choice routing).
"""
# Expert-choice: experts pick their top-K tokens
# instead of tokens picking their top-K experts
pass
Model Plugin Registration Requirements
| Requirement | Purpose | Complexity |
|---|---|---|
| Weight name mapping | Map checkpoint keys to model attributes | Medium |
| TP weight splitting | Shard weights across GPUs | High |
| KV cache integration | Use vLLM's paged attention | Medium |
| Mixed batch support | Handle prefill + decode together | High |
| Quantization support | Load quantized weights | Medium |
Plugin Performance Guidelines
def plugin_performance_guidelines() -> dict:
"""
Performance requirements and guidelines for vLLM plugins.
"""
return {
"sampler_plugins": {
"max_latency_ms": 0.5,
"called_per": "decode_step * batch_size",
"optimization_tips": [
"Pre-compute token masks offline, not per-step",
"Use GPU tensors (not Python lists) for masking",
"Batch token classification (vectorized, not per-token)",
"Cache grammar state transitions",
],
"anti_patterns": [
"Python loop over vocabulary (50K iterations = slow)",
"Regex evaluation per token per step",
"CPU-GPU synchronization in sampler",
],
},
"scheduler_plugins": {
"max_latency_ms": 1.0,
"called_per": "scheduler_iteration",
"optimization_tips": [
"Keep data structures sorted (avoid re-sorting each call)",
"Use O(1) or O(log n) operations for priority lookup",
"Cache block estimates (don't recompute every iteration)",
],
},
"attention_plugins": {
"max_latency_ms": 0.25,
"called_per": "layer * step",
"optimization_tips": [
"Must be a CUDA kernel (Python too slow)",
"Must support paged KV cache block tables",
"Must handle GQA (grouped query attention) natively",
"Benchmark against FlashAttention v2 as baseline",
],
"warning": "Custom attention backends that are slower than FlashAttention will dominate latency",
},
}
Plugin Latency Budget per Decode Step (80-Layer Model)
vLLM’s plugin architecture enables customization at every layer of the serving stack. Custom samplers are the most common extension (constrained decoding, function calling). Custom schedulers enable priority-based serving for multi-tenant deployments. Attention backends are the most performance-sensitive and should only be customized when the built-in backends cannot support your model’s attention pattern. Model plugins provide the most flexibility but require the most integration work. The key constraint across all plugins: respect the latency budget. A plugin that adds 1ms per decode step in an 80-layer model adds 80ms to total generation time for a 1000-token response.