Part of Series Inference Optimization Timeline 37 of 60
1 Transformer Fundamentals for Systems Engineers: The 10-Minute Bridge from Architecture to Inference 2 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 3 KV Cache: The Hidden Memory Giant in LLM Serving 4 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 5 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 6 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 7 Continuous Batching: The Complete Guide to LLM Inference Scheduling 8 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 9 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 10 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 11 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 12 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 13 Mamba and State Space Models: The O(n) Alternative to Attention 14 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 15 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 16 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 17 Model Loading and Cold Start: safetensors, mmap, and Startup Optimization 18 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 19 Kernel Autotuning: How TensorRT and torch.compile Find Optimal CUDA Kernels 20 Attention Kernel Comparison: FlashAttention vs FlashInfer vs xformers vs Triton 21 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 22 Dynamic Batching: Orca, Sarathi, and Iteration-Level Scheduling Algorithms 23 Memory Pool Management: Slab Allocators for GPU Inference 24 Prefill vs Decode Optimization: Different Bottlenecks, Different Solutions 25 Decode Optimization: CUDA Graphs, Persistent Batches, and Speculative Verification 26 Multi-Model Serving: GPU Sharing, Model Switching, and Adapter Pool Management 27 Structured Output Acceleration: Compressed FSMs, Speculative JSON, and Grammar Caching 28 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 29 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 30 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 31 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 32 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification 33 Disaggregated Serving v2: Mooncake KV-Centric Architecture and LoongServe Elastic SP 34 Request Preemption and Priority Scheduling in Production LLM Serving 35 Autoscaling LLM Inference: Signals, Lag, Warm Pools, and Cost-Optimal Scaling 36 The Inference Stack in 2026: From HTTP Request to GPU Kernel and Back 37 Video and Audio LLM Serving: Temporal Encoding, Chunked Streaming, and Latency Budgets 38 KV Cache Compression and Eviction: H2O, Attention Sinks, Sliding Window, and Quantized KV 39 Distributed Inference: Tensor Parallelism vs Pipeline Parallelism for Serving 40 Serving Benchmark Methodology: How to Properly Measure LLM Inference Performance 41 Compute-Communication Overlap: Hiding Distributed Training Latency 42 DeepSpeed ZeRO: Memory Optimization for Distributed Training at Scale 43 Pipeline Parallelism: From GPipe to DualPipe -- Eliminating the Bubble 44 Gradient Compression for Distributed Training: Promise, Reality, and Where It Still Wins 45 The Definitive Guide to Distributed Parallelism: Data, Tensor, Pipeline, Expert, and Sequence Parallelism for Large-Scale Training 46 Decoding Performance: Beam Search vs Sampling — Latency, Throughput, Memory, and the Full Design Space 47 LLM Prefill Phase Optimization: Why Prompt Processing Is Compute-Bound and How to Fix It 48 LLM Serving Engines: vLLM vs SGLang vs TensorRT-LLM — A Systems Comparison 49 Request Routing for LLM Inference: From Naive Load Balancing to KV Cache-Aware Scheduling 50 Why Adam Is Expensive and What To Do About It: 8-bit Adam, Adafactor, CAME, and the Memory Math of Optimizers 51 How Large Models Actually Get Loaded: Safetensors, mmap, Tensor Parallelism, and Progressive Loading 52 Mixed Precision Training: The Complete Precision Landscape from FP32 to FP4 53 Model Compression: Pruning, Distillation, and Why Quantization Won 54 From NAS to Scaling Laws: How We Design LLM Architectures Now 55 NVIDIA NCCL Performance Tuning for Multi-GPU Training 56 ONNX Runtime in Practice: Graph Optimization, Execution Providers, Quantization, and When ORT Is the Right Choice 57 Optimizing GEMM for Neural Networks: BLAS vs Custom Kernels (Nov 2019) 58 Long Context: From Sparse Attention to Ring Attention 59 TensorRT-LLM: Graph Optimization for Maximum Inference Performance 60 Long Context LLMs: From 2K to 1M Tokens

