How do you serve 100 simultaneous users when each request takes a different amount of compute? One user asks for a three-word answer, another wants a 1000-token essay, and you need to keep the GPU saturated for both. Static batching says “lock them into the same batch and run until everyone finishes,” which means the three-word answer sits idle for 997 decode steps waiting for the essay to complete. GPU utilization collapses to 40-50% as batch slots empty out. The naive solution — smaller batches — makes it worse because you’re not filling the GPU. The real solution is iteration-level scheduling: after every single decode step, re-evaluate the batch. Finished requests leave immediately, waiting requests enter immediately, and the GPU stays at 95% utilization instead of 45%. This is what Orca introduced and every production serving system now implements.
Orca introduced the idea of continuous batching with iteration-level scheduling: after every single decode step, the scheduler re-evaluates the batch. Finished requests are evicted immediately, and waiting requests are admitted in their place. Sarathi extended this by splitting prefill computation into chunks, preventing a long prefill from stalling ongoing decode requests. Together, these techniques form the scheduling foundation of every production LLM serving system — vLLM, SGLang, TensorRT-LLM, and DeepSpeed-FastGen all implement variants of Orca-style continuous batching with Sarathi-style chunked prefill.
This post covers the scheduling algorithms in detail: what iteration-level scheduling actually means at the implementation level, how chunked prefill works, the token budget optimization problem, priority scheduling, and a complete reference implementation.
The Problem with Static Batching
Consider a batch of 4 requests with different generation lengths:
Request A: prompt=128 tokens, generate=500 tokens -> total=628 steps
Request B: prompt=64 tokens, generate=50 tokens -> total=114 steps
Request C: prompt=256 tokens, generate=200 tokens -> total=456 steps
Request D: prompt=32 tokens, generate=30 tokens -> total=62 steps
With static batching, all 4 requests start together. The batch runs for 628 steps (the maximum). GPU utilization per slot over time:
Step 0-62: A active, B active, C active, D active -> 4/4 = 100%
Step 63-114: A active, B active, C active, D done -> 3/4 = 75%
Step 115-456: A active, B done, C active, D done -> 2/4 = 50%
Step 457-628: A active, B done, C done, D done -> 1/4 = 25%
Weighted average utilization:
Half the GPU capacity is wasted. In a real serving scenario with hundreds of requests, the variance in generation length is even higher (some requests produce 1 token, others produce 4096), making the problem worse.
Throughput Impact
The throughput loss from static batching is direct:
where is the time per decode step. With continuous batching, those empty slots would be filled by new requests, generating additional tokens during the same wall-clock time.
GPU Slot Utilization: Static vs Continuous Batching
(%)Orca: Iteration-Level Scheduling
Orca (Yu et al., OSDI 2022) introduced two key ideas: iteration-level scheduling and selective batching. The core insight: do not commit to a batch for the entire generation. Instead, make a scheduling decision at every single decode iteration.
The Scheduling Loop
At each iteration, the Orca scheduler executes:
# Orca scheduling loop (simplified)
while True:
# Step 1: Remove finished requests
for req in running_batch:
if req.is_finished():
running_batch.remove(req)
completed_requests.append(req)
# Step 2: Admit new requests (if capacity allows)
while len(running_batch) < max_batch_size and waiting_queue:
new_req = waiting_queue.popleft()
# Run prefill for new request
prefill(new_req)
running_batch.append(new_req)
# Step 3: Run one decode step for all running requests
if running_batch:
decode_step(running_batch)
This loop runs every 20-50ms (the time for one decode step on a 70B model). The scheduler has sub-second granularity for admitting and evicting requests.
Selective Batching
Not all operations in a transformer step can be efficiently batched across requests with different states. Orca identified which operations can share computation:
Operations that CAN be batched:
- Linear projections (QKV, O, FFN): same weight matrix, different activations
Batched as: [B_total, d] x [d, d_out]
where B_total = sum of tokens across all requests in the batch
- Layer norm: element-wise, independent per token
- Activation functions: element-wise, independent per token
Operations that CANNOT be trivially batched:
- Attention: each request has different KV cache length
Solution: per-request attention with custom kernels (FlashAttention, PagedAttention)
The key insight is that linear layers (which dominate compute) naturally batch across requests. We concatenate the current token from each decode request into a single tensor and run one GEMM:
def batched_decode_step(requests):
"""Execute one decode step for a batch of requests."""
# Gather current tokens from all requests
# Each decode request contributes exactly 1 token
batch_tokens = torch.stack([req.current_token_embedding for req in requests])
# Shape: [num_requests, hidden_dim]
for layer in transformer_layers:
# QKV projection: batched GEMM
qkv = batch_tokens @ layer.qkv_weight
# Shape: [num_requests, qkv_dim]
q, k, v = split_qkv(qkv)
# Attention: per-request (different KV cache lengths)
attn_outputs = []
for i, req in enumerate(requests):
attn_out = flash_attention(
q[i:i+1], # [1, num_heads, head_dim]
req.kv_cache.k, # [seq_len_i, num_kv_heads, head_dim]
req.kv_cache.v, # [seq_len_i, num_kv_heads, head_dim]
)
attn_outputs.append(attn_out)
batch_tokens = torch.cat(attn_outputs, dim=0)
# Output projection: batched GEMM
batch_tokens = batch_tokens @ layer.o_weight
# FFN: batched GEMM
batch_tokens = ffn_forward(batch_tokens, layer)
return batch_tokens
The per-request attention loop is the bottleneck in naive Orca. PagedAttention (vLLM) and FlashDecoding solve this by batching the attention computation across requests using variable-length attention kernels.
Sarathi: Chunked Prefill to Prevent Decode Stalling
Orca has a critical problem: when a new request is admitted, its prefill must be completed before the next decode step. A prompt with 2048 tokens requires processing all 2048 tokens through every transformer layer. On a 70B model, this takes 200-500ms. During that time, all ongoing decode requests are stalled — they receive no service for 200-500ms.
For a service with a 100ms time-between-tokens (TBT) SLO, a single 2048-token prefill causes all concurrent decode requests to violate the SLO.
The Prefill-Decode Interference Problem
Timeline WITHOUT chunked prefill:
Time | Decode batch | New request
-------|------------------|-----------------
0ms | decode step 1 | (waiting)
30ms | decode step 2 | (waiting)
60ms | decode step 3 | admitted!
60ms | STALLED | prefill (2048 tokens)
| STALLED | ...processing...
360ms | STALLED | prefill complete
360ms | decode step 4 | decode step 1
390ms | decode step 5 | decode step 2
Decode requests experienced 300ms stall (steps 4-6 delayed).
TBT for decode batch: 30ms, 30ms, 300ms, 30ms, ... -> SLO violated
Sarathi’s Solution: Chunked Prefill
Sarathi (Agrawal et al., 2023) splits prefill into fixed-size chunks and interleaves them with decode steps:
Timeline WITH chunked prefill (chunk_size=256):
Time | Decode batch | New request
-------|----------------------|------------------
0ms | decode step 1 | (waiting)
30ms | decode step 2 | admitted!
35ms | decode step 3 | prefill chunk 1 (tokens 0-255)
70ms | decode step 4 | prefill chunk 2 (tokens 256-511)
105ms | decode step 5 | prefill chunk 3 (tokens 512-767)
140ms | decode step 6 | prefill chunk 4 (tokens 768-1023)
175ms | decode step 7 | prefill chunk 5 (tokens 1024-1279)
210ms | decode step 8 | prefill chunk 6 (tokens 1280-1535)
245ms | decode step 9 | prefill chunk 7 (tokens 1536-1791)
280ms | decode step 10 | prefill chunk 8 (tokens 1792-2047)
315ms | decode step 11 | decode step 1
TBT for decode batch: 30ms, 35ms, 35ms, 35ms, ... -> SLO maintained
Each iteration now processes both decode tokens and a chunk of prefill tokens. The iteration time increases slightly (from 30ms to 35ms) because of the additional prefill computation, but no decode request experiences a multi-hundred-millisecond stall.
How Chunked Prefill Works
The key implementation insight: in each iteration, we process a mixed batch of tokens where some are decode tokens (one per ongoing request) and others are prefill tokens (a chunk for the new request). The linear layers naturally handle this because they just see a batch of tokens — they do not care whether a token is from decode or prefill.
def chunked_prefill_iteration(decode_requests, prefill_request, chunk_size):
"""One scheduler iteration with chunked prefill."""
# Decode requests contribute 1 token each
decode_tokens = [req.current_token for req in decode_requests]
# Prefill request contributes chunk_size tokens
prefill_start = prefill_request.prefill_progress
prefill_end = min(
prefill_start + chunk_size,
prefill_request.prompt_length,
)
prefill_tokens = prefill_request.prompt_tokens[prefill_start:prefill_end]
# Concatenate all tokens into one batch
all_tokens = decode_tokens + prefill_tokens
# Total batch: num_decode + chunk_size tokens
for layer in transformer_layers:
# Linear projections: one GEMM for all tokens
qkv = torch.cat(all_tokens) @ layer.qkv_weight
# Shape: [num_decode + chunk_size, qkv_dim]
# Split back into decode and prefill portions
decode_qkv = qkv[:len(decode_tokens)]
prefill_qkv = qkv[len(decode_tokens):]
# Attention: different handling for decode vs prefill
# Decode: each request attends to its full KV cache
# Prefill chunk: attends to previous chunks + current chunk
decode_attn = batched_decode_attention(decode_qkv, decode_requests)
prefill_attn = chunked_prefill_attention(
prefill_qkv, prefill_request, prefill_start, prefill_end
)
# Merge and continue through FFN
all_tokens = torch.cat([decode_attn, prefill_attn])
all_tokens = ffn_forward(all_tokens, layer)
# Update prefill progress
prefill_request.prefill_progress = prefill_end
# Update decode KV caches
for i, req in enumerate(decode_requests):
req.append_kv(all_tokens[i])
Chunked prefill adds computational overhead: each prefill chunk requires reading the KV cache from all previous chunks for its attention computation. A 2048-token prompt split into 8 chunks of 256 tokens computes attention 8 times instead of once. However, the attention cost for 256 tokens attending to up to 2048 KV entries is small compared to the linear projection cost, which is identical whether chunked or not. The total extra compute from chunking is typically 3-8% for the prefill phase. The benefit — maintaining decode TBT SLOs — far outweighs this cost.
Token Budget Optimization
Each scheduler iteration has a token budget: the maximum number of tokens that can be processed in one forward pass. This budget is constrained by GPU memory (KV cache, activations) and by the latency SLO (more tokens per iteration means longer iteration time, which increases decode TBT).
The Token Budget Constraint
where:
- = number of decode tokens (one per running request, fixed)
- = number of prefill tokens to process this iteration (tunable)
The iteration time is approximately:
For the GEMM component, with total tokens :
where is the number of layers and covers all projections in one layer.
For a 70B model with 80 layers, , total projection parameters ~1.7B per layer:
At 800 TFLOPS (achievable on H100 with large batches):
- (32 decode tokens):
- (32 decode + 256 prefill):
- (32 decode + 512 prefill):
Iteration Time vs Token Budget (Llama 70B, H100 SXM)
| Token Budget | Decode Tokens | Prefill Tokens | GEMM Time | Attention Time | Total Iter Time |
|---|---|---|---|---|---|
| 32 | 32 | 0 | 10.9ms | 8ms | 21ms |
| 160 | 32 | 128 | 54.4ms | 15ms | 72ms |
| 288 | 32 | 256 | 97.9ms | 25ms | 126ms |
| 544 | 32 | 512 | 184.9ms | 48ms | 237ms |
| 1056 | 32 | 1024 | 359.0ms | 95ms | 460ms |
Optimal Chunk Size Derivation
The TBT SLO constrains the iteration time:
Solving for :
The optimal prefill chunk size maximizes throughput while respecting the SLO:
def compute_optimal_chunk_size(
tbt_slo_ms,
num_decode_requests,
avg_kv_cache_length,
model_params,
gpu_tflops,
):
"""Compute maximum prefill chunk size that respects TBT SLO."""
# GEMM time per token (all layers)
flops_per_token = 2 * model_params.total_projection_params * model_params.num_layers
gemm_time_per_token_ms = flops_per_token / (gpu_tflops * 1e9)
# Attention time estimate (scales with KV length and batch)
# Approximate: attention time is proportional to
# num_decode * avg_kv_length * num_heads * head_dim
attn_time_ms = estimate_attention_time(
num_decode_requests, avg_kv_cache_length, model_params
)
# Overhead (kernel launch, scheduling, memory allocation)
overhead_ms = 2.0 # typically 1-3ms
# Solve for max prefill tokens
available_ms = tbt_slo_ms - attn_time_ms - overhead_ms
total_tokens = available_ms / gemm_time_per_token_ms
max_prefill_tokens = int(total_tokens - num_decode_requests)
# Floor to multiples of 64 for tensor core alignment
max_prefill_tokens = (max_prefill_tokens // 64) * 64
return max(0, max_prefill_tokens)
Dynamic Token Budget
The optimal chunk size changes as the system state changes:
- More decode requests means fewer tokens available for prefill
- Longer KV caches mean longer attention time, reducing the prefill budget
- GPU thermal throttling reduces TFLOPS, reducing the budget
Production schedulers recompute the token budget every iteration:
class TokenBudgetScheduler:
"""Dynamically adjusts token budget each iteration."""
def __init__(self, tbt_slo_ms, max_tokens_per_iter):
self.tbt_slo_ms = tbt_slo_ms
self.max_tokens = max_tokens_per_iter
self.ema_iter_time = 0.0
self.alpha = 0.1 # EMA smoothing factor
def update_budget(self, last_iter_time_ms, num_decode, num_prefill):
"""Update token budget based on observed iteration time."""
# Update EMA of iteration time per token
total_tokens = num_decode + num_prefill
if total_tokens > 0:
time_per_token = last_iter_time_ms / total_tokens
self.ema_iter_time = (
self.alpha * time_per_token
+ (1 - self.alpha) * self.ema_iter_time
)
def get_prefill_budget(self, num_decode_requests):
"""How many prefill tokens can we process this iteration?"""
if self.ema_iter_time <= 0:
return 256 # default initial budget
# Target: iteration time <= TBT SLO
max_total_tokens = int(self.tbt_slo_ms / self.ema_iter_time)
max_total_tokens = min(max_total_tokens, self.max_tokens)
prefill_budget = max_total_tokens - num_decode_requests
# Align to 64 tokens
prefill_budget = (prefill_budget // 64) * 64
return max(0, prefill_budget)
The token budget is a compute constraint (how many tokens can we process within the latency SLO). There is also a memory budget: the GPU must hold KV caches for all active requests plus activations for the current iteration. The scheduler must respect both constraints. In practice, the memory budget limits the maximum number of concurrent requests, while the token budget limits how much prefill can be interleaved per iteration.
Iteration-Level vs Request-Level Scheduling
The scheduling granularity determines how quickly the system responds to changes in workload.
Request-Level Scheduling
Request-level scheduling makes decisions when requests arrive or complete:
# Request-level: decisions at request boundaries
def request_level_scheduler(request_queue, gpu_pool):
while True:
# Wait for an event
event = wait_for_event() # new request or request completion
if event.type == "new_request":
gpu = find_gpu_with_capacity(gpu_pool)
if gpu:
assign_request(event.request, gpu)
else:
request_queue.append(event.request)
elif event.type == "request_complete":
gpu = event.gpu
if request_queue:
next_req = request_queue.popleft()
assign_request(next_req, gpu)
Granularity: decisions every few hundred milliseconds to seconds (average request duration). Simple to implement, but responds slowly to load changes.
Iteration-Level Scheduling
Iteration-level scheduling makes decisions every decode step:
# Iteration-level: decisions every 20-50ms
def iteration_level_scheduler(waiting_queue, running_requests, kv_cache_manager):
while True:
# Step 1: Evict finished requests (every iteration)
finished = [r for r in running_requests if r.is_finished()]
for req in finished:
running_requests.remove(req)
kv_cache_manager.free(req.kv_cache_blocks)
# Step 2: Preempt if memory pressure (every iteration)
while kv_cache_manager.free_blocks < safety_margin:
victim = select_preemption_victim(running_requests)
running_requests.remove(victim)
waiting_queue.appendleft(victim) # re-queue at front
kv_cache_manager.free(victim.kv_cache_blocks)
# Step 3: Admit new requests (every iteration)
prefill_budget = compute_prefill_budget(running_requests)
while prefill_budget > 0 and waiting_queue:
candidate = waiting_queue[0]
blocks_needed = estimate_kv_blocks(candidate)
if kv_cache_manager.free_blocks >= blocks_needed:
waiting_queue.popleft()
running_requests.append(candidate)
kv_cache_manager.allocate(candidate, blocks_needed)
prefill_budget -= min(candidate.remaining_prefill, prefill_budget)
else:
break # no memory for new requests
# Step 4: Execute one forward pass
forward_pass(running_requests, prefill_budget)
Request-Level vs Iteration-Level Scheduling
| Metric | Request-Level | Iteration-Level | Improvement |
|---|---|---|---|
| Scheduling granularity | 100ms-10s | 20-50ms | 10-200x finer |
| Slot utilization | 60-75% | 90-98% | +20-30% |
| P99 TBT (during prefill) | 200-500ms | 35-80ms | 3-10x lower |
| Scheduler CPU overhead | 0.01% | 0.1-0.5% | 10-50x higher |
| Implementation complexity | Low | High | — |
| Memory management | Static allocation | Dynamic (PagedAttention) | — |
CPU Overhead of Iteration-Level Scheduling
The scheduler runs on the CPU while the GPU executes kernels. The scheduling decision must complete before the GPU finishes the current iteration, otherwise the GPU stalls waiting for the next batch.
import time
def measure_scheduler_overhead(scheduler, running_requests, waiting_queue):
"""Measure how long the scheduling decision takes."""
iterations = 1000
total_time = 0
for _ in range(iterations):
start = time.perf_counter_ns()
# Simulate scheduling decision
finished = [r for r in running_requests if r.is_finished()]
for req in finished:
running_requests.remove(req)
# Admit new requests
while len(running_requests) < scheduler.max_batch and waiting_queue:
running_requests.append(waiting_queue.popleft())
# Compute token budget
budget = scheduler.get_prefill_budget(len(running_requests))
elapsed_ns = time.perf_counter_ns() - start
total_time += elapsed_ns
avg_us = total_time / iterations / 1000
print(f"Avg scheduling time: {avg_us:.1f} us")
# Typical: 10-100us for batch size 32-256
# Compare to GPU iteration time: 20-50ms
# Overhead: 0.02-0.5%
The scheduling overhead is 10-100 microseconds per iteration, compared to 20-50ms for the GPU forward pass. This is negligible. The CPU has ample time to make the scheduling decision while the GPU is busy.
Priority Scheduling Within an Iteration
Not all requests are equal. In a production serving system, requests have different priorities and different SLO requirements.
Priority Classes
from enum import IntEnum
from dataclasses import dataclass, field
class Priority(IntEnum):
REALTIME = 0 # interactive chat, strict TBT SLO
HIGH = 1 # API requests with SLA
NORMAL = 2 # batch processing, relaxed SLO
LOW = 3 # background tasks, best-effort
@dataclass
class InferenceRequest:
request_id: str
prompt_tokens: list
priority: Priority
tbt_slo_ms: float # time-between-tokens SLO
ttft_slo_ms: float # time-to-first-token SLO
max_tokens: int
# Runtime state
generated_tokens: list = field(default_factory=list)
prefill_progress: int = 0
kv_cache_blocks: list = field(default_factory=list)
last_token_time: float = 0.0
admitted_time: float = 0.0
def is_finished(self):
return (
len(self.generated_tokens) >= self.max_tokens
or (self.generated_tokens and self.generated_tokens[-1] == EOS_TOKEN)
)
@property
def remaining_prefill(self):
return len(self.prompt_tokens) - self.prefill_progress
@property
def tbt_slack(self):
"""How much time until TBT SLO violation, in ms."""
if not self.generated_tokens:
return float("inf") # not yet decoding
elapsed = (time.time() - self.last_token_time) * 1000
return self.tbt_slo_ms - elapsed
Priority-Aware Token Budget Allocation
Within each iteration, the scheduler allocates the token budget with priorities:
def allocate_token_budget(
decode_requests,
prefill_queue,
total_budget,
):
"""
Allocate token budget across decode and prefill requests.
Priority order:
1. All decode tokens (mandatory — one token per running request)
2. Prefill chunks for requests closest to TTFT SLO violation
3. Prefill chunks for lower-priority requests
"""
allocation = {
"decode_tokens": [],
"prefill_chunks": [],
}
remaining_budget = total_budget
# Phase 1: Decode tokens are mandatory
# Sort decode requests by TBT slack (most urgent first)
urgent_decode = sorted(decode_requests, key=lambda r: r.tbt_slack)
for req in urgent_decode:
if remaining_budget <= 0:
# Cannot serve this decode request this iteration
# This means we have too many concurrent requests
break
allocation["decode_tokens"].append(req)
remaining_budget -= 1
# Phase 2: Prefill chunks
# Sort prefill queue by: priority first, then TTFT deadline urgency
sorted_prefill = sorted(
prefill_queue,
key=lambda r: (r.priority, r.ttft_slack()),
)
for req in sorted_prefill:
if remaining_budget <= 0:
break
# Allocate up to chunk_size tokens, but not more than remaining prompt
chunk = min(
remaining_budget,
256, # max chunk size
req.remaining_prefill,
)
# Align to 64 for tensor core efficiency
chunk = (chunk // 64) * 64
if chunk <= 0:
continue
allocation["prefill_chunks"].append((req, chunk))
remaining_budget -= chunk
return allocation
Preemption for SLO Preservation
When the system is overloaded, the scheduler must preempt low-priority requests to preserve SLOs of high-priority ones:
def preemption_check(running_requests, kv_cache_manager):
"""Preempt low-priority requests if high-priority SLOs are at risk."""
# Check if any high-priority request is close to SLO violation
at_risk = [
req for req in running_requests
if req.priority <= Priority.HIGH and req.tbt_slack < 10 # less than 10ms slack
]
if not at_risk:
return []
# Find low-priority candidates for preemption
preempt_candidates = sorted(
[req for req in running_requests if req.priority >= Priority.NORMAL],
key=lambda r: (-r.priority, -len(r.generated_tokens)),
# Preempt lowest priority first, then shortest generation
# (shorter generation = less wasted work)
)
preempted = []
for candidate in preempt_candidates:
# Preempt: save KV cache to CPU (swap) or discard (recompute)
if kv_cache_manager.can_swap_to_cpu(candidate):
kv_cache_manager.swap_out(candidate)
else:
# Discard KV cache; will recompute if re-admitted
kv_cache_manager.free(candidate.kv_cache_blocks)
candidate.prefill_progress = 0 # must re-prefill
running_requests.remove(candidate)
preempted.append(candidate)
# Re-check: is the at-risk request now safe?
if all(r.tbt_slack > 20 for r in at_risk):
break
return preempted
Preempting a request that has generated 200 tokens wastes all the compute spent on those 200 decode steps. If the KV cache can be swapped to CPU memory (200ms for a 70B request with 200 tokens of KV cache), the request can resume later without recomputation. If the KV cache must be discarded, the entire prefill and decode history is lost. Production systems (vLLM, SGLang) use swap-based preemption as the default and only discard KV cache under extreme memory pressure.
Implementation: Complete Iteration-Level Scheduler
Here is a complete, production-style iteration-level scheduler:
import time
import heapq
from collections import deque
from dataclasses import dataclass, field
from enum import IntEnum
EOS_TOKEN = 2 # model-specific
class Priority(IntEnum):
REALTIME = 0
HIGH = 1
NORMAL = 2
LOW = 3
@dataclass(order=True)
class ScheduledRequest:
sort_key: tuple = field(init=False, repr=False)
request_id: str = field(compare=False)
prompt_tokens: list = field(compare=False)
priority: Priority = field(compare=False)
tbt_slo_ms: float = field(compare=False, default=100.0)
ttft_slo_ms: float = field(compare=False, default=2000.0)
max_new_tokens: int = field(compare=False, default=512)
# Runtime state
generated_tokens: list = field(default_factory=list, compare=False)
prefill_progress: int = field(default=0, compare=False)
kv_block_ids: list = field(default_factory=list, compare=False)
last_token_time_ms: float = field(default=0.0, compare=False)
arrival_time_ms: float = field(default=0.0, compare=False)
state: str = field(default="waiting", compare=False) # waiting, prefilling, decoding, finished
def __post_init__(self):
self.sort_key = (self.priority, self.arrival_time_ms)
self.arrival_time_ms = time.monotonic() * 1000
@property
def is_prefilling(self):
return self.prefill_progress < len(self.prompt_tokens)
@property
def is_finished(self):
if len(self.generated_tokens) >= self.max_new_tokens:
return True
if self.generated_tokens and self.generated_tokens[-1] == EOS_TOKEN:
return True
return False
@property
def remaining_prefill_tokens(self):
return len(self.prompt_tokens) - self.prefill_progress
@property
def tbt_slack_ms(self):
if not self.generated_tokens:
return float("inf")
now = time.monotonic() * 1000
elapsed = now - self.last_token_time_ms
return self.tbt_slo_ms - elapsed
@property
def ttft_slack_ms(self):
now = time.monotonic() * 1000
elapsed = now - self.arrival_time_ms
return self.ttft_slo_ms - elapsed
class KVCacheManager:
"""Simplified block-based KV cache manager."""
def __init__(self, total_blocks, block_size_tokens=16):
self.total_blocks = total_blocks
self.block_size = block_size_tokens
self.free_block_ids = list(range(total_blocks))
self.allocated = {} # request_id -> [block_ids]
@property
def num_free_blocks(self):
return len(self.free_block_ids)
def allocate(self, request_id, num_blocks):
if num_blocks > self.num_free_blocks:
return None
blocks = [self.free_block_ids.pop() for _ in range(num_blocks)]
self.allocated[request_id] = blocks
return blocks
def free(self, request_id):
if request_id in self.allocated:
self.free_block_ids.extend(self.allocated.pop(request_id))
def blocks_needed(self, total_tokens):
return (total_tokens + self.block_size - 1) // self.block_size
class IterationLevelScheduler:
"""
Complete iteration-level scheduler with:
- Chunked prefill (Sarathi)
- Token budget management
- Priority scheduling
- Preemption support
"""
def __init__(
self,
max_batch_tokens=2048,
max_batch_requests=256,
prefill_chunk_size=256,
tbt_slo_ms=100.0,
kv_total_blocks=10000,
kv_block_size=16,
):
self.max_batch_tokens = max_batch_tokens
self.max_batch_requests = max_batch_requests
self.prefill_chunk_size = prefill_chunk_size
self.tbt_slo_ms = tbt_slo_ms
# Request queues
self.waiting_queue = [] # min-heap by (priority, arrival_time)
self.running_requests = {} # request_id -> ScheduledRequest
self.finished_requests = []
# KV cache manager
self.kv_manager = KVCacheManager(kv_total_blocks, kv_block_size)
# Metrics
self.iter_count = 0
self.ema_iter_time_ms = 30.0 # initial estimate
def add_request(self, request):
"""Add a new request to the waiting queue."""
request.state = "waiting"
heapq.heappush(self.waiting_queue, request)
def _evict_finished(self):
"""Remove finished requests and free their KV cache."""
finished_ids = []
for req_id, req in self.running_requests.items():
if req.is_finished:
req.state = "finished"
self.finished_requests.append(req)
self.kv_manager.free(req_id)
finished_ids.append(req_id)
for req_id in finished_ids:
del self.running_requests[req_id]
return len(finished_ids)
def _check_preemption(self):
"""Preempt low-priority requests if high-priority SLOs at risk."""
at_risk = [
req for req in self.running_requests.values()
if req.priority <= Priority.HIGH and req.tbt_slack_ms < 15
]
if not at_risk:
return []
# Sort candidates: lowest priority first, then least work done
candidates = sorted(
[r for r in self.running_requests.values() if r.priority >= Priority.NORMAL],
key=lambda r: (-r.priority, len(r.generated_tokens)),
)
preempted = []
for candidate in candidates:
self.kv_manager.free(candidate.request_id)
del self.running_requests[candidate.request_id]
# Re-queue: reset prefill progress (KV cache lost)
candidate.prefill_progress = 0
candidate.state = "waiting"
candidate.kv_block_ids = []
heapq.heappush(self.waiting_queue, candidate)
preempted.append(candidate.request_id)
# Check if situation resolved
if all(
r.tbt_slack_ms > 30
for r in self.running_requests.values()
if r.priority <= Priority.HIGH
):
break
return preempted
def _admit_requests(self, prefill_budget):
"""Admit waiting requests up to budget and memory limits."""
admitted = []
while (
self.waiting_queue
and prefill_budget > 0
and len(self.running_requests) < self.max_batch_requests
):
candidate = self.waiting_queue[0] # peek
# Check memory: estimate total KV blocks needed
estimated_total_tokens = (
len(candidate.prompt_tokens) + candidate.max_new_tokens
)
blocks_needed = self.kv_manager.blocks_needed(estimated_total_tokens)
if blocks_needed > self.kv_manager.num_free_blocks:
break # no memory
# Admit
heapq.heappop(self.waiting_queue)
block_ids = self.kv_manager.allocate(
candidate.request_id, blocks_needed
)
candidate.kv_block_ids = block_ids
candidate.state = "prefilling"
self.running_requests[candidate.request_id] = candidate
admitted.append(candidate.request_id)
# Deduct from prefill budget
chunk = min(
self.prefill_chunk_size,
candidate.remaining_prefill_tokens,
prefill_budget,
)
prefill_budget -= chunk
return admitted
def schedule_iteration(self):
"""
Make scheduling decision for one iteration.
Returns a ScheduleBatch describing what to execute.
"""
self.iter_count += 1
# Phase 1: Evict finished requests
num_evicted = self._evict_finished()
# Phase 2: Check preemption needs
preempted = self._check_preemption()
# Phase 3: Compute token budget
num_decode = sum(
1 for r in self.running_requests.values()
if not r.is_prefilling
)
remaining_budget = self.max_batch_tokens - num_decode
# Phase 4: Allocate prefill budget to existing prefilling requests
prefill_assignments = []
prefilling = sorted(
[r for r in self.running_requests.values() if r.is_prefilling],
key=lambda r: (r.priority, r.ttft_slack_ms),
)
for req in prefilling:
if remaining_budget <= 0:
break
chunk = min(
self.prefill_chunk_size,
req.remaining_prefill_tokens,
remaining_budget,
)
# Align to 64
chunk = max((chunk // 64) * 64, 64) if chunk >= 64 else chunk
prefill_assignments.append((req.request_id, chunk))
remaining_budget -= chunk
# Phase 5: Admit new requests with remaining budget
self._admit_requests(remaining_budget)
# Re-collect prefill assignments (new admissions included)
prefill_assignments = []
remaining_budget = self.max_batch_tokens - num_decode
for req in sorted(
[r for r in self.running_requests.values() if r.is_prefilling],
key=lambda r: (r.priority, r.ttft_slack_ms),
):
if remaining_budget <= 0:
break
chunk = min(
self.prefill_chunk_size,
req.remaining_prefill_tokens,
remaining_budget,
)
chunk = max((chunk // 64) * 64, 64) if chunk >= 64 else chunk
prefill_assignments.append((req.request_id, chunk))
remaining_budget -= chunk
# Build the batch
decode_request_ids = [
r.request_id
for r in self.running_requests.values()
if not r.is_prefilling
]
return {
"iteration": self.iter_count,
"decode_request_ids": decode_request_ids,
"prefill_assignments": prefill_assignments,
"total_tokens": num_decode + sum(c for _, c in prefill_assignments),
"num_evicted": num_evicted,
"num_preempted": len(preempted),
"num_waiting": len(self.waiting_queue),
"num_running": len(self.running_requests),
"kv_blocks_free": self.kv_manager.num_free_blocks,
}
def complete_iteration(self, iter_time_ms, new_tokens):
"""
Called after forward pass completes.
Updates request state with generated tokens.
"""
now_ms = time.monotonic() * 1000
# Update EMA of iteration time
self.ema_iter_time_ms = (
0.1 * iter_time_ms + 0.9 * self.ema_iter_time_ms
)
# Update request states
for req_id, token_id in new_tokens.items():
req = self.running_requests.get(req_id)
if req is None:
continue
if req.is_prefilling:
# Prefill complete if progress reached prompt length
# (updated by the forward pass based on chunk assignment)
if not req.is_prefilling:
req.state = "decoding"
else:
req.generated_tokens.append(token_id)
req.last_token_time_ms = now_ms
# Usage example
def run_serving_loop():
scheduler = IterationLevelScheduler(
max_batch_tokens=2048,
max_batch_requests=128,
prefill_chunk_size=256,
tbt_slo_ms=80.0,
kv_total_blocks=20000,
)
# Simulate incoming requests
for i in range(100):
req = ScheduledRequest(
request_id=f"req_{i}",
prompt_tokens=list(range(128 + i * 10)), # varying prompt lengths
priority=Priority.HIGH if i % 5 == 0 else Priority.NORMAL,
max_new_tokens=256,
)
scheduler.add_request(req)
# Run scheduling loop
for step in range(500):
batch = scheduler.schedule_iteration()
print(
f"Step {batch['iteration']}: "
f"decode={len(batch['decode_request_ids'])}, "
f"prefill_chunks={len(batch['prefill_assignments'])}, "
f"total_tokens={batch['total_tokens']}, "
f"waiting={batch['num_waiting']}, "
f"running={batch['num_running']}"
)
# Simulate forward pass
iter_time = 25.0 + batch["total_tokens"] * 0.05 # ms
new_tokens = {}
for req_id in batch["decode_request_ids"]:
new_tokens[req_id] = 42 # placeholder token
# Update prefill progress
for req_id, chunk_size in batch["prefill_assignments"]:
req = scheduler.running_requests.get(req_id)
if req:
req.prefill_progress += chunk_size
scheduler.complete_iteration(iter_time, new_tokens)
if batch["num_running"] == 0 and batch["num_waiting"] == 0:
break
Real-World Scheduling Decisions in vLLM and SGLang
vLLM Scheduler
vLLM implements Orca-style continuous batching with Sarathi-style chunked prefill. Key implementation details:
# vLLM scheduler configuration (simplified)
# From: vllm/core/scheduler.py
# The scheduler runs three passes per iteration:
# Pass 1: _schedule_running() - handle existing running requests
# - Identify finished requests
# - Handle preemption if memory is low
# - Output: running_queue_output
#
# Pass 2: _schedule_swapped() - re-admit previously preempted requests
# - Check if swapped-out requests can be swapped back in
# - Output: swap_in_queue_output
#
# Pass 3: _schedule_waiting() - admit new requests
# - Check memory availability
# - Apply chunked prefill budgeting
# - Output: waiting_queue_output
# Token budget configuration:
# max_num_batched_tokens: total tokens per iteration (default: varies by model)
# max_num_seqs: maximum concurrent sequences (default: 256)
# enable_chunked_prefill: whether to use Sarathi-style chunking (default: True in v0.4+)
# max_num_batched_tokens with chunked prefill: typically 512-2048
SGLang Scheduler
SGLang uses a similar approach with additional optimizations for structured generation:
# SGLang scheduler key differences:
# RadixAttention: shares KV cache across requests with common prefixes
# - If 10 requests share a 1000-token system prompt, store KV once
# - Reduces memory by 10x for the shared prefix
#
# Overlap scheduling: prepare next batch while GPU is executing current batch
# - Scheduler runs on a separate CPU thread
# - Uses double-buffering for batch metadata
#
# Token budget split:
# - Decode budget: always reserved (1 token per running request)
# - Prefill budget: remaining tokens after decode reservation
# - If prefill budget < 64 tokens: skip prefill this iteration
Scheduling Algorithm Performance Comparison (Llama 70B, 128 concurrent users)
| Scheduler | Throughput (tok/s) | P50 TBT (ms) | P99 TBT (ms) | P99 TTFT (ms) |
|---|---|---|---|---|
| Static batching | 1200 | 45 | 850 | 3200 |
| Orca (no chunked prefill) | 2800 | 38 | 420 | 1800 |
| Orca + Sarathi (chunk=512) | 2650 | 40 | 85 | 2100 |
| Orca + Sarathi (chunk=256) | 2550 | 42 | 65 | 2400 |
| Orca + Sarathi (chunk=128) | 2400 | 44 | 55 | 2800 |
The tradeoff is visible in the numbers: smaller prefill chunks give better P99 TBT (fewer decode stalls) but worse TTFT (prefill takes longer) and slightly lower throughput (chunking overhead). A chunk size of 256-512 is the typical production sweet spot.
For interactive chat (strict TBT SLO of 50-100ms): use chunk size 128-256. For batch API (relaxed TBT, care about throughput): use chunk size 512-1024 or disable chunking entirely. For mixed workloads: use adaptive chunk sizing based on the number of concurrent decode requests — fewer decode requests means larger chunks are safe.
Throughput-Latency Tradeoff Curves
The token budget directly controls the throughput-latency tradeoff:
Throughput vs P99 TBT at Different Token Budgets (Llama 70B, H100)
(tok/s)Beyond a token budget of 2048, throughput gains flatten because the GEMM is already compute-bound (batch size exceeds the roofline ridge point). But latency keeps increasing linearly with budget. The optimal operating point depends on the SLO:
This is typically solved empirically by profiling the system at different token budgets and finding the knee of the throughput-latency curve.