Part of Series Inference Optimization Timeline 20 of 23
1 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 2 KV Cache: The Hidden Memory Giant in LLM Serving 3 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 4 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 5 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 6 Continuous Batching: The Complete Guide to LLM Inference Scheduling 7 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 8 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 9 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 10 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 11 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 12 Mamba and State Space Models: The O(n) Alternative to Attention 13 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 14 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 15 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 16 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 17 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 18 Memory Pool Management: Slab Allocators for GPU Inference 19 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 20 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 21 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 22 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 23 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification

After the model forward pass produces a logits tensor of shape [B,V][B, V] — where BB is the batch size and VV is the vocabulary size — most inference guides stop. The tensor is there, the model has spoken, and the rest is “just sampling.” In practice, the pipeline between raw logits and an emitted token involves seven distinct stages, several of which have nontrivial compute cost, and the engineering of this pipeline affects output quality, latency, correctness, and throughput.

This post walks through the complete token generation pipeline: logit processors (repetition penalty, frequency penalty, presence penalty), temperature scaling, top-k filtering, top-p (nucleus) filtering, multinomial sampling, stop criteria evaluation (EOS token, max_tokens, stop strings, regex), and streaming detokenization. We cover the compute cost of each stage, show how to fuse them into a single GPU kernel, and provide a complete reference implementation.


1. The Full Pipeline: Stage by Stage

Here is the complete pipeline from logits to emitted token, in order:

logits [B, V]
  |
  v
Stage 1: Logit processors (repetition/frequency/presence penalty)
  |
  v
Stage 2: Temperature scaling
  |
  v
Stage 3: Top-k filtering
  |
  v
Stage 4: Top-p (nucleus) filtering
  |
  v
Stage 5: Multinomial sampling
  |
  v
Stage 6: Stop criteria evaluation
  |
  v
Stage 7: Detokenization + streaming output

Each stage transforms or filters the logits tensor before the final sampling step. The order matters — applying temperature before penalties produces different results than the reverse. The convention established by HuggingFace Transformers (and followed by vLLM, SGLang, and TensorRT-LLM) is the order shown above.

📊

Compute Cost per Pipeline Stage (V=128256, Llama 3 vocab)

StageOperationComplexityTime at B=1 (GPU)Time at B=256 (GPU)
Logit processors Scatter + conditional add O(B * S) per penalty 0.8 us 12 us
Temperature Element-wise divide O(B * V) 0.3 us 5 us
Top-k filter Partial sort (k-th element) O(B * V) 2.1 us 45 us
Top-p filter Full sort + prefix sum O(B * V log V) 8.5 us 180 us
Sampling CDF build + binary search O(B * V) 1.2 us 15 us
Stop check Token comparison O(B * num_stops) 0.1 us 0.5 us
Detokenize Lookup + UTF-8 decode O(B * token_len) 0.2 us 2 us
Total 13.2 us 260 us
Note: V=128256 (Llama 3). Times are GPU kernel execution time, excluding launch overhead. Top-p dominates due to the sort. For comparison, a single decode forward pass for Llama 70B at B=1 takes ~30ms — so the sampling pipeline is 0.04% of total latency.

At batch size 1, the entire pipeline is negligible compared to the model forward pass. At large batch sizes, top-p sorting becomes measurable. The real motivation for optimizing this pipeline is not latency but correctness and flexibility — getting the sampling semantics exactly right.

2. Stage 1: Logit Processors

Logit processors modify the raw logits before any probabilistic operation. They implement penalties designed to improve output quality.

Repetition Penalty

Introduced by Keskar et al. (2019). For each token tt that appears in the context, modify its logit:

logitt={logitt/θif logitt>0logitt×θif logitt0\text{logit}_t' = \begin{cases} \text{logit}_t / \theta & \text{if } \text{logit}_t > 0 \\ \text{logit}_t \times \theta & \text{if } \text{logit}_t \leq 0 \end{cases}

where θ>1.0\theta > 1.0 is the repetition penalty parameter. This reduces the probability of tokens that have already appeared, regardless of how many times they appeared.

def apply_repetition_penalty(logits, input_ids, penalty):
    """
    Apply repetition penalty to logits.

    Args:
        logits: [B, V] raw logits
        input_ids: [B, S] token IDs in context
        penalty: float, typically 1.0-1.3
    """
    # Gather logits for tokens that appear in context
    # input_ids may contain duplicates; we only care about unique tokens
    for b in range(logits.shape[0]):
        unique_tokens = torch.unique(input_ids[b])
        token_logits = logits[b, unique_tokens]

        # Apply asymmetric penalty
        positive_mask = token_logits > 0
        token_logits[positive_mask] /= penalty
        token_logits[~positive_mask] *= penalty

        logits[b, unique_tokens] = token_logits

    return logits

The per-batch loop is problematic for GPU efficiency. A vectorized implementation:

def apply_repetition_penalty_vectorized(logits, input_ids, penalty):
    """Vectorized repetition penalty — no Python loops."""
    # Create a mask: which (batch, vocab) positions appear in context
    # input_ids: [B, S], values in [0, V)
    score = torch.gather(logits, 1, input_ids)  # [B, S]

    # Apply penalty
    score = torch.where(score > 0, score / penalty, score * penalty)

    # Scatter back — this handles duplicates by overwriting
    # (all duplicates get the same penalty, so overwrite is correct)
    logits.scatter_(1, input_ids, score)
    return logits

Frequency Penalty

Penalizes tokens proportionally to how many times they appear in the context. Used by the OpenAI API.

logitt=logittαf×count(t)\text{logit}_t' = \text{logit}_t - \alpha_f \times \text{count}(t)

where αf\alpha_f is the frequency penalty coefficient (typically 0.0-2.0) and count(t)\text{count}(t) is the number of times token tt appears in the generated text.

def apply_frequency_penalty(logits, output_ids, frequency_penalty):
    """
    Args:
        logits: [B, V]
        output_ids: [B, S_out] generated token IDs (not prompt)
        frequency_penalty: float, 0.0-2.0
    """
    # Count occurrences of each token in generated output
    bin_counts = torch.zeros_like(logits)  # [B, V]
    bin_counts.scatter_add_(
        1, output_ids,
        torch.ones_like(output_ids, dtype=logits.dtype)
    )

    # Subtract penalty * count
    logits -= frequency_penalty * bin_counts
    return logits

Presence Penalty

A binary version of frequency penalty — penalizes a token if it appears at all, regardless of count.

logitt=logittαp×1[count(t)>0]\text{logit}_t' = \text{logit}_t - \alpha_p \times \mathbb{1}[\text{count}(t) > 0]

def apply_presence_penalty(logits, output_ids, presence_penalty):
    """Binary penalty: applied once if token appears at all."""
    bin_counts = torch.zeros_like(logits)
    bin_counts.scatter_add_(
        1, output_ids,
        torch.ones_like(output_ids, dtype=logits.dtype)
    )

    # Convert counts to binary presence
    presence = (bin_counts > 0).float()
    logits -= presence_penalty * presence
    return logits
ℹ️ Penalty Interaction

Repetition penalty is multiplicative (divides/multiplies logits). Frequency and presence penalties are additive (subtract from logits). Using both simultaneously can produce unexpected distributions. Most production APIs expose only one penalty family — OpenAI uses frequency + presence, while HuggingFace uses repetition penalty.

3. Stage 2: Temperature Scaling

Temperature scaling divides logits by a scalar T>0T > 0 before softmax:

pi=exp(logiti/T)jexp(logitj/T)p_i = \frac{\exp(\text{logit}_i / T)}{\sum_j \exp(\text{logit}_j / T)}

The effect on the probability distribution:

  • T=1.0T = 1.0: no change
  • T0+T \to 0^+: distribution concentrates on the highest-logit token (greedy)
  • TT \to \infty: distribution approaches uniform

The entropy of the distribution is a monotonically increasing function of TT. Specifically, for a categorical distribution with logits zz:

H(T)=log(jexp(zj/T))jzjexp(zj/T)Tjexp(zj/T)H(T) = \log\left(\sum_j \exp(z_j / T)\right) - \frac{\sum_j z_j \exp(z_j / T)}{T \sum_j \exp(z_j / T)}

def apply_temperature(logits, temperature):
    """
    Scale logits by temperature.
    temperature=0 is handled as greedy (argmax).
    """
    if temperature == 0.0:
        # Greedy: set all logits to -inf except the max
        max_logits = logits.max(dim=-1, keepdim=True).values
        logits = torch.where(
            logits == max_logits,
            logits,
            torch.full_like(logits, float('-inf'))
        )
        return logits

    return logits / temperature
⚠️ Temperature 0 Is Not Division by Zero

The correct implementation of temperature 0 is argmax (greedy decoding), not a division. Frameworks handle this as a special case. In vLLM, temperature=0 skips the entire sampling pipeline and directly calls torch.argmax. This is both semantically correct and faster.

Temperature and Calibration

Temperature also appears in knowledge distillation and model calibration, but with different semantics. In inference serving, temperature is purely a generation-time control. A temperature of 0.7 is common for chat applications (reduces but does not eliminate randomness). A temperature of 1.0 is standard for evaluation benchmarks. Temperatures above 1.0 are used for creative writing or when diversity is desired.

The compute cost is a single element-wise division: O(BV)O(BV) FLOPs, fully memory-bound. On GPU, this fuses trivially with subsequent operations.

4. Stage 3: Top-k Filtering

Top-k filtering retains only the kk highest-logit tokens and sets all others to -\infty:

logiti={logitiif logitilogit(k)otherwise\text{logit}_i' = \begin{cases} \text{logit}_i & \text{if } \text{logit}_i \geq \text{logit}_{(k)} \\ -\infty & \text{otherwise} \end{cases}

where logit(k)\text{logit}_{(k)} is the kk-th largest logit value.

The implementation requires finding the kk-th largest element — a selection problem, not a full sort:

def apply_top_k(logits, k):
    """
    Retain only top-k logits, set rest to -inf.

    Uses torch.topk which internally uses a partial sort
    (radix select or heap-based selection).
    """
    if k <= 0 or k >= logits.shape[-1]:
        return logits  # No filtering

    # Find the k-th largest value per batch element
    top_k_values, _ = torch.topk(logits, k, dim=-1)  # [B, k]
    threshold = top_k_values[:, -1:]  # [B, 1] — the k-th value

    # Mask everything below threshold
    logits = torch.where(
        logits >= threshold,
        logits,
        torch.full_like(logits, float('-inf'))
    )
    return logits

Compute Cost

torch.topk on GPU uses a radix-based selection algorithm with complexity O(V)O(V) per batch element (not O(VlogV)O(V \log V) — it does not fully sort). For V=128256V = 128256 and k=50k = 50, this is fast: approximately 2 microseconds per batch element on H100.

The key insight: top-k is cheap because it does NOT require sorting. It only needs to find the kk-th order statistic, which can be done in linear time via partial sort or selection algorithms.

Top-k Kernel Time vs k (V=128256, B=1, H100)

(microseconds)
k=1 (argmax)
0.5 microseconds
k=10
1.2 microseconds
k=50
2.1 microseconds
k=100
2.4 microseconds
k=1000
3.8 microseconds
k=10000
6.5 microseconds
k=V (no-op)
0 microseconds

5. Stage 4: Top-p (Nucleus) Filtering

Top-p sampling (Holtzman et al., 2020) is more nuanced than top-k. Instead of a fixed count, it retains the smallest set of tokens whose cumulative probability exceeds a threshold pp:

nucleus(p)={t:tsortedP(t)p}\text{nucleus}(p) = \{t : \sum_{t' \in \text{sorted}} P(t') \leq p\}

where the sum is over tokens sorted by probability in descending order, and we include the token that causes the cumulative sum to first exceed pp.

This requires a full sort of the vocabulary by logit value.

def apply_top_p(logits, p):
    """
    Nucleus (top-p) filtering.
    Retains tokens whose cumulative probability mass is <= p.

    Requires: sorted probabilities and cumulative sum.
    """
    if p >= 1.0:
        return logits  # No filtering

    # Sort logits in descending order
    sorted_logits, sorted_indices = torch.sort(
        logits, dim=-1, descending=True
    )  # Both: [B, V]

    # Convert to probabilities for cumulative sum
    sorted_probs = torch.softmax(sorted_logits, dim=-1)  # [B, V]

    # Cumulative sum of probabilities
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)  # [B, V]

    # Create mask: remove tokens where cumulative prob > p
    # Shift right by 1 so the token that crosses p is included
    sorted_mask = cumulative_probs - sorted_probs > p  # [B, V]

    # Set filtered tokens to -inf in sorted space
    sorted_logits[sorted_mask] = float('-inf')

    # Unsort: scatter back to original positions
    logits = torch.zeros_like(logits)
    logits.scatter_(1, sorted_indices, sorted_logits)

    return logits

The Sort Bottleneck

The dominant cost is torch.sort on a tensor of size [B,V][B, V]. GPU sorting algorithms (radix sort, bitonic sort) have complexity O(VlogV)O(V \log V) per batch element. For V=128256V = 128256:

log2(128256)17\log_2(128256) \approx 17

So we perform approximately 128256×172.18M128256 \times 17 \approx 2.18M comparison-swap operations per batch element. At large batch sizes, this is the most expensive step in the post-model pipeline.

📊

Top-p Sort Time vs Vocabulary Size (B=1, H100)

Vocabulary SizeSort TimeCumsum TimeTotal Top-p TimeFraction of Pipeline
32000 (Llama 2) 2.5 us 0.8 us 4.2 us 52%
50257 (GPT-2) 3.8 us 1.2 us 6.1 us 55%
128256 (Llama 3) 8.5 us 2.8 us 13.1 us 62%
151936 (Qwen 2) 9.8 us 3.2 us 15.0 us 63%
256000 (Gemma 2) 15.2 us 5.1 us 23.0 us 67%
Note: Sort dominates top-p cost. Larger vocabularies (trend in modern models) make top-p proportionally more expensive.

Combined Top-k + Top-p

Most serving systems apply top-k before top-p. This is not just for quality — it is an optimization. If top-k reduces the candidate set to kk tokens, the subsequent sort for top-p operates on kk elements instead of VV:

def apply_top_k_top_p(logits, k, p):
    """Combined: top-k first (cheap), then top-p on the reduced set."""
    # Step 1: Top-k (O(V) selection, not sort)
    if k > 0:
        top_k_values, top_k_indices = torch.topk(logits, k, dim=-1)
        # Now work with [B, k] tensor instead of [B, V]
    else:
        top_k_values = logits
        top_k_indices = torch.arange(logits.shape[-1]).expand_as(logits)

    # Step 2: Top-p on the top-k set (O(k log k) sort)
    sorted_values, sorted_local = torch.sort(
        top_k_values, dim=-1, descending=True
    )  # [B, k]

    sorted_probs = torch.softmax(sorted_values, dim=-1)
    cumsum = torch.cumsum(sorted_probs, dim=-1)

    mask = (cumsum - sorted_probs) > p
    sorted_values[mask] = float('-inf')

    # Unsort within top-k
    top_k_values.scatter_(1, sorted_local, sorted_values)

    # Scatter back to full vocabulary
    result = torch.full_like(logits, float('-inf'))
    result.scatter_(1, top_k_indices, top_k_values)
    return result

With k=50k = 50 and V=128256V = 128256: the sort is on 50 elements instead of 128256 — a 2565×2565\times reduction in sort work.

Always Apply Top-k Before Top-p

When both top-k and top-p are active, applying top-k first reduces the sort from O(VlogV)O(V \log V) to O(klogk)O(k \log k). For k=50k = 50, this is a 3-4 orders of magnitude reduction in sort work. This is why vLLM and SGLang always apply top-k before top-p, even if the user only requested top-p — they set a default kk (e.g., k=1k = -1 meaning “all”) but internally cap it.

6. Stage 5: Multinomial Sampling

After filtering, the remaining logits are converted to a probability distribution via softmax, and a token is sampled:

def sample_token(logits):
    """
    Sample a token from filtered logits.

    Args:
        logits: [B, V] with -inf for filtered tokens
    Returns:
        token_ids: [B] sampled token indices
    """
    probs = torch.softmax(logits, dim=-1)  # [B, V]
    token_ids = torch.multinomial(probs, num_samples=1)  # [B, 1]
    return token_ids.squeeze(-1)  # [B]

How torch.multinomial Works on GPU

torch.multinomial with num_samples=1 performs:

  1. Compute the CDF: cumsum(probs)O(V)O(V) per batch element
  2. Draw a uniform random number uU(0,1)u \sim U(0, 1)
  3. Binary search for the smallest index ii where CDF[i]u\text{CDF}[i] \geq uO(logV)O(\log V)

The total cost is dominated by the cumulative sum: O(BV)O(BV).

def multinomial_manual(probs):
    """Manual implementation of multinomial sampling."""
    # CDF via cumulative sum
    cdf = torch.cumsum(probs, dim=-1)  # [B, V]

    # Random uniform
    u = torch.rand(probs.shape[0], 1, device=probs.device)  # [B, 1]

    # Find first index where CDF >= u
    # This is equivalent to binary search
    mask = cdf >= u  # [B, V]
    # argmax on a boolean tensor returns the first True index
    token_ids = mask.to(torch.int32).argmax(dim=-1)  # [B]

    return token_ids

Greedy Decoding as a Special Case

When temperature is 0 (or when the user requests greedy decoding), the entire sampling pipeline reduces to:

token_ids = torch.argmax(logits, dim=-1)  # O(BV)

This skips all filtering stages. Production systems detect temperature=0 early and bypass the pipeline.

Beam search is not sampling — it maintains bb candidate sequences and expands each by the top-bb tokens:

def beam_search_step(logits, beam_scores, beam_width):
    """
    One step of beam search.

    Args:
        logits: [B * beam_width, V] logits from all beams
        beam_scores: [B * beam_width] accumulated log-probs
        beam_width: int
    """
    log_probs = torch.log_softmax(logits, dim=-1)  # [B*beam, V]

    # Add accumulated beam scores
    next_scores = log_probs + beam_scores.unsqueeze(-1)  # [B*beam, V]

    # Reshape to [B, beam * V] and select top beam_width
    B = next_scores.shape[0] // beam_width
    next_scores = next_scores.view(B, -1)  # [B, beam*V]

    top_scores, top_indices = torch.topk(
        next_scores, beam_width, dim=-1
    )  # [B, beam]

    # Decode beam and token indices
    beam_indices = top_indices // logits.shape[-1]  # which beam
    token_indices = top_indices % logits.shape[-1]  # which token

    return top_scores, beam_indices, token_indices

Beam search with bb beams requires a top-k selection over b×Vb \times V elements per batch — more expensive than sampling but deterministic.

7. Stage 6: Stop Criteria

After sampling a token, the system must decide whether generation is complete for each request in the batch.

EOS Token

The simplest stop criterion. Each model has a designated end-of-sequence token (or set of tokens). Llama 3 uses token ID 128001 (&lt;|end_of_text|&gt;) and 128009 (&lt;|eot_id|&gt;).

def check_eos(token_ids, eos_token_ids):
    """
    Check if any sampled token is an EOS token.

    Args:
        token_ids: [B] sampled tokens
        eos_token_ids: set of EOS token IDs
    Returns:
        finished: [B] boolean mask
    """
    finished = torch.zeros(token_ids.shape[0], dtype=torch.bool,
                           device=token_ids.device)
    for eos_id in eos_token_ids:
        finished |= (token_ids == eos_id)
    return finished

Max Tokens

A hard limit on the number of generated tokens. Trivial to implement:

def check_max_tokens(generated_count, max_tokens):
    """Check if we have generated enough tokens."""
    return generated_count >= max_tokens

Stop Strings

Stop strings (e.g., "\n\n", "```", "Human:") require matching against the decoded text, not token IDs. This is where things get complicated.

A stop string may span multiple tokens. The string "\n\n" could be:

  • One token: \n\n (if the tokenizer has this as a single token)
  • Two tokens: \n + \n

The matching must happen on decoded text, not token sequences:

class StopStringChecker:
    """
    Check if any stop string appears in the generated text.

    Handles the case where a stop string spans a token boundary.
    Maintains a buffer of recent decoded text.
    """

    def __init__(self, stop_strings, tokenizer):
        self.stop_strings = stop_strings
        self.tokenizer = tokenizer
        # Maximum stop string length determines buffer size
        self.max_stop_len = max(len(s) for s in stop_strings)
        # Per-request text buffers
        self.buffers = {}

    def check(self, request_id, new_token_id):
        """
        Decode the new token, append to buffer, check for stops.

        Returns: (is_stopped, matched_string, trimmed_output)
        """
        new_text = self.tokenizer.decode(
            [new_token_id], skip_special_tokens=False
        )

        if request_id not in self.buffers:
            self.buffers[request_id] = ""

        self.buffers[request_id] += new_text

        # Check all stop strings
        for stop_str in self.stop_strings:
            idx = self.buffers[request_id].find(stop_str)
            if idx != -1:
                # Found a stop string — trim output
                trimmed = self.buffers[request_id][:idx]
                return True, stop_str, trimmed

        # Trim buffer to max_stop_len (no need to keep old text)
        if len(self.buffers[request_id]) > self.max_stop_len * 2:
            self.buffers[request_id] = (
                self.buffers[request_id][-self.max_stop_len:]
            )

        return False, None, None
⚠️ Stop Strings Are Tricky

Stop string detection must handle partial matches at token boundaries. A stop string "Human:" might appear as tokens ["Hum", "an", ":"]. The system must buffer decoded text and check after each token. This is why stop string evaluation is always done on the CPU side in decoded text, not in the GPU sampling kernel.

Structured Stop Criteria

For structured generation (JSON, XML, function calls), stop criteria become more complex:

class JSONStopChecker:
    """Stop when a valid JSON object is complete."""

    def __init__(self):
        self.brace_depth = 0
        self.in_string = False
        self.escape_next = False

    def check_char(self, char):
        if self.escape_next:
            self.escape_next = False
            return False

        if char == '\\' and self.in_string:
            self.escape_next = True
            return False

        if char == '"':
            self.in_string = not self.in_string
            return False

        if self.in_string:
            return False

        if char == '{':
            self.brace_depth += 1
        elif char == '}':
            self.brace_depth -= 1
            if self.brace_depth == 0:
                return True  # Complete JSON object

        return False

8. Stage 7: Streaming and Detokenization

Streaming Architecture

In streaming mode, tokens are yielded to the client as they are generated. The server sends Server-Sent Events (SSE):

async def generate_stream(request, model, tokenizer):
    """
    Streaming token generation.
    Yields tokens as SSE events.
    """
    input_ids = tokenizer.encode(request.prompt)
    generated_ids = []

    stop_checker = StopStringChecker(
        request.stop_strings, tokenizer
    )
    detokenizer = IncrementalDetokenizer(tokenizer)

    for step in range(request.max_tokens):
        # Model forward pass
        logits = model.forward(input_ids + generated_ids)
        # logits: [1, V] (batch size 1 for this example)

        # Apply sampling pipeline
        logits = apply_logit_processors(
            logits, input_ids, generated_ids, request
        )
        logits = apply_temperature(logits, request.temperature)
        logits = apply_top_k(logits, request.top_k)
        logits = apply_top_p(logits, request.top_p)
        token_id = sample_token(logits).item()

        generated_ids.append(token_id)

        # Check EOS
        if token_id in tokenizer.eos_token_ids:
            yield {"event": "done", "data": ""}
            break

        # Incremental detokenization
        new_text = detokenizer.decode_token(token_id)

        # Check stop strings
        stopped, _, trimmed = stop_checker.check(
            request.id, token_id
        )
        if stopped:
            if trimmed:
                yield {"event": "token", "data": trimmed}
            yield {"event": "done", "data": ""}
            break

        # Yield the new text fragment
        if new_text:
            yield {"event": "token", "data": new_text}

    else:
        # max_tokens reached
        yield {"event": "done", "data": "[max_tokens]"}

Incremental Detokenization

Tokenizers like SentencePiece and BPE produce tokens that may not align with UTF-8 character boundaries. A single Unicode character might be split across multiple tokens. Incremental detokenization must handle this:

class IncrementalDetokenizer:
    """
    Detokenize one token at a time, handling multi-token characters.

    The challenge: decoding [token_1, token_2, ..., token_n] may produce
    different text than decode(token_1) + decode(token_2) + ...
    because tokenizers use context-dependent decoding (e.g., SentencePiece
    adds a space prefix for certain tokens).
    """

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.token_ids = []
        self.prev_text = ""

    def decode_token(self, token_id):
        """
        Decode one new token, return only the new text.

        Strategy: decode all tokens so far, diff with previous decode.
        This correctly handles context-dependent decoding.
        """
        self.token_ids.append(token_id)

        # Full decode of all tokens
        full_text = self.tokenizer.decode(
            self.token_ids, skip_special_tokens=True
        )

        # New text is the difference
        new_text = full_text[len(self.prev_text):]
        self.prev_text = full_text

        return new_text

This decode-everything-and-diff approach is correct but O(n2)O(n^2) over the full generation (decoding nn tokens total). For long generations, this becomes expensive. The optimization is to use a sliding window:

class IncrementalDetokenizerOptimized:
    """
    Optimized incremental detokenizer.
    Only re-decodes the last few tokens to handle boundary effects.
    """

    def __init__(self, tokenizer, context_window=5):
        self.tokenizer = tokenizer
        self.context_window = context_window
        self.token_ids = []
        self.committed_text = ""
        self.pending_ids = []

    def decode_token(self, token_id):
        self.token_ids.append(token_id)
        self.pending_ids.append(token_id)

        # Decode the pending window
        window = self.pending_ids[-self.context_window:]
        window_text = self.tokenizer.decode(
            window, skip_special_tokens=True
        )

        if len(self.pending_ids) > self.context_window:
            # Commit text from tokens that are no longer in the window
            old_window = self.pending_ids[-(self.context_window + 1):-1]
            # The text difference is safe to commit
            pass  # Simplified — real implementation tracks offsets

        # For correctness, we still diff against the previous state
        prev_window = self.pending_ids[-(len(self.pending_ids)):]
        if len(prev_window) > 1:
            prev_text = self.tokenizer.decode(
                prev_window[:-1], skip_special_tokens=True
            )
        else:
            prev_text = ""

        current_text = self.tokenizer.decode(
            prev_window, skip_special_tokens=True
        )

        return current_text[len(prev_text):]

9. Fused Logit Processing Kernels

In production, the separate stages (temperature, top-k, top-p, sample) are fused into a single GPU kernel to avoid multiple passes over the [B,V][B, V] tensor:

// Fused sampling kernel: temperature + top-k + top-p + sample
// in a single pass over the vocabulary
__global__ void fused_sampling_kernel(
    const float* __restrict__ logits,   // [B, V]
    const int* __restrict__ input_ids,  // [B, S] for penalties
    const int* __restrict__ seq_lens,   // [B] actual lengths
    float* __restrict__ output_probs,   // [B, V] workspace
    int* __restrict__ sampled_tokens,   // [B] output
    float temperature,
    int top_k,
    float top_p,
    float rep_penalty,
    int V,
    int S,
    unsigned long long seed
) {
    int batch_idx = blockIdx.x;
    int tid = threadIdx.x;

    // Step 1: Apply repetition penalty (cooperative across threads)
    // Each thread handles a chunk of the vocabulary
    extern __shared__ float shared_logits[];

    for (int v = tid; v < V; v += blockDim.x) {
        float logit = logits[batch_idx * V + v];

        // Check if token v appears in context
        bool in_context = false;
        int seq_len = seq_lens[batch_idx];
        for (int s = 0; s < seq_len; s++) {
            if (input_ids[batch_idx * S + s] == v) {
                in_context = true;
                break;
            }
        }

        if (in_context && rep_penalty != 1.0f) {
            logit = (logit > 0) ? logit / rep_penalty
                                : logit * rep_penalty;
        }

        // Step 2: Temperature
        logit /= temperature;

        output_probs[batch_idx * V + v] = logit;
    }
    __syncthreads();

    // Step 3: Top-k via partial sort (warp-level reduction)
    // Find k-th largest value using iterative threshold
    // ... (radix-based selection in shared memory)

    // Step 4: Softmax on surviving tokens
    // Step 5: Cumulative sum for top-p
    // Step 6: Sample from CDF

    // This is simplified — real implementations use multi-pass
    // with shared memory and warp shuffles
}

In practice, frameworks like vLLM use Triton kernels for this fusion:

import triton
import triton.language as tl

@triton.jit
def fused_top_k_top_p_sampling_kernel(
    logits_ptr,        # [B, V]
    output_ptr,        # [B] sampled token IDs
    temperature,
    top_k: tl.constexpr,
    top_p,
    V: tl.constexpr,
    BLOCK_V: tl.constexpr,
    seed,
):
    batch_idx = tl.program_id(0)

    # Load logits for this batch element
    offsets = tl.arange(0, BLOCK_V)
    mask = offsets < V
    logits = tl.load(
        logits_ptr + batch_idx * V + offsets,
        mask=mask,
        other=float('-inf')
    )

    # Temperature scaling
    logits = logits / temperature

    # Top-k: find k-th value via iterative approach
    # (Triton does not have a native topk — use approximate methods
    #  or multiple passes)
    # ... top-k filtering logic ...

    # Softmax
    max_logit = tl.max(logits, axis=0)
    logits = logits - max_logit
    exp_logits = tl.exp(logits)
    sum_exp = tl.sum(exp_logits, axis=0)
    probs = exp_logits / sum_exp

    # Top-p via cumulative sum (requires sorted order)
    # ... sorting in Triton is limited, typically uses
    #     bitonic sort for small V or falls back to PyTorch ...

    # Sample: CDF + uniform random
    cdf = tl.cumsum(probs, axis=0)
    rand_val = tl.rand(seed, batch_idx)
    selected = tl.sum((cdf < rand_val).to(tl.int32), axis=0)

    tl.store(output_ptr + batch_idx, selected)
💡 Fused Kernels in Practice

vLLM’s sampling kernel fuses temperature, top-k, top-p, and multinomial sampling into a single Triton kernel for batch sizes up to 256. For larger batches, it falls back to separate PyTorch operations because the Triton kernel’s shared memory usage limits occupancy. SGLang takes a similar approach with a custom CUDA kernel. The fusion eliminates 3-4 kernel launches and avoids writing intermediate [B,V][B, V] tensors to HBM.

10. Complete Reference Pipeline

Here is the complete pipeline assembled from all stages:

import torch
from dataclasses import dataclass

@dataclass
class SamplingParams:
    temperature: float = 1.0
    top_k: int = -1           # -1 means disabled
    top_p: float = 1.0        # 1.0 means disabled
    repetition_penalty: float = 1.0
    frequency_penalty: float = 0.0
    presence_penalty: float = 0.0
    max_tokens: int = 2048
    stop_strings: list = None
    eos_token_ids: set = None

    def __post_init__(self):
        if self.stop_strings is None:
            self.stop_strings = []
        if self.eos_token_ids is None:
            self.eos_token_ids = set()


class TokenGenerationPipeline:
    """
    Complete token generation pipeline.
    Handles batched generation with per-request sampling params.
    """

    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def process_logits(self, logits, input_ids, output_ids, params):
        """
        Full logit processing pipeline.

        Args:
            logits: [B, V] raw logits from model
            input_ids: [B, S_in] prompt token IDs
            output_ids: [B, S_out] generated token IDs so far
            params: SamplingParams
        Returns:
            token_ids: [B] sampled token indices
        """
        B, V = logits.shape

        # Stage 0: Greedy shortcut
        if params.temperature == 0.0:
            return torch.argmax(logits, dim=-1)

        # Stage 1a: Repetition penalty
        if params.repetition_penalty != 1.0:
            all_ids = torch.cat([input_ids, output_ids], dim=1)
            score = torch.gather(logits, 1, all_ids)
            score = torch.where(
                score > 0,
                score / params.repetition_penalty,
                score * params.repetition_penalty,
            )
            logits.scatter_(1, all_ids, score)

        # Stage 1b: Frequency penalty
        if params.frequency_penalty != 0.0:
            bin_counts = torch.zeros_like(logits)
            bin_counts.scatter_add_(
                1, output_ids,
                torch.ones_like(output_ids, dtype=logits.dtype),
            )
            logits -= params.frequency_penalty * bin_counts

        # Stage 1c: Presence penalty
        if params.presence_penalty != 0.0:
            bin_counts = torch.zeros_like(logits)
            bin_counts.scatter_add_(
                1, output_ids,
                torch.ones_like(output_ids, dtype=logits.dtype),
            )
            presence = (bin_counts > 0).float()
            logits -= params.presence_penalty * presence

        # Stage 2: Temperature
        logits = logits / params.temperature

        # Stage 3: Top-k
        if 0 < params.top_k < V:
            top_k_vals, _ = torch.topk(logits, params.top_k, dim=-1)
            threshold = top_k_vals[:, -1:]
            logits = torch.where(
                logits >= threshold, logits,
                torch.full_like(logits, float('-inf')),
            )

        # Stage 4: Top-p
        if params.top_p < 1.0:
            sorted_logits, sorted_idx = torch.sort(
                logits, dim=-1, descending=True
            )
            sorted_probs = torch.softmax(sorted_logits, dim=-1)
            cumsum = torch.cumsum(sorted_probs, dim=-1)
            mask = (cumsum - sorted_probs) > params.top_p
            sorted_logits[mask] = float('-inf')
            logits = torch.zeros_like(logits)
            logits.scatter_(1, sorted_idx, sorted_logits)

        # Stage 5: Sample
        probs = torch.softmax(logits, dim=-1)
        token_ids = torch.multinomial(probs, num_samples=1).squeeze(-1)

        return token_ids

    def generate(self, prompt, params):
        """
        Full generation loop.

        Args:
            prompt: str
            params: SamplingParams
        Yields:
            str: generated text fragments
        """
        input_ids = self.tokenizer.encode(prompt)
        input_tensor = torch.tensor(
            [input_ids], device='cuda'
        )  # [1, S]
        output_ids = torch.zeros(
            1, 0, dtype=torch.long, device='cuda'
        )

        detokenizer = IncrementalDetokenizer(self.tokenizer)
        stop_checker = StopStringChecker(
            params.stop_strings, self.tokenizer
        )

        for step in range(params.max_tokens):
            # Forward pass (with KV cache in practice)
            with torch.no_grad():
                logits = self.model(
                    torch.cat([input_tensor, output_ids], dim=1)
                )  # [1, S+step, V]
                logits = logits[:, -1, :]  # [1, V] — last position

            # Sample
            token_id = self.process_logits(
                logits, input_tensor, output_ids, params
            )  # [1]

            # Append to output
            output_ids = torch.cat(
                [output_ids, token_id.unsqueeze(0).unsqueeze(0)],
                dim=1,
            )

            tid = token_id.item()

            # Check EOS
            if tid in params.eos_token_ids:
                break

            # Detokenize
            new_text = detokenizer.decode_token(tid)

            # Check stop strings
            stopped, _, trimmed = stop_checker.check("req_0", tid)
            if stopped:
                if trimmed:
                    yield trimmed
                break

            if new_text:
                yield new_text

Per-Request Sampling Parameters in Batched Serving

In continuous batching, each request in the batch may have different sampling parameters. This requires either:

  1. Per-request kernel dispatch: Execute sampling separately for each request. Simple but serializes work.

  2. Batched with parameter arrays: Pass arrays of per-request parameters to the kernel.

def batched_sample_heterogeneous(
    logits,          # [B, V]
    temperatures,    # [B]
    top_ks,          # [B]
    top_ps,          # [B]
):
    """
    Batched sampling where each request has different params.
    Each request is processed independently but in parallel.
    """
    B, V = logits.shape
    token_ids = torch.empty(B, dtype=torch.long, device='cuda')

    # Temperature: vectorized (different T per request)
    logits = logits / temperatures.unsqueeze(-1)  # [B, V] / [B, 1]

    # Top-k: must handle different k per request
    for b in range(B):
        k = top_ks[b].item()
        p = top_ps[b].item()

        row = logits[b]  # [V]

        if 0 < k < V:
            vals, _ = torch.topk(row, k)
            row = torch.where(
                row >= vals[-1], row,
                torch.tensor(float('-inf'), device='cuda'),
            )

        if p < 1.0:
            sorted_row, sorted_idx = torch.sort(
                row, descending=True
            )
            probs = torch.softmax(sorted_row, dim=-1)
            cumsum = torch.cumsum(probs, dim=-1)
            mask = (cumsum - probs) > p
            sorted_row[mask] = float('-inf')
            row = torch.zeros_like(row)
            row.scatter_(0, sorted_idx, sorted_row)

        probs = torch.softmax(row, dim=-1)
        token_ids[b] = torch.multinomial(
            probs.unsqueeze(0), 1
        ).squeeze()

    return token_ids

The per-request loop is the bottleneck. vLLM solves this by grouping requests with identical sampling parameters and processing each group as a batch. In practice, most requests in a deployment use the same parameters (the API defaults), so this grouping is effective.

Sampling Pipeline Latency vs Batch Size (V=128256, H100)

(microseconds)
B=1
13 microseconds
B=8
28 microseconds
B=32
65 microseconds
B=64
110 microseconds
B=128
195 microseconds
B=256
360 microseconds
B=512
680 microseconds

Key Takeaways

  1. The pipeline order matters: Penalties first, then temperature, then top-k, then top-p, then sample. Changing the order produces different output distributions.

  2. Top-p is the expensive step: The O(VlogV)O(V \log V) sort dominates the sampling pipeline. Apply top-k first to reduce the sort domain.

  3. Stop strings require CPU-side text matching: Token-level checking is insufficient because stop strings can span token boundaries. Buffer decoded text and check after each token.

  4. Incremental detokenization is not trivial: Context-dependent tokenizers (SentencePiece) require decoding all tokens and diffing, or a carefully managed sliding window.

  5. Fused kernels eliminate HBM round-trips: Combining temperature + top-k + top-p + sample into one kernel avoids writing and reading the [B,V][B, V] tensor multiple times. This matters at large vocabulary sizes.

  6. At typical batch sizes, the sampling pipeline is negligible: The model forward pass takes 10-30 ms; the sampling pipeline takes 10-300 microseconds. Optimize the forward pass first. But get the sampling semantics exactly right — incorrect penalties or filtering produce measurably worse output quality.