A Llama 70B model with GQA (8 KV heads, head dimension 128, 80 layers) at 128K context length requires 2×80×8×128×128000×2=41.9 GB2 \times 80 \times 8 \times 128 \times 128000 \times 2 = 41.9\text{ GB} of KV cache per request in FP16. On an 8x H100 cluster (640 GB total HBM), the model weights alone consume 140 GB (FP16), leaving 500 GB for KV cache. At 41.9 GB per request, the system can serve only 11 concurrent 128K-context requests. Reducing KV cache size directly increases serving throughput.

Four strategies address this, each with different tradeoffs between memory savings and output quality.

Strategy 1: KV Cache Quantization (INT8/FP8)

The most straightforward approach: store KV cache values in lower precision. KV cache values have a bounded range (post-softmax attention is applied to these values), making them amenable to quantization with minimal quality loss.

Per-Token Asymmetric INT8

import torch

class KVCacheQuantizerINT8:
    """Quantize KV cache to INT8 with per-token scaling."""

    def __init__(self):
        pass

    def quantize_kv(self, key, value):
        """Quantize K and V tensors to INT8.
        key: [batch, num_kv_heads, seq_len, head_dim] in FP16
        value: [batch, num_kv_heads, seq_len, head_dim] in FP16
        returns: quantized tensors + scales
        """
        k_quant, k_scale, k_zero = self._quantize_per_token(key)
        v_quant, v_scale, v_zero = self._quantize_per_token(value)

        return {
            "k_quant": k_quant,    # int8
            "k_scale": k_scale,    # fp16, per-token
            "k_zero": k_zero,      # fp16, per-token
            "v_quant": v_quant,    # int8
            "v_scale": v_scale,    # fp16
            "v_zero": v_zero,      # fp16
        }

    def _quantize_per_token(self, tensor):
        """Asymmetric per-token INT8 quantization.
        Each token (across head_dim) gets its own scale and zero point."""
        # tensor: [batch, heads, seq_len, head_dim]
        # Compute min/max along head_dim
        t_min = tensor.amin(dim=-1, keepdim=True)
        t_max = tensor.amax(dim=-1, keepdim=True)

        # Scale and zero point
        scale = (t_max - t_min) / 255.0
        scale = torch.clamp(scale, min=1e-8)  # Avoid division by zero
        zero = t_min

        # Quantize
        quantized = torch.clamp(
            torch.round((tensor - zero) / scale), 0, 255
        ).to(torch.uint8)

        return quantized, scale.to(torch.float16), zero.to(torch.float16)

    def dequantize_kv(self, quantized_kv):
        """Dequantize for attention computation."""
        k = (quantized_kv["k_quant"].float() *
             quantized_kv["k_scale"] + quantized_kv["k_zero"]).half()
        v = (quantized_kv["v_quant"].float() *
             quantized_kv["v_scale"] + quantized_kv["v_zero"]).half()
        return k, v

    def memory_savings(self, seq_len, num_kv_heads, head_dim, num_layers):
        """Calculate memory savings from INT8 quantization."""
        # FP16: 2 bytes per element
        fp16_bytes = 2 * num_layers * 2 * num_kv_heads * seq_len * head_dim * 2

        # INT8: 1 byte per element + scale/zero overhead
        int8_data = 2 * num_layers * 2 * num_kv_heads * seq_len * head_dim * 1
        # Scale and zero: 2 values per token, per head, FP16
        scale_overhead = 2 * num_layers * 2 * num_kv_heads * seq_len * 2 * 2

        int8_total = int8_data + scale_overhead
        savings = 1 - int8_total / fp16_bytes

        return {
            "fp16_gb": fp16_bytes / 1e9,
            "int8_gb": int8_total / 1e9,
            "savings_pct": savings * 100,
            "ratio": fp16_bytes / int8_total,
        }

FP8 KV Cache (Hopper Native)

On H100 GPUs, FP8 (E4M3 or E5M2) is natively supported in tensor cores:

