After the model forward pass produces a logits tensor of shape — where is the batch size and is the vocabulary size — most inference guides stop. The tensor is there, the model has spoken, and the rest is “just sampling.” In practice, the pipeline between raw logits and an emitted token involves seven distinct stages, several of which have nontrivial compute cost, and the engineering of this pipeline affects output quality, latency, correctness, and throughput.
This post walks through the complete token generation pipeline: logit processors (repetition penalty, frequency penalty, presence penalty), temperature scaling, top-k filtering, top-p (nucleus) filtering, multinomial sampling, stop criteria evaluation (EOS token, max_tokens, stop strings, regex), and streaming detokenization. We cover the compute cost of each stage, show how to fuse them into a single GPU kernel, and provide a complete reference implementation.
1. The Full Pipeline: Stage by Stage
Here is the complete pipeline from logits to emitted token, in order:
logits [B, V]
|
v
Stage 1: Logit processors (repetition/frequency/presence penalty)
|
v
Stage 2: Temperature scaling
|
v
Stage 3: Top-k filtering
|
v
Stage 4: Top-p (nucleus) filtering
|
v
Stage 5: Multinomial sampling
|
v
Stage 6: Stop criteria evaluation
|
v
Stage 7: Detokenization + streaming output
Each stage transforms or filters the logits tensor before the final sampling step. The order matters — applying temperature before penalties produces different results than the reverse. The convention established by HuggingFace Transformers (and followed by vLLM, SGLang, and TensorRT-LLM) is the order shown above.
Compute Cost per Pipeline Stage (V=128256, Llama 3 vocab)
| Stage | Operation | Complexity | Time at B=1 (GPU) | Time at B=256 (GPU) |
|---|---|---|---|---|
| Logit processors | Scatter + conditional add | O(B * S) per penalty | 0.8 us | 12 us |
| Temperature | Element-wise divide | O(B * V) | 0.3 us | 5 us |
| Top-k filter | Partial sort (k-th element) | O(B * V) | 2.1 us | 45 us |
| Top-p filter | Full sort + prefix sum | O(B * V log V) | 8.5 us | 180 us |
| Sampling | CDF build + binary search | O(B * V) | 1.2 us | 15 us |
| Stop check | Token comparison | O(B * num_stops) | 0.1 us | 0.5 us |
| Detokenize | Lookup + UTF-8 decode | O(B * token_len) | 0.2 us | 2 us |
| Total | — | — | 13.2 us | 260 us |
At batch size 1, the entire pipeline is negligible compared to the model forward pass. At large batch sizes, top-p sorting becomes measurable. The real motivation for optimizing this pipeline is not latency but correctness and flexibility — getting the sampling semantics exactly right.
2. Stage 1: Logit Processors
Logit processors modify the raw logits before any probabilistic operation. They implement penalties designed to improve output quality.
Repetition Penalty
Introduced by Keskar et al. (2019). For each token that appears in the context, modify its logit:
where is the repetition penalty parameter. This reduces the probability of tokens that have already appeared, regardless of how many times they appeared.
def apply_repetition_penalty(logits, input_ids, penalty):
"""
Apply repetition penalty to logits.
Args:
logits: [B, V] raw logits
input_ids: [B, S] token IDs in context
penalty: float, typically 1.0-1.3
"""
# Gather logits for tokens that appear in context
# input_ids may contain duplicates; we only care about unique tokens
for b in range(logits.shape[0]):
unique_tokens = torch.unique(input_ids[b])
token_logits = logits[b, unique_tokens]
# Apply asymmetric penalty
positive_mask = token_logits > 0
token_logits[positive_mask] /= penalty
token_logits[~positive_mask] *= penalty
logits[b, unique_tokens] = token_logits
return logits
The per-batch loop is problematic for GPU efficiency. A vectorized implementation:
def apply_repetition_penalty_vectorized(logits, input_ids, penalty):
"""Vectorized repetition penalty — no Python loops."""
# Create a mask: which (batch, vocab) positions appear in context
# input_ids: [B, S], values in [0, V)
score = torch.gather(logits, 1, input_ids) # [B, S]
# Apply penalty
score = torch.where(score > 0, score / penalty, score * penalty)
# Scatter back — this handles duplicates by overwriting
# (all duplicates get the same penalty, so overwrite is correct)
logits.scatter_(1, input_ids, score)
return logits
Frequency Penalty
Penalizes tokens proportionally to how many times they appear in the context. Used by the OpenAI API.
where is the frequency penalty coefficient (typically 0.0-2.0) and is the number of times token appears in the generated text.
def apply_frequency_penalty(logits, output_ids, frequency_penalty):
"""
Args:
logits: [B, V]
output_ids: [B, S_out] generated token IDs (not prompt)
frequency_penalty: float, 0.0-2.0
"""
# Count occurrences of each token in generated output
bin_counts = torch.zeros_like(logits) # [B, V]
bin_counts.scatter_add_(
1, output_ids,
torch.ones_like(output_ids, dtype=logits.dtype)
)
# Subtract penalty * count
logits -= frequency_penalty * bin_counts
return logits
Presence Penalty
A binary version of frequency penalty — penalizes a token if it appears at all, regardless of count.
def apply_presence_penalty(logits, output_ids, presence_penalty):
"""Binary penalty: applied once if token appears at all."""
bin_counts = torch.zeros_like(logits)
bin_counts.scatter_add_(
1, output_ids,
torch.ones_like(output_ids, dtype=logits.dtype)
)
# Convert counts to binary presence
presence = (bin_counts > 0).float()
logits -= presence_penalty * presence
return logits
Repetition penalty is multiplicative (divides/multiplies logits). Frequency and presence penalties are additive (subtract from logits). Using both simultaneously can produce unexpected distributions. Most production APIs expose only one penalty family — OpenAI uses frequency + presence, while HuggingFace uses repetition penalty.
3. Stage 2: Temperature Scaling
Temperature scaling divides logits by a scalar before softmax:
The effect on the probability distribution:
- : no change
- : distribution concentrates on the highest-logit token (greedy)
- : distribution approaches uniform
The entropy of the distribution is a monotonically increasing function of . Specifically, for a categorical distribution with logits :
def apply_temperature(logits, temperature):
"""
Scale logits by temperature.
temperature=0 is handled as greedy (argmax).
"""
if temperature == 0.0:
# Greedy: set all logits to -inf except the max
max_logits = logits.max(dim=-1, keepdim=True).values
logits = torch.where(
logits == max_logits,
logits,
torch.full_like(logits, float('-inf'))
)
return logits
return logits / temperature
The correct implementation of temperature 0 is argmax (greedy decoding), not a division. Frameworks handle this as a special case. In vLLM, temperature=0 skips the entire sampling pipeline and directly calls torch.argmax. This is both semantically correct and faster.
Temperature and Calibration
Temperature also appears in knowledge distillation and model calibration, but with different semantics. In inference serving, temperature is purely a generation-time control. A temperature of 0.7 is common for chat applications (reduces but does not eliminate randomness). A temperature of 1.0 is standard for evaluation benchmarks. Temperatures above 1.0 are used for creative writing or when diversity is desired.
The compute cost is a single element-wise division: FLOPs, fully memory-bound. On GPU, this fuses trivially with subsequent operations.
4. Stage 3: Top-k Filtering
Top-k filtering retains only the highest-logit tokens and sets all others to :
where is the -th largest logit value.
The implementation requires finding the -th largest element — a selection problem, not a full sort:
def apply_top_k(logits, k):
"""
Retain only top-k logits, set rest to -inf.
Uses torch.topk which internally uses a partial sort
(radix select or heap-based selection).
"""
if k <= 0 or k >= logits.shape[-1]:
return logits # No filtering
# Find the k-th largest value per batch element
top_k_values, _ = torch.topk(logits, k, dim=-1) # [B, k]
threshold = top_k_values[:, -1:] # [B, 1] — the k-th value
# Mask everything below threshold
logits = torch.where(
logits >= threshold,
logits,
torch.full_like(logits, float('-inf'))
)
return logits
Compute Cost
torch.topk on GPU uses a radix-based selection algorithm with complexity per batch element (not — it does not fully sort). For and , this is fast: approximately 2 microseconds per batch element on H100.
The key insight: top-k is cheap because it does NOT require sorting. It only needs to find the -th order statistic, which can be done in linear time via partial sort or selection algorithms.
Top-k Kernel Time vs k (V=128256, B=1, H100)
(microseconds)5. Stage 4: Top-p (Nucleus) Filtering
Top-p sampling (Holtzman et al., 2020) is more nuanced than top-k. Instead of a fixed count, it retains the smallest set of tokens whose cumulative probability exceeds a threshold :
where the sum is over tokens sorted by probability in descending order, and we include the token that causes the cumulative sum to first exceed .
This requires a full sort of the vocabulary by logit value.
def apply_top_p(logits, p):
"""
Nucleus (top-p) filtering.
Retains tokens whose cumulative probability mass is <= p.
Requires: sorted probabilities and cumulative sum.
"""
if p >= 1.0:
return logits # No filtering
# Sort logits in descending order
sorted_logits, sorted_indices = torch.sort(
logits, dim=-1, descending=True
) # Both: [B, V]
# Convert to probabilities for cumulative sum
sorted_probs = torch.softmax(sorted_logits, dim=-1) # [B, V]
# Cumulative sum of probabilities
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) # [B, V]
# Create mask: remove tokens where cumulative prob > p
# Shift right by 1 so the token that crosses p is included
sorted_mask = cumulative_probs - sorted_probs > p # [B, V]
# Set filtered tokens to -inf in sorted space
sorted_logits[sorted_mask] = float('-inf')
# Unsort: scatter back to original positions
logits = torch.zeros_like(logits)
logits.scatter_(1, sorted_indices, sorted_logits)
return logits
The Sort Bottleneck
The dominant cost is torch.sort on a tensor of size . GPU sorting algorithms (radix sort, bitonic sort) have complexity per batch element. For :
So we perform approximately comparison-swap operations per batch element. At large batch sizes, this is the most expensive step in the post-model pipeline.
Top-p Sort Time vs Vocabulary Size (B=1, H100)
| Vocabulary Size | Sort Time | Cumsum Time | Total Top-p Time | Fraction of Pipeline |
|---|---|---|---|---|
| 32000 (Llama 2) | 2.5 us | 0.8 us | 4.2 us | 52% |
| 50257 (GPT-2) | 3.8 us | 1.2 us | 6.1 us | 55% |
| 128256 (Llama 3) | 8.5 us | 2.8 us | 13.1 us | 62% |
| 151936 (Qwen 2) | 9.8 us | 3.2 us | 15.0 us | 63% |
| 256000 (Gemma 2) | 15.2 us | 5.1 us | 23.0 us | 67% |
Combined Top-k + Top-p
Most serving systems apply top-k before top-p. This is not just for quality — it is an optimization. If top-k reduces the candidate set to tokens, the subsequent sort for top-p operates on elements instead of :
def apply_top_k_top_p(logits, k, p):
"""Combined: top-k first (cheap), then top-p on the reduced set."""
# Step 1: Top-k (O(V) selection, not sort)
if k > 0:
top_k_values, top_k_indices = torch.topk(logits, k, dim=-1)
# Now work with [B, k] tensor instead of [B, V]
else:
top_k_values = logits
top_k_indices = torch.arange(logits.shape[-1]).expand_as(logits)
# Step 2: Top-p on the top-k set (O(k log k) sort)
sorted_values, sorted_local = torch.sort(
top_k_values, dim=-1, descending=True
) # [B, k]
sorted_probs = torch.softmax(sorted_values, dim=-1)
cumsum = torch.cumsum(sorted_probs, dim=-1)
mask = (cumsum - sorted_probs) > p
sorted_values[mask] = float('-inf')
# Unsort within top-k
top_k_values.scatter_(1, sorted_local, sorted_values)
# Scatter back to full vocabulary
result = torch.full_like(logits, float('-inf'))
result.scatter_(1, top_k_indices, top_k_values)
return result
With and : the sort is on 50 elements instead of 128256 — a reduction in sort work.
When both top-k and top-p are active, applying top-k first reduces the sort from to . For , this is a 3-4 orders of magnitude reduction in sort work. This is why vLLM and SGLang always apply top-k before top-p, even if the user only requested top-p — they set a default (e.g., meaning “all”) but internally cap it.
6. Stage 5: Multinomial Sampling
After filtering, the remaining logits are converted to a probability distribution via softmax, and a token is sampled:
def sample_token(logits):
"""
Sample a token from filtered logits.
Args:
logits: [B, V] with -inf for filtered tokens
Returns:
token_ids: [B] sampled token indices
"""
probs = torch.softmax(logits, dim=-1) # [B, V]
token_ids = torch.multinomial(probs, num_samples=1) # [B, 1]
return token_ids.squeeze(-1) # [B]
How torch.multinomial Works on GPU
torch.multinomial with num_samples=1 performs:
- Compute the CDF:
cumsum(probs)— per batch element - Draw a uniform random number
- Binary search for the smallest index where —
The total cost is dominated by the cumulative sum: .
def multinomial_manual(probs):
"""Manual implementation of multinomial sampling."""
# CDF via cumulative sum
cdf = torch.cumsum(probs, dim=-1) # [B, V]
# Random uniform
u = torch.rand(probs.shape[0], 1, device=probs.device) # [B, 1]
# Find first index where CDF >= u
# This is equivalent to binary search
mask = cdf >= u # [B, V]
# argmax on a boolean tensor returns the first True index
token_ids = mask.to(torch.int32).argmax(dim=-1) # [B]
return token_ids
Greedy Decoding as a Special Case
When temperature is 0 (or when the user requests greedy decoding), the entire sampling pipeline reduces to:
token_ids = torch.argmax(logits, dim=-1) # O(BV)
This skips all filtering stages. Production systems detect temperature=0 early and bypass the pipeline.
Beam Search
Beam search is not sampling — it maintains candidate sequences and expands each by the top- tokens:
def beam_search_step(logits, beam_scores, beam_width):
"""
One step of beam search.
Args:
logits: [B * beam_width, V] logits from all beams
beam_scores: [B * beam_width] accumulated log-probs
beam_width: int
"""
log_probs = torch.log_softmax(logits, dim=-1) # [B*beam, V]
# Add accumulated beam scores
next_scores = log_probs + beam_scores.unsqueeze(-1) # [B*beam, V]
# Reshape to [B, beam * V] and select top beam_width
B = next_scores.shape[0] // beam_width
next_scores = next_scores.view(B, -1) # [B, beam*V]
top_scores, top_indices = torch.topk(
next_scores, beam_width, dim=-1
) # [B, beam]
# Decode beam and token indices
beam_indices = top_indices // logits.shape[-1] # which beam
token_indices = top_indices % logits.shape[-1] # which token
return top_scores, beam_indices, token_indices
Beam search with beams requires a top-k selection over elements per batch — more expensive than sampling but deterministic.
7. Stage 6: Stop Criteria
After sampling a token, the system must decide whether generation is complete for each request in the batch.
EOS Token
The simplest stop criterion. Each model has a designated end-of-sequence token (or set of tokens). Llama 3 uses token ID 128001 (<|end_of_text|>) and 128009 (<|eot_id|>).
def check_eos(token_ids, eos_token_ids):
"""
Check if any sampled token is an EOS token.
Args:
token_ids: [B] sampled tokens
eos_token_ids: set of EOS token IDs
Returns:
finished: [B] boolean mask
"""
finished = torch.zeros(token_ids.shape[0], dtype=torch.bool,
device=token_ids.device)
for eos_id in eos_token_ids:
finished |= (token_ids == eos_id)
return finished
Max Tokens
A hard limit on the number of generated tokens. Trivial to implement:
def check_max_tokens(generated_count, max_tokens):
"""Check if we have generated enough tokens."""
return generated_count >= max_tokens
Stop Strings
Stop strings (e.g., "\n\n", "```", "Human:") require matching against the decoded text, not token IDs. This is where things get complicated.
A stop string may span multiple tokens. The string "\n\n" could be:
- One token:
\n\n(if the tokenizer has this as a single token) - Two tokens:
\n+\n
The matching must happen on decoded text, not token sequences:
class StopStringChecker:
"""
Check if any stop string appears in the generated text.
Handles the case where a stop string spans a token boundary.
Maintains a buffer of recent decoded text.
"""
def __init__(self, stop_strings, tokenizer):
self.stop_strings = stop_strings
self.tokenizer = tokenizer
# Maximum stop string length determines buffer size
self.max_stop_len = max(len(s) for s in stop_strings)
# Per-request text buffers
self.buffers = {}
def check(self, request_id, new_token_id):
"""
Decode the new token, append to buffer, check for stops.
Returns: (is_stopped, matched_string, trimmed_output)
"""
new_text = self.tokenizer.decode(
[new_token_id], skip_special_tokens=False
)
if request_id not in self.buffers:
self.buffers[request_id] = ""
self.buffers[request_id] += new_text
# Check all stop strings
for stop_str in self.stop_strings:
idx = self.buffers[request_id].find(stop_str)
if idx != -1:
# Found a stop string — trim output
trimmed = self.buffers[request_id][:idx]
return True, stop_str, trimmed
# Trim buffer to max_stop_len (no need to keep old text)
if len(self.buffers[request_id]) > self.max_stop_len * 2:
self.buffers[request_id] = (
self.buffers[request_id][-self.max_stop_len:]
)
return False, None, None
Stop string detection must handle partial matches at token boundaries. A stop string "Human:" might appear as tokens ["Hum", "an", ":"]. The system must buffer decoded text and check after each token. This is why stop string evaluation is always done on the CPU side in decoded text, not in the GPU sampling kernel.
Structured Stop Criteria
For structured generation (JSON, XML, function calls), stop criteria become more complex:
class JSONStopChecker:
"""Stop when a valid JSON object is complete."""
def __init__(self):
self.brace_depth = 0
self.in_string = False
self.escape_next = False
def check_char(self, char):
if self.escape_next:
self.escape_next = False
return False
if char == '\\' and self.in_string:
self.escape_next = True
return False
if char == '"':
self.in_string = not self.in_string
return False
if self.in_string:
return False
if char == '{':
self.brace_depth += 1
elif char == '}':
self.brace_depth -= 1
if self.brace_depth == 0:
return True # Complete JSON object
return False
8. Stage 7: Streaming and Detokenization
Streaming Architecture
In streaming mode, tokens are yielded to the client as they are generated. The server sends Server-Sent Events (SSE):
async def generate_stream(request, model, tokenizer):
"""
Streaming token generation.
Yields tokens as SSE events.
"""
input_ids = tokenizer.encode(request.prompt)
generated_ids = []
stop_checker = StopStringChecker(
request.stop_strings, tokenizer
)
detokenizer = IncrementalDetokenizer(tokenizer)
for step in range(request.max_tokens):
# Model forward pass
logits = model.forward(input_ids + generated_ids)
# logits: [1, V] (batch size 1 for this example)
# Apply sampling pipeline
logits = apply_logit_processors(
logits, input_ids, generated_ids, request
)
logits = apply_temperature(logits, request.temperature)
logits = apply_top_k(logits, request.top_k)
logits = apply_top_p(logits, request.top_p)
token_id = sample_token(logits).item()
generated_ids.append(token_id)
# Check EOS
if token_id in tokenizer.eos_token_ids:
yield {"event": "done", "data": ""}
break
# Incremental detokenization
new_text = detokenizer.decode_token(token_id)
# Check stop strings
stopped, _, trimmed = stop_checker.check(
request.id, token_id
)
if stopped:
if trimmed:
yield {"event": "token", "data": trimmed}
yield {"event": "done", "data": ""}
break
# Yield the new text fragment
if new_text:
yield {"event": "token", "data": new_text}
else:
# max_tokens reached
yield {"event": "done", "data": "[max_tokens]"}
Incremental Detokenization
Tokenizers like SentencePiece and BPE produce tokens that may not align with UTF-8 character boundaries. A single Unicode character might be split across multiple tokens. Incremental detokenization must handle this:
class IncrementalDetokenizer:
"""
Detokenize one token at a time, handling multi-token characters.
The challenge: decoding [token_1, token_2, ..., token_n] may produce
different text than decode(token_1) + decode(token_2) + ...
because tokenizers use context-dependent decoding (e.g., SentencePiece
adds a space prefix for certain tokens).
"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.token_ids = []
self.prev_text = ""
def decode_token(self, token_id):
"""
Decode one new token, return only the new text.
Strategy: decode all tokens so far, diff with previous decode.
This correctly handles context-dependent decoding.
"""
self.token_ids.append(token_id)
# Full decode of all tokens
full_text = self.tokenizer.decode(
self.token_ids, skip_special_tokens=True
)
# New text is the difference
new_text = full_text[len(self.prev_text):]
self.prev_text = full_text
return new_text
This decode-everything-and-diff approach is correct but over the full generation (decoding tokens total). For long generations, this becomes expensive. The optimization is to use a sliding window:
class IncrementalDetokenizerOptimized:
"""
Optimized incremental detokenizer.
Only re-decodes the last few tokens to handle boundary effects.
"""
def __init__(self, tokenizer, context_window=5):
self.tokenizer = tokenizer
self.context_window = context_window
self.token_ids = []
self.committed_text = ""
self.pending_ids = []
def decode_token(self, token_id):
self.token_ids.append(token_id)
self.pending_ids.append(token_id)
# Decode the pending window
window = self.pending_ids[-self.context_window:]
window_text = self.tokenizer.decode(
window, skip_special_tokens=True
)
if len(self.pending_ids) > self.context_window:
# Commit text from tokens that are no longer in the window
old_window = self.pending_ids[-(self.context_window + 1):-1]
# The text difference is safe to commit
pass # Simplified — real implementation tracks offsets
# For correctness, we still diff against the previous state
prev_window = self.pending_ids[-(len(self.pending_ids)):]
if len(prev_window) > 1:
prev_text = self.tokenizer.decode(
prev_window[:-1], skip_special_tokens=True
)
else:
prev_text = ""
current_text = self.tokenizer.decode(
prev_window, skip_special_tokens=True
)
return current_text[len(prev_text):]
9. Fused Logit Processing Kernels
In production, the separate stages (temperature, top-k, top-p, sample) are fused into a single GPU kernel to avoid multiple passes over the tensor:
// Fused sampling kernel: temperature + top-k + top-p + sample
// in a single pass over the vocabulary
__global__ void fused_sampling_kernel(
const float* __restrict__ logits, // [B, V]
const int* __restrict__ input_ids, // [B, S] for penalties
const int* __restrict__ seq_lens, // [B] actual lengths
float* __restrict__ output_probs, // [B, V] workspace
int* __restrict__ sampled_tokens, // [B] output
float temperature,
int top_k,
float top_p,
float rep_penalty,
int V,
int S,
unsigned long long seed
) {
int batch_idx = blockIdx.x;
int tid = threadIdx.x;
// Step 1: Apply repetition penalty (cooperative across threads)
// Each thread handles a chunk of the vocabulary
extern __shared__ float shared_logits[];
for (int v = tid; v < V; v += blockDim.x) {
float logit = logits[batch_idx * V + v];
// Check if token v appears in context
bool in_context = false;
int seq_len = seq_lens[batch_idx];
for (int s = 0; s < seq_len; s++) {
if (input_ids[batch_idx * S + s] == v) {
in_context = true;
break;
}
}
if (in_context && rep_penalty != 1.0f) {
logit = (logit > 0) ? logit / rep_penalty
: logit * rep_penalty;
}
// Step 2: Temperature
logit /= temperature;
output_probs[batch_idx * V + v] = logit;
}
__syncthreads();
// Step 3: Top-k via partial sort (warp-level reduction)
// Find k-th largest value using iterative threshold
// ... (radix-based selection in shared memory)
// Step 4: Softmax on surviving tokens
// Step 5: Cumulative sum for top-p
// Step 6: Sample from CDF
// This is simplified — real implementations use multi-pass
// with shared memory and warp shuffles
}
In practice, frameworks like vLLM use Triton kernels for this fusion:
import triton
import triton.language as tl
@triton.jit
def fused_top_k_top_p_sampling_kernel(
logits_ptr, # [B, V]
output_ptr, # [B] sampled token IDs
temperature,
top_k: tl.constexpr,
top_p,
V: tl.constexpr,
BLOCK_V: tl.constexpr,
seed,
):
batch_idx = tl.program_id(0)
# Load logits for this batch element
offsets = tl.arange(0, BLOCK_V)
mask = offsets < V
logits = tl.load(
logits_ptr + batch_idx * V + offsets,
mask=mask,
other=float('-inf')
)
# Temperature scaling
logits = logits / temperature
# Top-k: find k-th value via iterative approach
# (Triton does not have a native topk — use approximate methods
# or multiple passes)
# ... top-k filtering logic ...
# Softmax
max_logit = tl.max(logits, axis=0)
logits = logits - max_logit
exp_logits = tl.exp(logits)
sum_exp = tl.sum(exp_logits, axis=0)
probs = exp_logits / sum_exp
# Top-p via cumulative sum (requires sorted order)
# ... sorting in Triton is limited, typically uses
# bitonic sort for small V or falls back to PyTorch ...
# Sample: CDF + uniform random
cdf = tl.cumsum(probs, axis=0)
rand_val = tl.rand(seed, batch_idx)
selected = tl.sum((cdf < rand_val).to(tl.int32), axis=0)
tl.store(output_ptr + batch_idx, selected)
vLLM’s sampling kernel fuses temperature, top-k, top-p, and multinomial sampling into a single Triton kernel for batch sizes up to 256. For larger batches, it falls back to separate PyTorch operations because the Triton kernel’s shared memory usage limits occupancy. SGLang takes a similar approach with a custom CUDA kernel. The fusion eliminates 3-4 kernel launches and avoids writing intermediate tensors to HBM.
10. Complete Reference Pipeline
Here is the complete pipeline assembled from all stages:
import torch
from dataclasses import dataclass
@dataclass
class SamplingParams:
temperature: float = 1.0
top_k: int = -1 # -1 means disabled
top_p: float = 1.0 # 1.0 means disabled
repetition_penalty: float = 1.0
frequency_penalty: float = 0.0
presence_penalty: float = 0.0
max_tokens: int = 2048
stop_strings: list = None
eos_token_ids: set = None
def __post_init__(self):
if self.stop_strings is None:
self.stop_strings = []
if self.eos_token_ids is None:
self.eos_token_ids = set()
class TokenGenerationPipeline:
"""
Complete token generation pipeline.
Handles batched generation with per-request sampling params.
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def process_logits(self, logits, input_ids, output_ids, params):
"""
Full logit processing pipeline.
Args:
logits: [B, V] raw logits from model
input_ids: [B, S_in] prompt token IDs
output_ids: [B, S_out] generated token IDs so far
params: SamplingParams
Returns:
token_ids: [B] sampled token indices
"""
B, V = logits.shape
# Stage 0: Greedy shortcut
if params.temperature == 0.0:
return torch.argmax(logits, dim=-1)
# Stage 1a: Repetition penalty
if params.repetition_penalty != 1.0:
all_ids = torch.cat([input_ids, output_ids], dim=1)
score = torch.gather(logits, 1, all_ids)
score = torch.where(
score > 0,
score / params.repetition_penalty,
score * params.repetition_penalty,
)
logits.scatter_(1, all_ids, score)
# Stage 1b: Frequency penalty
if params.frequency_penalty != 0.0:
bin_counts = torch.zeros_like(logits)
bin_counts.scatter_add_(
1, output_ids,
torch.ones_like(output_ids, dtype=logits.dtype),
)
logits -= params.frequency_penalty * bin_counts
# Stage 1c: Presence penalty
if params.presence_penalty != 0.0:
bin_counts = torch.zeros_like(logits)
bin_counts.scatter_add_(
1, output_ids,
torch.ones_like(output_ids, dtype=logits.dtype),
)
presence = (bin_counts > 0).float()
logits -= params.presence_penalty * presence
# Stage 2: Temperature
logits = logits / params.temperature
# Stage 3: Top-k
if 0 < params.top_k < V:
top_k_vals, _ = torch.topk(logits, params.top_k, dim=-1)
threshold = top_k_vals[:, -1:]
logits = torch.where(
logits >= threshold, logits,
torch.full_like(logits, float('-inf')),
)
# Stage 4: Top-p
if params.top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(
logits, dim=-1, descending=True
)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumsum = torch.cumsum(sorted_probs, dim=-1)
mask = (cumsum - sorted_probs) > params.top_p
sorted_logits[mask] = float('-inf')
logits = torch.zeros_like(logits)
logits.scatter_(1, sorted_idx, sorted_logits)
# Stage 5: Sample
probs = torch.softmax(logits, dim=-1)
token_ids = torch.multinomial(probs, num_samples=1).squeeze(-1)
return token_ids
def generate(self, prompt, params):
"""
Full generation loop.
Args:
prompt: str
params: SamplingParams
Yields:
str: generated text fragments
"""
input_ids = self.tokenizer.encode(prompt)
input_tensor = torch.tensor(
[input_ids], device='cuda'
) # [1, S]
output_ids = torch.zeros(
1, 0, dtype=torch.long, device='cuda'
)
detokenizer = IncrementalDetokenizer(self.tokenizer)
stop_checker = StopStringChecker(
params.stop_strings, self.tokenizer
)
for step in range(params.max_tokens):
# Forward pass (with KV cache in practice)
with torch.no_grad():
logits = self.model(
torch.cat([input_tensor, output_ids], dim=1)
) # [1, S+step, V]
logits = logits[:, -1, :] # [1, V] — last position
# Sample
token_id = self.process_logits(
logits, input_tensor, output_ids, params
) # [1]
# Append to output
output_ids = torch.cat(
[output_ids, token_id.unsqueeze(0).unsqueeze(0)],
dim=1,
)
tid = token_id.item()
# Check EOS
if tid in params.eos_token_ids:
break
# Detokenize
new_text = detokenizer.decode_token(tid)
# Check stop strings
stopped, _, trimmed = stop_checker.check("req_0", tid)
if stopped:
if trimmed:
yield trimmed
break
if new_text:
yield new_text
Per-Request Sampling Parameters in Batched Serving
In continuous batching, each request in the batch may have different sampling parameters. This requires either:
-
Per-request kernel dispatch: Execute sampling separately for each request. Simple but serializes work.
-
Batched with parameter arrays: Pass arrays of per-request parameters to the kernel.
def batched_sample_heterogeneous(
logits, # [B, V]
temperatures, # [B]
top_ks, # [B]
top_ps, # [B]
):
"""
Batched sampling where each request has different params.
Each request is processed independently but in parallel.
"""
B, V = logits.shape
token_ids = torch.empty(B, dtype=torch.long, device='cuda')
# Temperature: vectorized (different T per request)
logits = logits / temperatures.unsqueeze(-1) # [B, V] / [B, 1]
# Top-k: must handle different k per request
for b in range(B):
k = top_ks[b].item()
p = top_ps[b].item()
row = logits[b] # [V]
if 0 < k < V:
vals, _ = torch.topk(row, k)
row = torch.where(
row >= vals[-1], row,
torch.tensor(float('-inf'), device='cuda'),
)
if p < 1.0:
sorted_row, sorted_idx = torch.sort(
row, descending=True
)
probs = torch.softmax(sorted_row, dim=-1)
cumsum = torch.cumsum(probs, dim=-1)
mask = (cumsum - probs) > p
sorted_row[mask] = float('-inf')
row = torch.zeros_like(row)
row.scatter_(0, sorted_idx, sorted_row)
probs = torch.softmax(row, dim=-1)
token_ids[b] = torch.multinomial(
probs.unsqueeze(0), 1
).squeeze()
return token_ids
The per-request loop is the bottleneck. vLLM solves this by grouping requests with identical sampling parameters and processing each group as a batch. In practice, most requests in a deployment use the same parameters (the API defaults), so this grouping is effective.
Sampling Pipeline Latency vs Batch Size (V=128256, H100)
(microseconds)Key Takeaways
-
The pipeline order matters: Penalties first, then temperature, then top-k, then top-p, then sample. Changing the order produces different output distributions.
-
Top-p is the expensive step: The sort dominates the sampling pipeline. Apply top-k first to reduce the sort domain.
-
Stop strings require CPU-side text matching: Token-level checking is insufficient because stop strings can span token boundaries. Buffer decoded text and check after each token.
-
Incremental detokenization is not trivial: Context-dependent tokenizers (SentencePiece) require decoding all tokens and diffing, or a carefully managed sliding window.
-
Fused kernels eliminate HBM round-trips: Combining temperature + top-k + top-p + sample into one kernel avoids writing and reading the tensor multiple times. This matters at large vocabulary sizes.
-
At typical batch sizes, the sampling pipeline is negligible: The model forward pass takes 10-30 ms; the sampling pipeline takes 10-300 microseconds. Optimize the forward pass first. But get the sampling semantics exactly right — incorrect penalties or filtering produce measurably worse output quality.