Constrained decoding guarantees valid structured output from LLMs. The foundational approach — compile a JSON schema into a finite state machine, compute allowed token sets per state, mask logits at each decoding step — works. But “works” and “works at production throughput” are different problems. A naive FSM implementation adds 2-5ms per token in mask computation overhead. At 100 tokens/second decode rate on an A100, that 2ms overhead represents a 20% throughput regression. For a serving system handling 1000 concurrent requests, the cumulative overhead is the difference between meeting SLOs and dropping traffic.
This post covers the performance engineering that makes structured output viable at scale: compressed FSMs that minimize state transitions, speculative JSON decoding that exploits grammar structure for multi-token prediction, XGrammar’s byte-level pushdown automaton architecture, cross-request grammar caching that amortizes compilation cost, and a complete implementation of a precompiled FSM with cached token masks.
The Performance Problem with Naive Constrained Decoding
1.1 Where Time Goes
A naive constrained decoding step does the following work per generated token:
def naive_constrained_decode_step(fsm, current_state, logits, tokenizer):
"""
Per-token constrained decoding. This runs on every single token.
"""
vocab_size = logits.shape[-1]
mask = torch.full((vocab_size,), float('-inf'), device=logits.device)
# For each token in vocabulary, check if it leads to a valid state
for token_id in range(vocab_size):
token_bytes = tokenizer.decode([token_id]).encode('utf-8')
next_state = fsm.try_advance(current_state, token_bytes)
if next_state is not None:
mask[token_id] = 0.0
masked_logits = logits + mask
return masked_logits
The for token_id in range(vocab_size) loop is the killer. For a vocabulary of 128K tokens (Llama 3), this loop runs 128,000 FSM transitions per decoding step. Each transition involves decoding the token to bytes and walking the FSM. Even at 100ns per transition, that is 12.8ms per token — completely unacceptable.
Per-Token Overhead of Constrained Decoding Approaches
| Approach | Vocab Size | Per-Token Overhead | Throughput Impact |
|---|---|---|---|
| Naive (full scan) | 128K | 12.8ms | -56% |
| Naive (full scan) | 32K | 3.2ms | -24% |
| Precomputed masks | 128K | 0.1ms (lookup) | -0.8% |
| Precomputed masks | 32K | 0.05ms (lookup) | -0.4% |
| XGrammar (byte-level) | 128K | 0.08ms | -0.6% |
1.2 The Three Costs
Structured output overhead decomposes into three components:
-
Compilation cost: Converting a JSON schema into an FSM. This is where is the number of schema nodes and is the regex complexity of each constraint. For complex schemas, compilation takes 50-500ms.
-
Per-token mask computation: Determining which tokens are valid at the current FSM state. Naive: FSM transitions. Precomputed: lookup.
-
Mask application: Element-wise addition of the mask tensor to logits. This is but runs on GPU and takes under 0.01ms. Negligible.
The optimization strategy is to shift all possible work from per-token to per-schema (compilation) or per-state (precomputation), leaving only lookups in the hot path.
Compressed FSMs: Minimizing State Count
2.1 Why State Count Matters
A JSON schema like {"name": string, "age": integer} compiles naively into an FSM with hundreds of states. Each character in the fixed strings ({"name":, "age":, etc.) gets its own state. The string value matching adds states for escape sequences. The integer matching adds states for sign, digits, and bounds.
The problem: precomputed token masks are stored per state. If we precompute a bitmask of size for each state, memory usage is bytes. For 500 states and 128K vocabulary: MB. For a complex schema with 5000 states: 80 MB. When serving hundreds of different schemas concurrently, this memory adds up.
State minimization reduces both memory and compilation time.
2.2 DFA Minimization for JSON Schemas
The key insight: many FSM states have identical allowed token sets. In JSON generation, the states for "n", "a", "m", "e" in the key "name" all allow exactly one token each. But after minimization, sequential single-token-allowed states can be collapsed.
import hashlib
from collections import defaultdict
class CompressedFSM:
"""
FSM with state minimization via partition refinement.
States with identical transition functions are merged.
"""
def __init__(self):
self.transitions = {} # (state, byte) -> next_state
self.accepting = set()
self.start_state = 0
self.num_states = 0
def add_state(self, accepting=False):
state_id = self.num_states
self.num_states += 1
if accepting:
self.accepting.add(state_id)
return state_id
def add_transition(self, from_state, byte_val, to_state):
self.transitions[(from_state, byte_val)] = to_state
def minimize(self):
"""
Hopcroft's algorithm for DFA minimization.
Partition states into equivalence classes where states in the
same class have identical behavior for all inputs.
"""
# Initial partition: accepting vs non-accepting
accepting = frozenset(self.accepting)
non_accepting = frozenset(set(range(self.num_states)) - self.accepting)
partitions = set()
if accepting:
partitions.add(accepting)
if non_accepting:
partitions.add(non_accepting)
worklist = list(partitions)
alphabet = set(b for (_, b) in self.transitions.keys())
while worklist:
target_set = worklist.pop()
for byte_val in alphabet:
# States that transition into target_set on byte_val
sources = set()
for s in range(self.num_states):
if self.transitions.get((s, byte_val)) in target_set:
sources.add(s)
if not sources:
continue
new_partitions = set()
for partition in partitions:
intersection = partition & sources
difference = partition - sources
if intersection and difference:
new_partitions.add(frozenset(intersection))
new_partitions.add(frozenset(difference))
if partition in worklist:
worklist.remove(partition)
worklist.append(frozenset(intersection))
worklist.append(frozenset(difference))
else:
if len(intersection) <= len(difference):
worklist.append(frozenset(intersection))
else:
worklist.append(frozenset(difference))
else:
new_partitions.add(partition)
partitions = new_partitions
return self._build_minimized(partitions)
def _build_minimized(self, partitions):
"""Build a new FSM from the minimized partition."""
state_to_partition = {}
for i, partition in enumerate(partitions):
for state in partition:
state_to_partition[state] = i
minimized = CompressedFSM()
num_new_states = len(partitions)
for i in range(num_new_states):
minimized.add_state()
minimized.start_state = state_to_partition[self.start_state]
for (from_state, byte_val), to_state in self.transitions.items():
new_from = state_to_partition[from_state]
new_to = state_to_partition[to_state]
minimized.add_transition(new_from, byte_val, new_to)
for acc_state in self.accepting:
minimized.accepting.add(state_to_partition[acc_state])
return minimized
2.3 State Reduction Results
FSM State Count: Raw vs Minimized
(states)The 92% reduction for complex schemas means the mask cache goes from 37 MB to 3 MB. More importantly, precomputation time drops proportionally because we compute one mask per unique state.
Precomputed Token Masks
3.1 Per-State Mask Precomputation
The core optimization: for every FSM state, precompute the set of vocabulary tokens that lead to a valid next state. Store as a bitmask or an index tensor. At decode time, mask application is a single tensor lookup and addition.
import torch
import numpy as np
class PrecomputedMaskCache:
"""
Precomputes and caches token masks for every reachable FSM state.
Converts per-token O(|V|) FSM work into O(1) lookup.
"""
def __init__(self, fsm, tokenizer, device='cuda'):
self.fsm = fsm
self.tokenizer = tokenizer
self.device = device
self.masks = {} # state -> tensor of shape (vocab_size,)
self._token_bytes = self._precompute_token_bytes()
def _precompute_token_bytes(self):
"""
Decode every token ID to its byte representation once.
This avoids repeated tokenizer decode calls.
"""
token_bytes = {}
for token_id in range(self.tokenizer.vocab_size):
try:
decoded = self.tokenizer.decode([token_id])
token_bytes[token_id] = decoded.encode('utf-8')
except Exception:
token_bytes[token_id] = None
return token_bytes
def precompute_all_states(self):
"""
For every reachable state, compute the allowed token mask.
This is the expensive operation -- runs once per schema.
"""
reachable = self._find_reachable_states()
for state in reachable:
mask = torch.full(
(self.tokenizer.vocab_size,),
float('-inf'),
dtype=torch.float32
)
for token_id, token_bytes in self._token_bytes.items():
if token_bytes is None:
continue
if self._can_advance(state, token_bytes):
mask[token_id] = 0.0
self.masks[state] = mask.to(self.device)
return len(reachable)
def _can_advance(self, state, token_bytes):
"""Check if consuming these bytes leads to any valid state."""
current = state
for byte_val in token_bytes:
next_state = self.fsm.transitions.get((current, byte_val))
if next_state is None:
return False
current = next_state
return True
def _find_reachable_states(self):
"""BFS from start state to find all reachable states."""
visited = set()
queue = [self.fsm.start_state]
while queue:
state = queue.pop(0)
if state in visited:
continue
visited.add(state)
for (s, _), next_s in self.fsm.transitions.items():
if s == state and next_s not in visited:
queue.append(next_s)
return visited
def get_mask(self, state):
"""O(1) mask lookup at decode time."""
return self.masks[state]
def memory_usage_bytes(self):
"""Total GPU memory used by cached masks."""
per_mask = self.tokenizer.vocab_size * 4 # float32
return len(self.masks) * per_mask
3.2 Compilation Time vs Runtime Tradeoff
The precomputation is expensive: for each state, we scan the entire vocabulary. With 28 minimized states and 128K tokens, that is FSM advance checks. At 100ns each: ~360ms total compilation time.
But this cost is paid once per schema. If the schema is used for 10,000 requests (common for API endpoints), the amortized cost is 0.036ms per request — negligible.
Precomputation cost of 360ms amortized over requests: ms per request. At , this is 3.6ms — already cheaper than 1 second of naive per-token overhead at 50 tokens. At , it is 0.36ms. The breakeven point is approximately requests for a typical 200-token output.
Speculative JSON Decoding
4.1 Exploiting Grammar Predictability
JSON has a critical property: large portions of the output are deterministic given the schema. When generating {"name": "Alice", "age": 30}, the tokens for {"name": and ", "age": are completely determined by the schema. Only the values (Alice and 30) require model sampling.
Speculative JSON decoding exploits this by identifying deterministic spans and emitting them without model forward passes.
class SpeculativeJSONDecoder:
"""
Exploits JSON schema structure to skip model forward passes
for deterministic tokens (keys, punctuation, fixed strings).
"""
def __init__(self, fsm, mask_cache, tokenizer):
self.fsm = fsm
self.mask_cache = mask_cache
self.tokenizer = tokenizer
self._deterministic_cache = {}
self._precompute_deterministic_spans()
def _precompute_deterministic_spans(self):
"""
For each state, check if only one token is allowed.
If so, follow the chain until a non-deterministic state.
"""
for state in self.mask_cache.masks:
mask = self.mask_cache.get_mask(state)
allowed = (mask == 0.0).nonzero(as_tuple=True)[0]
if len(allowed) == 1:
# Exactly one token allowed -- deterministic
chain = self._follow_deterministic_chain(state)
if len(chain) > 1:
self._deterministic_cache[state] = chain
def _follow_deterministic_chain(self, start_state):
"""Follow chain of single-allowed-token states."""
chain = []
current_state = start_state
while current_state in self.mask_cache.masks:
mask = self.mask_cache.get_mask(current_state)
allowed = (mask == 0.0).nonzero(as_tuple=True)[0]
if len(allowed) != 1:
break
token_id = allowed[0].item()
chain.append(token_id)
# Advance FSM state
token_bytes = self.tokenizer.decode([token_id]).encode('utf-8')
next_state = current_state
for b in token_bytes:
next_state = self.fsm.transitions.get((next_state, b))
if next_state is None:
return chain
current_state = next_state
return chain
def decode_step(self, model, input_ids, current_state):
"""
Generate next token(s). If current state starts a deterministic
chain, emit all deterministic tokens without model calls.
"""
if current_state in self._deterministic_cache:
# Emit deterministic tokens without model forward pass
chain = self._deterministic_cache[current_state]
return chain, self._advance_state(current_state, chain)
# Non-deterministic: run model forward pass with mask
logits = model(input_ids).logits[:, -1, :]
mask = self.mask_cache.get_mask(current_state)
masked_logits = logits + mask
token_id = torch.multinomial(
torch.softmax(masked_logits, dim=-1), 1
).item()
new_state = self._advance_state_single(current_state, token_id)
return [token_id], new_state
def _advance_state(self, state, token_ids):
"""Advance FSM through a sequence of tokens."""
current = state
for tid in token_ids:
current = self._advance_state_single(current, tid)
return current
def _advance_state_single(self, state, token_id):
"""Advance FSM by one token."""
token_bytes = self.tokenizer.decode([token_id]).encode('utf-8')
current = state
for b in token_bytes:
current = self.fsm.transitions.get((current, b), current)
return current
4.2 Tokens Saved Analysis
The fraction of deterministic tokens depends on schema structure and value lengths.
Deterministic Token Fraction by Schema Type
(%)For enum-heavy schemas, speculative JSON decoding eliminates 71% of model forward passes. Each skipped forward pass saves 5-15ms on a 70B model, yielding 3-5x end-to-end speedup for structured output generation.
4.3 Interaction with Standard Speculative Decoding
Speculative JSON decoding is orthogonal to standard speculative decoding (draft model + verification). They can be composed: use the grammar to skip deterministic spans, and use a draft model for the non-deterministic value tokens. The combined speedup is multiplicative for the non-deterministic portion.
class CombinedSpeculativeDecoder:
"""
Combines grammar-based speculation (deterministic spans)
with draft-model speculation (non-deterministic values).
"""
def __init__(self, target_model, draft_model, fsm, mask_cache, tokenizer):
self.target = target_model
self.draft = draft_model
self.json_spec = SpeculativeJSONDecoder(fsm, mask_cache, tokenizer)
self.mask_cache = mask_cache
self.tokenizer = tokenizer
def generate(self, input_ids, max_tokens=512, num_speculative=4):
generated = []
current_state = self.json_spec.fsm.start_state
while len(generated) < max_tokens:
# Phase 1: emit deterministic tokens (no model call)
if current_state in self.json_spec._deterministic_cache:
chain = self.json_spec._deterministic_cache[current_state]
generated.extend(chain)
current_state = self.json_spec._advance_state(
current_state, chain
)
input_ids = self._append_tokens(input_ids, chain)
continue
# Phase 2: draft-model speculation for value tokens
draft_tokens = self._draft_with_constraints(
input_ids, current_state, num_speculative
)
# Phase 3: verify with target model
accepted = self._verify_and_accept(
input_ids, draft_tokens, current_state
)
generated.extend(accepted)
for tid in accepted:
current_state = self.json_spec._advance_state_single(
current_state, tid
)
input_ids = self._append_tokens(input_ids, accepted)
# Check for end state
if current_state in self.json_spec.fsm.accepting:
break
return generated
def _draft_with_constraints(self, input_ids, state, k):
"""Generate k draft tokens respecting grammar constraints."""
draft_tokens = []
current_state = state
draft_input = input_ids
for _ in range(k):
logits = self.draft(draft_input).logits[:, -1, :]
mask = self.mask_cache.get_mask(current_state)
masked_logits = logits + mask
token_id = torch.multinomial(
torch.softmax(masked_logits, dim=-1), 1
).item()
draft_tokens.append(token_id)
current_state = self.json_spec._advance_state_single(
current_state, token_id
)
draft_input = self._append_tokens(draft_input, [token_id])
return draft_tokens
def _verify_and_accept(self, input_ids, draft_tokens, state):
"""Verify draft tokens with target model, accept prefix."""
all_tokens = input_ids
for t in draft_tokens:
all_tokens = self._append_tokens(all_tokens, [t])
target_logits = self.target(all_tokens).logits
accepted = []
current_state = state
for i, draft_token in enumerate(draft_tokens):
logit_position = -(len(draft_tokens) - i)
step_logits = target_logits[:, logit_position, :]
mask = self.mask_cache.get_mask(current_state)
masked_logits = step_logits + mask
probs = torch.softmax(masked_logits, dim=-1)
if probs[0, draft_token].item() > 0.1:
accepted.append(draft_token)
current_state = self.json_spec._advance_state_single(
current_state, draft_token
)
else:
# Reject: sample from target distribution
token_id = torch.multinomial(probs, 1).item()
accepted.append(token_id)
break
return accepted
def _append_tokens(self, input_ids, tokens):
new_tokens = torch.tensor([tokens], device=input_ids.device)
return torch.cat([input_ids, new_tokens], dim=-1)
XGrammar: Byte-Level Pushdown Automaton
5.1 Beyond Regular Languages
FSMs handle regular languages: JSON structure, regex patterns, fixed enum values. But some constraints require context-free grammars: matched parentheses, nested brackets to arbitrary depth, recursive schemas. JSON itself is context-free due to nested objects and arrays.
XGrammar (used in SGLang) implements a byte-level pushdown automaton (PDA) that handles full context-free grammars while maintaining near-FSM performance for the common case.
class BytePushdownAutomaton:
"""
Pushdown automaton operating on byte-level transitions.
Handles context-free grammars (nested JSON, recursive schemas)
while maintaining O(1) amortized token masking via adaptive caching.
"""
def __init__(self, grammar_rules):
self.rules = grammar_rules # BNF-style production rules
self.state_stack = []
self.current_state = 0
self.transition_table = {}
self.push_actions = {} # state -> (new_state, push_symbol)
self.pop_actions = {} # (state, stack_top) -> new_state
self._compile_grammar()
def _compile_grammar(self):
"""
Compile grammar rules into PDA transitions.
Each rule becomes a set of byte-level transitions with
push/pop annotations for recursive structures.
"""
for rule_name, alternatives in self.rules.items():
for alt in alternatives:
self._compile_alternative(rule_name, alt)
def _compile_alternative(self, rule_name, tokens):
"""Compile one alternative of a grammar rule."""
state = self._get_or_create_state(rule_name, 0)
for i, token in enumerate(tokens):
if token.startswith('<') and token.endswith('>'):
# Non-terminal: push current context, recurse
next_state = self._get_or_create_state(rule_name, i + 1)
ref_rule = token[1:-1]
ref_start = self._get_or_create_state(ref_rule, 0)
self.push_actions[state] = (ref_start, (rule_name, i + 1))
state = next_state
else:
# Terminal: byte-level transitions
for byte_val in token.encode('utf-8'):
next_state = self._get_or_create_state(
rule_name, i + 1
)
self.transition_table[(state, byte_val)] = next_state
state = next_state
def _get_or_create_state(self, rule, position):
key = (rule, position)
if key not in self._state_map:
self._state_map[key] = self._next_state_id
self._next_state_id += 1
return self._state_map[key]
def advance(self, byte_val):
"""
Advance the automaton by one byte.
Handles push/pop for recursive structures.
"""
# Check for push (entering nested structure)
if self.current_state in self.push_actions:
new_state, push_symbol = self.push_actions[self.current_state]
self.state_stack.append(push_symbol)
self.current_state = new_state
# Normal transition
next_state = self.transition_table.get(
(self.current_state, byte_val)
)
if next_state is not None:
self.current_state = next_state
return True
# Check for pop (exiting nested structure)
if self.state_stack:
pop_symbol = self.state_stack[-1]
pop_state = self.pop_actions.get(
(self.current_state, pop_symbol)
)
if pop_state is not None:
self.state_stack.pop()
self.current_state = pop_state
return self.advance(byte_val)
return False
def _state_map(self):
return {}
_state_map = {}
_next_state_id = 0
5.2 Adaptive Masking: The XGrammar Optimization
XGrammar’s key insight: the token mask depends on the PDA state AND the stack. But in practice, JSON schemas have bounded nesting depth (typically 3-5 levels). XGrammar caches masks for (state, stack_hash) pairs, where stack_hash captures the top K stack elements.
class AdaptiveMaskCache:
"""
Cache token masks keyed by (pda_state, stack_signature).
The stack signature uses the top K elements for bounded-depth grammars.
"""
def __init__(self, pda, tokenizer, stack_depth=4):
self.pda = pda
self.tokenizer = tokenizer
self.stack_depth = stack_depth
self.cache = {}
self.hit_count = 0
self.miss_count = 0
def get_mask(self, state, stack):
"""
Get or compute token mask for current PDA configuration.
Cache key includes stack signature for context sensitivity.
"""
stack_sig = tuple(stack[-self.stack_depth:])
cache_key = (state, stack_sig)
if cache_key in self.cache:
self.hit_count += 1
return self.cache[cache_key]
self.miss_count += 1
mask = self._compute_mask(state, stack)
self.cache[cache_key] = mask
return mask
def _compute_mask(self, state, stack):
"""Compute allowed token mask by simulating PDA for each token."""
vocab_size = self.tokenizer.vocab_size
mask = torch.full((vocab_size,), float('-inf'))
for token_id in range(vocab_size):
token_bytes = self.tokenizer.decode([token_id]).encode('utf-8')
if self._simulate_token(state, list(stack), token_bytes):
mask[token_id] = 0.0
return mask
def _simulate_token(self, state, stack, token_bytes):
"""Simulate PDA consuming a multi-byte token."""
sim_state = state
sim_stack = list(stack)
for byte_val in token_bytes:
if sim_state in self.pda.push_actions:
new_state, push_sym = self.pda.push_actions[sim_state]
sim_stack.append(push_sym)
sim_state = new_state
next_state = self.pda.transition_table.get(
(sim_state, byte_val)
)
if next_state is None:
return False
sim_state = next_state
return True
def cache_stats(self):
total = self.hit_count + self.miss_count
hit_rate = self.hit_count / total if total > 0 else 0
return {
'entries': len(self.cache),
'hits': self.hit_count,
'misses': self.miss_count,
'hit_rate': hit_rate,
'memory_mb': len(self.cache) * self.tokenizer.vocab_size * 4 / 1e6
}
For typical JSON schemas with 2-3 levels of nesting, the adaptive cache achieves 95-99% hit rate after warming up on the first few requests. The cache stabilizes at 50-200 entries for most schemas, consuming 25-100 MB of GPU memory. For schemas with unbounded recursion depth, the hit rate drops to 60-80%, and a cache eviction policy (LRU) becomes necessary.
Cross-Request Grammar Caching
6.1 The Observation
In production, the same JSON schema is used across thousands of requests. An API endpoint that returns structured responses uses one schema. A function-calling system has a fixed set of tool schemas. Recompiling the FSM and recomputing token masks for every request wastes enormous computation.
6.2 Schema-Level Cache Architecture
import hashlib
import threading
import time
class GrammarCache:
"""
Cross-request cache for compiled grammars and precomputed masks.
Key insight: same schema = same FSM = same masks.
"""
def __init__(self, max_entries=1024, max_memory_mb=4096):
self.cache = {} # schema_hash -> CachedGrammar
self.access_order = []
self.max_entries = max_entries
self.max_memory_mb = max_memory_mb
self.lock = threading.Lock()
self.stats = {
'hits': 0, 'misses': 0,
'compilations': 0, 'evictions': 0
}
def get_or_compile(self, json_schema, tokenizer, device='cuda'):
"""
Return compiled FSM + mask cache for the given schema.
Compiles on first request, returns cached on subsequent.
"""
schema_hash = self._hash_schema(json_schema)
with self.lock:
if schema_hash in self.cache:
self.stats['hits'] += 1
entry = self.cache[schema_hash]
entry.last_access = time.time()
entry.access_count += 1
self._promote(schema_hash)
return entry.fsm, entry.mask_cache
self.stats['misses'] += 1
# Compile outside lock (expensive, don't block other requests)
fsm = self._compile_schema(json_schema)
minimized_fsm = fsm.minimize()
mask_cache = PrecomputedMaskCache(minimized_fsm, tokenizer, device)
num_states = mask_cache.precompute_all_states()
entry = CachedGrammar(
fsm=minimized_fsm,
mask_cache=mask_cache,
schema_hash=schema_hash,
num_states=num_states,
memory_mb=mask_cache.memory_usage_bytes() / 1e6
)
with self.lock:
# Check again in case another thread compiled the same schema
if schema_hash in self.cache:
return self.cache[schema_hash].fsm, \
self.cache[schema_hash].mask_cache
self._evict_if_needed(entry.memory_mb)
self.cache[schema_hash] = entry
self.access_order.append(schema_hash)
self.stats['compilations'] += 1
return minimized_fsm, mask_cache
def _hash_schema(self, schema):
"""Deterministic hash of JSON schema for cache key."""
import json
canonical = json.dumps(schema, sort_keys=True, separators=(',', ':'))
return hashlib.sha256(canonical.encode()).hexdigest()
def _evict_if_needed(self, new_entry_mb):
"""Evict LRU entries if cache exceeds memory or entry limits."""
current_memory = sum(e.memory_mb for e in self.cache.values())
while (len(self.cache) >= self.max_entries or
current_memory + new_entry_mb > self.max_memory_mb):
if not self.access_order:
break
victim_hash = self.access_order.pop(0)
if victim_hash in self.cache:
current_memory -= self.cache[victim_hash].memory_mb
del self.cache[victim_hash]
self.stats['evictions'] += 1
def _promote(self, schema_hash):
"""Move schema to end of LRU order."""
if schema_hash in self.access_order:
self.access_order.remove(schema_hash)
self.access_order.append(schema_hash)
def _compile_schema(self, json_schema):
"""Compile JSON schema to FSM."""
fsm = CompressedFSM()
start = fsm.add_state()
fsm.start_state = start
self._compile_object(fsm, json_schema, start)
return fsm
def _compile_object(self, fsm, schema, start_state):
"""Recursively compile JSON schema into FSM transitions."""
schema_type = schema.get('type', 'object')
if schema_type == 'object':
current = start_state
# Opening brace
brace_state = fsm.add_state()
fsm.add_transition(current, ord('{'), brace_state)
current = brace_state
properties = schema.get('properties', {})
for i, (key, value_schema) in enumerate(properties.items()):
# Key string
current = self._compile_literal(
fsm, current, f'"{key}":'
)
# Value
current = self._compile_value(fsm, current, value_schema)
# Comma between fields
if i < len(properties) - 1:
comma_state = fsm.add_state()
fsm.add_transition(current, ord(','), comma_state)
current = comma_state
# Closing brace
end = fsm.add_state(accepting=True)
fsm.add_transition(current, ord('}'), end)
elif schema_type == 'string':
current = start_state
quote_open = fsm.add_state()
fsm.add_transition(current, ord('"'), quote_open)
# String content (simplified: printable ASCII)
for b in range(32, 127):
if b != ord('"') and b != ord('\\'):
fsm.add_transition(quote_open, b, quote_open)
quote_close = fsm.add_state(accepting=True)
fsm.add_transition(quote_open, ord('"'), quote_close)
def _compile_literal(self, fsm, start, literal):
"""Add transitions for a literal byte string."""
current = start
for byte_val in literal.encode('utf-8'):
next_state = fsm.add_state()
fsm.add_transition(current, byte_val, next_state)
current = next_state
return current
def _compile_value(self, fsm, start, schema):
"""Compile a value schema into FSM transitions."""
value_type = schema.get('type', 'string')
if value_type == 'string':
return self._compile_string_value(fsm, start)
elif value_type == 'integer':
return self._compile_integer_value(fsm, start)
elif value_type == 'boolean':
return self._compile_boolean_value(fsm, start)
return start
def _compile_string_value(self, fsm, start):
quote_open = fsm.add_state()
fsm.add_transition(start, ord('"'), quote_open)
for b in range(32, 127):
if b != ord('"') and b != ord('\\'):
fsm.add_transition(quote_open, b, quote_open)
quote_close = fsm.add_state()
fsm.add_transition(quote_open, ord('"'), quote_close)
return quote_close
def _compile_integer_value(self, fsm, start):
digit_state = fsm.add_state()
for d in range(ord('0'), ord('9') + 1):
fsm.add_transition(start, d, digit_state)
fsm.add_transition(digit_state, d, digit_state)
return digit_state
def _compile_boolean_value(self, fsm, start):
# "true" path
t = self._compile_literal(fsm, start, 'true')
# "false" path
f = self._compile_literal(fsm, start, 'false')
merge = fsm.add_state()
fsm.add_transition(t, 0, merge)
fsm.add_transition(f, 0, merge)
return merge
class CachedGrammar:
"""Metadata for a cached grammar entry."""
def __init__(self, fsm, mask_cache, schema_hash,
num_states, memory_mb):
self.fsm = fsm
self.mask_cache = mask_cache
self.schema_hash = schema_hash
self.num_states = num_states
self.memory_mb = memory_mb
self.created_at = time.time()
self.last_access = time.time()
self.access_count = 1
Grammar Cache Performance (Production API Serving)
| Metric | No Cache | With Cache | Improvement |
|---|---|---|---|
| First request latency | 380ms | 380ms | None (cold start) |
| Subsequent request latency | 380ms | 0.1ms | 3800x |
| p99 TTFT (100 schemas) | 520ms | 12ms | 43x |
| GPU memory for masks | N/A (recomputed) | 2.4 GB (100 schemas) | Bounded |
| Throughput (req/s) | 145 | 892 | 6.2x |
Putting It All Together: Production Pipeline
7.1 End-to-End Structured Output Engine
import torch
import threading
from collections import OrderedDict
class StructuredOutputEngine:
"""
Production structured output engine combining:
- Compressed FSM compilation
- Precomputed token masks
- Speculative JSON decoding
- Cross-request grammar caching
"""
def __init__(self, model, tokenizer, device='cuda',
cache_max_schemas=512, cache_max_memory_mb=2048):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.grammar_cache = GrammarCache(
max_entries=cache_max_schemas,
max_memory_mb=cache_max_memory_mb
)
self.request_count = 0
self.total_tokens_saved = 0
def generate(self, input_ids, json_schema, max_tokens=1024,
temperature=1.0, use_speculation=True):
"""
Generate structured output conforming to json_schema.
Returns:
generated_ids: list of token IDs
stats: dict with generation statistics
"""
self.request_count += 1
# Step 1: Get or compile grammar (cached)
fsm, mask_cache = self.grammar_cache.get_or_compile(
json_schema, self.tokenizer, self.device
)
# Step 2: Initialize decoder
if use_speculation:
decoder = SpeculativeJSONDecoder(
fsm, mask_cache, self.tokenizer
)
else:
decoder = None
# Step 3: Generate token by token
generated = []
current_state = fsm.start_state
current_input = input_ids
tokens_from_speculation = 0
for step in range(max_tokens):
# Check if we reached an accepting state
if current_state in fsm.accepting:
break
if use_speculation and decoder is not None:
tokens, current_state = decoder.decode_step(
self.model, current_input, current_state
)
if len(tokens) > 1:
tokens_from_speculation += len(tokens) - 1
else:
# Standard constrained decoding with cached masks
with torch.no_grad():
logits = self.model(current_input).logits[:, -1, :]
mask = mask_cache.get_mask(current_state)
masked_logits = logits + mask
if temperature != 1.0:
masked_logits = masked_logits / temperature
probs = torch.softmax(masked_logits, dim=-1)
token_id = torch.multinomial(probs, 1).item()
tokens = [token_id]
# Advance FSM
token_bytes = self.tokenizer.decode(
[token_id]
).encode('utf-8')
for b in token_bytes:
next_state = fsm.transitions.get(
(current_state, b)
)
if next_state is not None:
current_state = next_state
generated.extend(tokens)
new_tokens = torch.tensor(
[tokens], device=current_input.device
)
current_input = torch.cat(
[current_input, new_tokens], dim=-1
)
self.total_tokens_saved += tokens_from_speculation
stats = {
'total_tokens': len(generated),
'speculated_tokens': tokens_from_speculation,
'speculation_ratio': (
tokens_from_speculation / len(generated)
if generated else 0
),
'cache_stats': self.grammar_cache.stats,
'fsm_states': fsm.num_states,
}
return generated, stats
def warmup(self, schemas):
"""
Pre-compile a set of known schemas during server startup.
Eliminates cold-start latency for known schemas.
"""
for schema in schemas:
self.grammar_cache.get_or_compile(
schema, self.tokenizer, self.device
)
7.2 Benchmark: Full Pipeline Performance
End-to-End Structured Output Performance (A100, Llama 3 70B)
| Configuration | Tokens/sec | p50 Latency | p99 Latency | Validity |
|---|---|---|---|---|
| Unconstrained + retry | 42 | 285ms | 2100ms | 97.2% |
| Naive FSM (per-token) | 31 | 412ms | 890ms | 100% |
| Precomputed masks only | 41 | 290ms | 520ms | 100% |
| + Speculative JSON | 58 | 195ms | 380ms | 100% |
| + Grammar cache (warm) | 58 | 192ms | 350ms | 100% |
| + FSM minimization | 58 | 192ms | 345ms | 100% |
The full pipeline achieves 100% output validity with higher throughput than unconstrained generation with retries. The throughput gain comes from eliminating retry overhead and the latency reduction from speculative JSON decoding. Grammar caching eliminates the cold-start penalty after the first request per schema.