class KVCacheQuantizerFP8:
    """FP8 KV cache quantization using Hopper native FP8 support."""

    def __init__(self, fp8_format="e4m3"):
        self.fp8_dtype = torch.float8_e4m3fn  # E4M3: more precision, less range
        # E5M2 alternative: torch.float8_e5m2 for more range, less precision

    def quantize_kv(self, key, value):
        """Quantize to FP8 with per-tensor scaling.
        FP8 E4M3 range: [-448, 448], precision: ~3.5 decimal digits
        FP16 range: [-65504, 65504], precision: ~3.3 decimal digits
        """
        # Per-tensor scale to fit FP16 range into FP8 range
        k_amax = key.abs().amax()
        v_amax = value.abs().amax()

        k_scale = k_amax / 448.0  # Max representable in E4M3
        v_scale = v_amax / 448.0

        k_fp8 = (key / k_scale).to(self.fp8_dtype)
        v_fp8 = (value / v_scale).to(self.fp8_dtype)

        return {
            "k_fp8": k_fp8,
            "k_scale": k_scale,
            "v_fp8": v_fp8,
            "v_scale": v_scale,
        }

    def dequantize_for_attention(self, quantized_kv, query):
        """Dequantize and compute attention.
        On H100, FP8 GEMMs are natively supported, so we can
        compute Q @ K^T directly in FP8 without explicit dequant."""

        # Option 1: Dequantize then compute (fallback)
        k = quantized_kv["k_fp8"].to(torch.float16) * quantized_kv["k_scale"]
        v = quantized_kv["v_fp8"].to(torch.float16) * quantized_kv["v_scale"]

        # Option 2: FP8 matmul (Hopper native, 2x throughput)
        # scores = torch._scaled_mm(
        #     query.to(torch.float8_e4m3fn),
        #     quantized_kv["k_fp8"].transpose(-2, -1),
        #     scale_a=query_scale,
        #     scale_b=quantized_kv["k_scale"],
        # )

        return k, v
📊

KV Cache Quantization: Memory and Quality Impact

PrecisionBytes/ElementMemory (128K, Llama 70B)Quality Loss (PPL)Throughput Gain
FP16 (baseline) 2 41.9 GB 0 (baseline) 1.0x
FP8 (E4M3) 1 21.0 GB +0.02 PPL 2.0x
INT8 (per-token) ~1.03 21.6 GB +0.05 PPL 1.94x
INT4 (per-group) ~0.56 11.7 GB +0.3 PPL 3.6x

Strategy 2: H2O (Heavy Hitter Oracle)

H2O observes that attention patterns are highly non-uniform: a small fraction of tokens receive the majority of attention weight. These “heavy hitter” tokens should be kept, while low-attention tokens can be evicted.

Attention Score Tracking

class H2OKVCache:
    """H2O: Heavy-Hitter Oracle for KV cache eviction.
    Tracks cumulative attention scores and evicts low-importance tokens."""

    def __init__(self, max_cache_size, num_layers, num_kv_heads,
                 head_dim, heavy_hitter_ratio=0.5, recent_ratio=0.25):
        self.max_size = max_cache_size
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim

        # Budget allocation
        self.heavy_budget = int(max_cache_size * heavy_hitter_ratio)
        self.recent_budget = int(max_cache_size * recent_ratio)

        # Per-layer, per-head attention score accumulators
        # Tracks cumulative attention each token has received
        self.attention_scores = {}  # layer -> [batch, heads, seq_len]

        # KV cache storage
        self.kv_cache = {}  # layer -> (K, V) tensors
        self.token_indices = {}  # layer -> which original positions are cached

    def update_attention_scores(self, layer_idx, attention_weights):
        """Called after each attention computation.
        attention_weights: [batch, num_heads, 1, seq_len] (decode step)
        """
        if layer_idx not in self.attention_scores:
            self.attention_scores[layer_idx] = torch.zeros_like(
                attention_weights.squeeze(2)
            )

        # Accumulate: each token's total received attention
        self.attention_scores[layer_idx] += attention_weights.squeeze(2)

    def evict_if_needed(self, layer_idx):
        """Evict low-importance tokens if cache exceeds budget."""
        if layer_idx not in self.kv_cache:
            return

        k, v = self.kv_cache[layer_idx]
        current_size = k.shape[2]  # seq_len dimension

        if current_size <= self.max_size:
            return  # No eviction needed

        scores = self.attention_scores[layer_idx]  # [batch, heads, seq_len]

        # Average across heads to get per-token importance
        token_importance = scores.mean(dim=1)  # [batch, seq_len]

        # Always keep: recent tokens (last recent_budget positions)
        recent_mask = torch.zeros(current_size, dtype=torch.bool,
                                   device=k.device)
        recent_mask[-self.recent_budget:] = True

        # Heavy hitters: top-K by cumulative attention score
        # Exclude recent tokens from heavy hitter selection
        non_recent_scores = token_importance.clone()
        non_recent_scores[:, -self.recent_budget:] = -float("inf")

        _, heavy_indices = torch.topk(
            non_recent_scores, self.heavy_budget, dim=-1
        )

        # Build keep mask: heavy hitters + recent tokens
        keep_mask = recent_mask.unsqueeze(0).expand_as(token_importance)
        keep_mask.scatter_(1, heavy_indices, True)

        # Evict: keep only selected tokens
        keep_indices = keep_mask[0].nonzero().squeeze(-1)  # Assume batch=1

        self.kv_cache[layer_idx] = (
            k[:, :, keep_indices, :],
            v[:, :, keep_indices, :],
        )
        self.attention_scores[layer_idx] = scores[:, :, keep_indices]
        self.token_indices[layer_idx] = keep_indices

    def get_kv(self, layer_idx):
        """Get current KV cache for attention computation."""
        return self.kv_cache[layer_idx]

    def append_kv(self, layer_idx, new_k, new_v, attention_weights):
        """Append new KV and update scores, then evict if needed."""
        if layer_idx in self.kv_cache:
            k, v = self.kv_cache[layer_idx]
            self.kv_cache[layer_idx] = (
                torch.cat([k, new_k], dim=2),
                torch.cat([v, new_v], dim=2),
            )
        else:
            self.kv_cache[layer_idx] = (new_k, new_v)

        self.update_attention_scores(layer_idx, attention_weights)
        self.evict_if_needed(layer_idx)
ℹ️ Note

H2O’s key observation: in Llama-family models, approximately 5-10% of tokens consistently receive over 90% of cumulative attention weight. These are typically: (1) the BOS/system prompt tokens, (2) tokens marking structural boundaries (newlines, punctuation), and (3) semantically important content tokens. The heavy hitter pattern is consistent across layers, meaning the same tokens tend to be important at every layer.

H2O Quality Analysis

def analyze_h2o_quality(model, dataset, cache_sizes):
    """Measure quality degradation at different H2O cache budgets."""
    results = []

    for max_cache in cache_sizes:
        h2o_cache = H2OKVCache(
            max_cache_size=max_cache,
            num_layers=model.config.num_hidden_layers,
            num_kv_heads=model.config.num_key_value_heads,
            head_dim=model.config.hidden_size // model.config.num_attention_heads,
        )

        total_loss = 0
        num_tokens = 0

        for sample in dataset:
            # Run model with H2O cache
            logits = model.forward_with_h2o(sample.input_ids, h2o_cache)
            loss = cross_entropy(logits, sample.labels)
            total_loss += loss.item() * sample.labels.numel()
            num_tokens += sample.labels.numel()

        ppl = torch.exp(torch.tensor(total_loss / num_tokens))
        results.append({
            "max_cache": max_cache,
            "perplexity": ppl.item(),
            "memory_ratio": max_cache / len(sample.input_ids),
        })

    return results

H2O Perplexity vs Cache Budget (Llama 70B, 128K Context)

line
Metric 5%10%20%30%50%75%100%
H2O (heavy hitter + recent)
12.8
7.2
5.8
5.3
5.05
4.98
4.95
Random eviction
45.2
18.5
9.8
7.2
5.6
5.1
4.95
Full KV (baseline)
4.95
4.95
4.95
4.95
4.95
4.95
4.95

Strategy 3: Attention Sinks

StreamingLLM discovered that the first few tokens in any sequence receive disproportionately high attention, regardless of their semantic content. These “attention sinks” act as learned bias terms in the attention computation. Removing them causes catastrophic quality degradation even if the content tokens are preserved.

class AttentionSinkCache:
    """Attention sink + sliding window KV cache.
    Always keeps first N tokens (sinks) + last W tokens (window)."""

    def __init__(self, num_sink_tokens=4, window_size=1024,
                 num_layers=80, num_kv_heads=8, head_dim=128):
        self.num_sinks = num_sink_tokens
        self.window_size = window_size
        self.total_budget = num_sink_tokens + window_size

        # Pre-allocate cache
        self.k_cache = {}  # layer -> [batch, heads, total_budget, head_dim]
        self.v_cache = {}
        self.current_len = 0  # Total tokens seen so far

    def append(self, layer_idx, new_k, new_v):
        """Append new KV token, maintaining sink + window invariant.
        new_k, new_v: [batch, heads, 1, head_dim]
        """
        self.current_len += 1

        if layer_idx not in self.k_cache:
            self.k_cache[layer_idx] = new_k
            self.v_cache[layer_idx] = new_v
            return

        k = self.k_cache[layer_idx]
        v = self.v_cache[layer_idx]
        seq_len = k.shape[2]

        if seq_len < self.total_budget:
            # Cache not full yet, just append
            self.k_cache[layer_idx] = torch.cat([k, new_k], dim=2)
            self.v_cache[layer_idx] = torch.cat([v, new_v], dim=2)
        else:
            # Cache full: keep sinks + shift window + add new token
            # Layout: [sink_0, ..., sink_N, window_start, ..., window_end]
            sinks_k = k[:, :, :self.num_sinks, :]
            sinks_v = v[:, :, :self.num_sinks, :]

            # Window: drop oldest window token, append new
            window_k = k[:, :, self.num_sinks + 1:, :]  # Drop first window token
            window_v = v[:, :, self.num_sinks + 1:, :]

            self.k_cache[layer_idx] = torch.cat(
                [sinks_k, window_k, new_k], dim=2
            )
            self.v_cache[layer_idx] = torch.cat(
                [sinks_v, window_v, new_v], dim=2
            )

    def get_kv(self, layer_idx):
        """Return current KV cache for attention."""
        return self.k_cache[layer_idx], self.v_cache[layer_idx]

    def get_position_ids(self):
        """Return position IDs for the cached tokens.
        Sink tokens keep their original positions (0, 1, ..., num_sinks-1).
        Window tokens have positions relative to current_len."""
        sink_positions = list(range(self.num_sinks))
        cached_len = min(self.current_len - self.num_sinks, self.window_size)
        window_start = max(self.num_sinks, self.current_len - self.window_size)
        window_positions = list(range(window_start, self.current_len))
        return torch.tensor(sink_positions + window_positions)
⚠️ Warning

Attention sink positions must use their ORIGINAL position IDs with RoPE, not consecutive IDs. If sink tokens 0-3 are at positions [0, 1, 2, 3] and window tokens are at positions [95000, 95001, …, 96023], the position IDs must reflect this gap. Using consecutive IDs [0, 1, 2, 3, 4, 5, …] breaks the model because RoPE encodes absolute position information.

Attention Sinks Quality Analysis

The critical question: how many sink tokens do you need? Research shows that 4 sink tokens capture the dominant attention bias pattern. Adding more sinks beyond 4 provides diminishing returns:

def evaluate_sink_counts(model, eval_data, window_size=1024):
    """Measure quality impact of different sink token counts."""
    results = []
    for num_sinks in [0, 1, 2, 4, 8, 16, 32]:
        cache = AttentionSinkCache(
            num_sink_tokens=num_sinks,
            window_size=window_size,
            num_layers=model.config.num_hidden_layers,
            num_kv_heads=model.config.num_key_value_heads,
            head_dim=model.config.hidden_size // model.config.num_attention_heads,
        )

        total_loss = 0
        total_tokens = 0
        for sample in eval_data:
            logits = model.forward_with_cache(sample.input_ids, cache)
            loss = torch.nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                sample.labels.view(-1),
                reduction="sum",
            )
            total_loss += loss.item()
            total_tokens += sample.labels.numel()

        ppl = torch.exp(torch.tensor(total_loss / total_tokens)).item()
        results.append({
            "num_sinks": num_sinks,
            "perplexity": ppl,
            "total_cache_size": num_sinks + window_size,
        })
    return results
📊

Sink Token Count vs Perplexity (Llama 70B, Window=1024, 32K Input)

Sink TokensTotal CachePerplexityvs Full KV
0 (window only) 1024 15.8 +10.85
1 1025 6.2 +1.25
4 1028 5.6 +0.65
16 1040 5.5 +0.55
32 1056 5.5 +0.55
Full KV 32768 4.95 baseline

Without any sink tokens (pure sliding window on a model not trained for it), perplexity degrades catastrophically from 4.95 to 15.8. Adding just 4 sink tokens reduces the gap to +0.65 PPL. Beyond 4 sinks, the improvement plateaus, confirming that the attention sink phenomenon is concentrated in the first few token positions.

Strategy 4: Sliding Window Attention

Some models (Mistral, Phi) are trained with sliding window attention: each token can only attend to the last WW tokens. This inherently bounds the KV cache to WW entries per layer.

class SlidingWindowKVCache:
    """Fixed-size sliding window KV cache.
    Only works correctly with models trained using sliding window attention."""

    def __init__(self, window_size=4096, num_layers=32,
                 num_kv_heads=8, head_dim=128, device="cuda"):
        self.W = window_size
        # Pre-allocate circular buffers
        self.k_buffer = torch.zeros(
            num_layers, 1, num_kv_heads, window_size, head_dim,
            dtype=torch.float16, device=device
        )
        self.v_buffer = torch.zeros(
            num_layers, 1, num_kv_heads, window_size, head_dim,
            dtype=torch.float16, device=device
        )
        self.write_pos = 0  # Circular buffer write position
        self.total_written = 0

    def append(self, layer_idx, new_k, new_v):
        """Append new KV at circular buffer position."""
        pos = self.write_pos % self.W
        self.k_buffer[layer_idx, :, :, pos, :] = new_k.squeeze(2)
        self.v_buffer[layer_idx, :, :, pos, :] = new_v.squeeze(2)

        if layer_idx == 0:  # Only increment once per token
            self.write_pos += 1
            self.total_written += 1

    def get_kv(self, layer_idx):
        """Get KV cache in correct temporal order."""
        if self.total_written <= self.W:
            # Buffer not full yet, return what we have
            return (
                self.k_buffer[layer_idx, :, :, :self.total_written, :],
                self.v_buffer[layer_idx, :, :, :self.total_written, :],
            )

        # Buffer full: reorder from circular to temporal
        start = self.write_pos % self.W
        k = torch.cat([
            self.k_buffer[layer_idx, :, :, start:, :],
            self.k_buffer[layer_idx, :, :, :start, :],
        ], dim=2)
        v = torch.cat([
            self.v_buffer[layer_idx, :, :, start:, :],
            self.v_buffer[layer_idx, :, :, :start, :],
        ], dim=2)
        return k, v

    def memory_usage(self):
        """Fixed memory regardless of sequence length."""
        return self.k_buffer.nbytes + self.v_buffer.nbytes

Combining Strategies

The four strategies are not mutually exclusive. Production systems combine them:

class CombinedKVCacheStrategy:
    """Combine quantization + H2O + attention sinks."""

    def __init__(self, config):
        self.quantizer = KVCacheQuantizerFP8()

        # Sink tokens: always keep first 4 tokens in FP16 (no quantization)
        self.num_sinks = 4

        # Heavy hitters: keep top 20% by attention score in FP8
        self.heavy_ratio = 0.2

        # Recent window: keep last 512 tokens in FP8
        self.recent_window = 512

        # Everything else: evicted
        self.max_cache = config.max_kv_cache_tokens

    def manage_cache(self, layer_idx, kv_cache, attention_weights):
        """Combined cache management after each decode step."""
        k, v = kv_cache
        seq_len = k.shape[2]

        if seq_len <= self.max_cache:
            # Under budget: just quantize
            return self.quantizer.quantize_kv(k, v)

        # Over budget: apply eviction + quantization
        # 1. Keep sinks (always, FP16)
        sink_k = k[:, :, :self.num_sinks, :]
        sink_v = v[:, :, :self.num_sinks, :]

        # 2. Keep recent window (FP8)
        recent_k = k[:, :, -self.recent_window:, :]
        recent_v = v[:, :, -self.recent_window:, :]
        recent_quant = self.quantizer.quantize_kv(recent_k, recent_v)

        # 3. H2O on middle tokens
        middle_k = k[:, :, self.num_sinks:-self.recent_window, :]
        middle_v = v[:, :, self.num_sinks:-self.recent_window, :]
        middle_scores = attention_weights[:, :, :, self.num_sinks:-self.recent_window]

        heavy_budget = self.max_cache - self.num_sinks - self.recent_window
        _, heavy_idx = torch.topk(
            middle_scores.mean(dim=1).squeeze(1), heavy_budget, dim=-1
        )
        heavy_k = middle_k[:, :, heavy_idx.squeeze(0), :]
        heavy_v = middle_v[:, :, heavy_idx.squeeze(0), :]
        heavy_quant = self.quantizer.quantize_kv(heavy_k, heavy_v)

        return {
            "sinks": (sink_k, sink_v),  # FP16
            "heavy_hitters": heavy_quant,  # FP8
            "recent": recent_quant,  # FP8
        }
📊

Combined Strategy: Memory Usage and Quality (Llama 70B, 128K Context)

StrategyMemoryCompressionPPL ImpactMax Concurrent Requests
Full FP16 41.9 GB 1.0x Baseline 11
FP8 only 21.0 GB 2.0x +0.02 23
H2O (50% budget) 21.0 GB 2.0x +0.10 23
Sinks + Window (W=4K) 1.3 GB 32x +0.8 (long deps) 384
FP8 + H2O (50%) 10.5 GB 4.0x +0.12 47
Sinks + H2O (30%) + FP8 6.3 GB 6.6x +0.25 79

Max Concurrent 128K Requests vs Compression Strategy (8x H100, 500 GB KV Budget)

Metric Full FP16FP8H2O 50%Sink+WindowFP8+H2OSink+H2O+FP8
Max Concurrent Requests
11
23
23
384
47
79

Impact on Attention Kernel

Quantized and evicted KV caches require modified attention kernels:

def quantized_paged_attention(query, kv_blocks_fp8, kv_scales,
                               page_table, seq_lens):
    """Attention kernel that operates on FP8 KV cache.
    Dequantizes on-the-fly during attention computation."""

    batch_size = query.shape[0]
    num_heads = query.shape[1]
    head_dim = query.shape[-1]

    output = torch.zeros_like(query)

    for b in range(batch_size):
        # Gather this request's KV blocks
        num_blocks = (seq_lens[b] + 15) // 16
        all_scores = []
        all_values = []

        for block_idx in range(num_blocks):
            physical_block = page_table[b, block_idx]

            # Load FP8 K block and dequantize
            k_fp8 = kv_blocks_fp8[physical_block, 0]  # [kv_heads, block_size, head_dim]
            k_scale = kv_scales[physical_block, 0]
            k_block = k_fp8.to(torch.float16) * k_scale

            v_fp8 = kv_blocks_fp8[physical_block, 1]
            v_scale = kv_scales[physical_block, 1]
            v_block = v_fp8.to(torch.float16) * v_scale

            # Compute Q @ K^T for this block
            # query: [1, num_heads, 1, head_dim]
            # k_block: [num_kv_heads, block_size, head_dim] (needs GQA expansion)
            scores = torch.matmul(
                query[b:b+1], k_block.transpose(-2, -1)
            ) / (head_dim ** 0.5)
            all_scores.append(scores)
            all_values.append(v_block)

        # Concatenate and compute attention
        all_scores = torch.cat(all_scores, dim=-1)
        all_values = torch.cat(all_values, dim=-2)

        # Trim to actual sequence length
        all_scores = all_scores[:, :, :, :seq_lens[b]]
        all_values = all_values[:, :seq_lens[b], :]

        attn = torch.softmax(all_scores, dim=-1)
        output[b] = torch.matmul(attn, all_values)

    return output
Performance

FP8 KV cache quantization with on-the-fly dequantization in the attention kernel adds negligible latency (less than 2% overhead) because the dequantization is bandwidth-free: the FP8 to FP16 conversion happens in registers after the data is already loaded from HBM. The bandwidth savings of loading 1 byte instead of 2 bytes per element directly translates to 2x faster KV cache reads during decode attention.

Choosing the Right Strategy

The right strategy depends on the workload:

📊

Strategy Selection Guide

WorkloadRecommended StrategyReason
Short context (less than 4K) FP8 quantization only Minimal KV, just save memory for more batching
Medium context (4K-32K) FP8 + H2O (50%) Good compression with minimal quality loss
Long context (32K-128K) Sinks + H2O (30%) + FP8 Aggressive compression needed
Streaming/infinite context Sinks + sliding window Fixed memory, accepts long-range quality loss
Multi-turn chat FP8 + prefix caching Cache shared system prompt, compress per-turn KV

Implementation in vLLM and SGLang

Both major serving frameworks have implemented KV cache quantization:

# vLLM: enable FP8 KV cache via command line
"""
python -m vllm.entrypoints.openai.api_server \
    --model meta-llama/Llama-3.1-70B \
    --kv-cache-dtype fp8_e4m3 \
    --tensor-parallel-size 8
"""

# SGLang: enable FP8 KV cache
"""
python -m sglang.launch_server \
    --model-path meta-llama/Llama-3.1-70B \
    --kv-cache-dtype fp8_e4m3 \
    --tp 8
"""

# In vLLM, FP8 KV cache is implemented in the attention backends:
class FP8KVCacheAttention:
    """Simplified FP8 KV cache management in vLLM."""

    def __init__(self, num_blocks, block_size, num_kv_heads, head_dim):
        # Allocate KV cache in FP8 format
        # Half the memory of FP16 cache
        self.k_cache = torch.zeros(
            num_blocks, block_size, num_kv_heads, head_dim,
            dtype=torch.float8_e4m3fn, device="cuda"
        )
        self.v_cache = torch.zeros(
            num_blocks, block_size, num_kv_heads, head_dim,
            dtype=torch.float8_e4m3fn, device="cuda"
        )
        # Per-block scale factors (FP32 for accuracy)
        self.k_scales = torch.ones(
            num_blocks, dtype=torch.float32, device="cuda"
        )
        self.v_scales = torch.ones(
            num_blocks, dtype=torch.float32, device="cuda"
        )

    def write_kv(self, block_idx, slot_in_block, k_fp16, v_fp16):
        """Write FP16 KV values into FP8 cache with scaling."""
        # Compute scale for this write
        k_amax = k_fp16.abs().amax()
        v_amax = v_fp16.abs().amax()

        k_scale = k_amax / 448.0
        v_scale = v_amax / 448.0

        # Quantize and store
        self.k_cache[block_idx, slot_in_block] = (k_fp16 / k_scale).to(
            torch.float8_e4m3fn
        )
        self.v_cache[block_idx, slot_in_block] = (v_fp16 / v_scale).to(
            torch.float8_e4m3fn
        )

        # Update running scale (exponential moving average)
        alpha = 0.1
        self.k_scales[block_idx] = (
            (1 - alpha) * self.k_scales[block_idx] + alpha * k_scale
        )
        self.v_scales[block_idx] = (
            (1 - alpha) * self.v_scales[block_idx] + alpha * v_scale
        )

Benchmarking KV Compression Impact

The only way to validate a KV compression strategy is to measure both memory savings and quality impact on your target workload:

def benchmark_kv_compression(model, strategies, eval_dataset):
    """Compare KV compression strategies on throughput and quality."""
    results = []

    for strategy_name, strategy in strategies.items():
        # Measure quality (perplexity on evaluation set)
        total_loss = 0
        total_tokens = 0
        for sample in eval_dataset:
            logits = model.forward_with_kv_strategy(
                sample.input_ids, strategy
            )
            loss = torch.nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                sample.labels.view(-1),
                reduction="sum",
            )
            total_loss += loss.item()
            total_tokens += sample.labels.numel()

        ppl = torch.exp(torch.tensor(total_loss / total_tokens)).item()

        # Measure memory
        kv_memory_gb = strategy.memory_usage() / 1e9

        # Estimate max concurrent requests
        total_hbm = 80 * 8  # 8x H100
        weight_memory = 140  # 70B FP16
        kv_budget = total_hbm - weight_memory
        max_requests = int(kv_budget / kv_memory_gb)

        results.append({
            "strategy": strategy_name,
            "perplexity": ppl,
            "kv_memory_gb": kv_memory_gb,
            "max_concurrent": max_requests,
            "throughput_relative": max_requests / results[0]["max_concurrent"] if results else 1.0,
        })

    return results

KV cache compression is the primary lever for increasing serving throughput at long context lengths. The 41.9 GB per request for 128K context means that without compression, most of the GPU cluster’s memory is consumed by a handful of requests. FP8 quantization alone doubles capacity with negligible quality impact. Adding H2O eviction on top provides another 2x. Together, they enable 4x more concurrent long-context requests, directly translating to 4x higher throughput for the same hardware. The choice between strategies is ultimately an empirical question: measure perplexity on your specific workload at your target compression ratio, and pick the strategy that preserves quality at the memory budget you need.