Part of Series Inference Optimization Timeline 2 of 23
1 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 2 KV Cache: The Hidden Memory Giant in LLM Serving 3 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 4 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 5 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 6 Continuous Batching: The Complete Guide to LLM Inference Scheduling 7 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 8 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 9 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 10 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 11 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 12 Mamba and State Space Models: The O(n) Alternative to Attention 13 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 14 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 15 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 16 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 17 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 18 Memory Pool Management: Slab Allocators for GPU Inference 19 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 20 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 21 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 22 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 23 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification

If you have ever tried to serve a large language model in production, you have encountered a frustrating reality: the model weights fit on your GPUs, but you cannot serve many concurrent users. The bottleneck is almost always the KV cache — a runtime data structure that grows with every token generated, for every request, across every layer of the model. At serving scale, the KV cache routinely consumes more GPU memory than the model itself.

This post is a comprehensive treatment of KV cache memory management. We will start from first principles — why the cache exists and what happens without it — then work through exact memory arithmetic for modern models, memory allocation strategies (from naive contiguous buffers to PagedAttention), compression and eviction techniques, bandwidth analysis during decode, and production tuning considerations. The goal is to give you the full systems-level picture so that you can reason about KV cache tradeoffs in your own serving stack.

Why KV Cache Exists

The Autoregressive Bottleneck

Transformer-based language models generate text one token at a time. At each step, the model must compute attention over all previous tokens. The attention mechanism for a single head looks like this:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

During generation, the query QQ comes from the new token being generated. But the keys KK and values VV must include every previous token in the sequence. Without caching, this means recomputing KK and VV projections for all prior tokens at every single generation step.

Without KV Cache: O(n^2) Per Token

Consider generating a sequence of length nn. At step tt, the model must:

  1. Project all tt tokens through WKW_K and WVW_V to get keys and values
  2. Compute attention between the new query and all tt keys
  3. Produce one output token

The projection step alone is O(td)O(t \cdot d) per layer. Summed over all nn generation steps:

Total projection work=t=1ntd=n(n+1)2d=O(n2d)\text{Total projection work} = \sum_{t=1}^{n} t \cdot d = \frac{n(n+1)}{2} \cdot d = O(n^2 \cdot d)

For a 4096-token sequence on a model with 80 layers, this means recomputing the K/V projections for tokens 1 through 4095 just to generate token 4096. Then recomputing tokens 1 through 4096 to generate token 4097. The redundant computation is enormous.

def generate_without_cache(model, prompt_tokens, max_new_tokens):
    """Naive generation: recompute everything at each step."""
    tokens = prompt_tokens.clone()

    for step in range(max_new_tokens):
        # Full forward pass over ALL tokens every step
        # Attention cost: O(seq_len^2) per layer
        logits = model(tokens)  # processes entire sequence
        next_token = sample(logits[:, -1, :])
        tokens = torch.cat([tokens, next_token.unsqueeze(1)], dim=1)

    return tokens
# Total compute: O(n^2 * d * L) where L = num_layers

With KV Cache: O(n) Per Token

The fix is straightforward: cache the K and V projections from previous tokens. At each generation step, we only compute K and V for the single new token, then concatenate them with the cached values.

def generate_with_cache(model, prompt_tokens, max_new_tokens):
    """Cached generation: only compute new token's K/V."""
    kv_cache = None

    # Prefill: process entire prompt (one-time cost)
    logits, kv_cache = model(prompt_tokens, kv_cache=None)
    next_token = sample(logits[:, -1, :])
    generated = [next_token]

    # Decode: process one token at a time, reusing cache
    for step in range(max_new_tokens - 1):
        logits, kv_cache = model(
            next_token.unsqueeze(1),
            kv_cache=kv_cache  # reuse all previous K/V
        )
        next_token = sample(logits[:, -1, :])
        generated.append(next_token)

    return generated
# Total compute: O(n * d * L) -- linear, not quadratic

The computational savings are dramatic:

📊

Computational Cost: With vs Without KV Cache (4096-token generation)

MethodK/V ProjectionsAttention OpsWall Time (est.)Speedup
No cache O(n^2 * d * L) O(n^2 * d * L) ~180 s 1.0x
With KV cache O(n * d * L) O(n * d * L) ~12 s ~15x
💡 The trade-off is memory for compute

KV caching converts a compute-bound quadratic problem into a memory-bound linear problem. You trade GPU memory (storing all those cached K/V tensors) for a massive reduction in redundant computation. This trade-off is worthwhile for virtually all serving scenarios, but it means memory management becomes the critical bottleneck.

The cache itself is simple in concept. For each layer, we store two tensors:

class CachedAttention(nn.Module):
    def forward(self, x, kv_cache=None):
        B, T, D = x.shape
        Q = self.W_q(x)  # [B, T, D]
        K_new = self.W_k(x)  # [B, T, D]
        V_new = self.W_v(x)  # [B, T, D]

        # Reshape to [B, num_heads, T, head_dim]
        Q = Q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        K_new = K_new.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
        V_new = V_new.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)

        if kv_cache is not None:
            K = torch.cat([kv_cache['K'], K_new], dim=2)
            V = torch.cat([kv_cache['V'], V_new], dim=2)
        else:
            K, V = K_new, V_new

        # Standard attention (or GQA with head broadcasting)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, V)

        new_cache = {'K': K, 'V': V}
        return out, new_cache

Simple in concept. Devastating in memory consumption at scale.

The Memory Arithmetic: KV Cache vs Model Weights

The Formula

The total KV cache memory for a model is:

KV_cache_bytes=2×L×Hkv×dh×S×B×bytes_per_element\text{KV\_cache\_bytes} = 2 \times L \times H_{kv} \times d_h \times S \times B \times \text{bytes\_per\_element}

Where:

  • 22 = one K tensor + one V tensor
  • LL = number of layers
  • HkvH_{kv} = number of KV heads (may differ from query heads in GQA)
  • dhd_h = head dimension
  • SS = sequence length
  • BB = batch size
  • bytes_per_element\text{bytes\_per\_element} = 2 for FP16/BF16, 1 for INT8

Llama 3 70B: A Worked Example

Llama 3 70B uses Grouped Query Attention (GQA) with these parameters:

  • L=80L = 80 layers
  • Hkv=8H_{kv} = 8 KV heads (while having 64 query heads)
  • dh=128d_h = 128 head dimension
  • Model weights: ~140 GB in FP16

Let us compute the KV cache for a realistic serving scenario: batch size 64, sequence length 4096.

KV_bytes=2×80×8×128×4096×64×2\text{KV\_bytes} = 2 \times 80 \times 8 \times 128 \times 4096 \times 64 \times 2

Breaking this down step by step:

=2×80×8×128=163,840 bytes per token per batch element= 2 \times 80 \times 8 \times 128 = 163{,}840 \text{ bytes per token per batch element} 163,840×4096=671,088,640 bytes per batch element163{,}840 \times 4096 = 671{,}088{,}640 \text{ bytes per batch element} 671,088,640×64=42,949,672,960 bytes (raw cache)671{,}088{,}640 \times 64 = 42{,}949{,}672{,}960 \text{ bytes (raw cache)} 42,949,672,960×2 (FP16)=85,899,345,920 bytes80 GB42{,}949{,}672{,}960 \times 2 \text{ (FP16)} = 85{,}899{,}345{,}920 \text{ bytes} \approx 80 \text{ GB}

Wait — let me redo this more carefully. The factor of 2 for “K + V” and the factor of 2 for “FP16 bytes” are separate:

KV_bytes=28081284096642=85,899,345,920 bytes80 GB\text{KV\_bytes} = 2 \cdot 80 \cdot 8 \cdot 128 \cdot 4096 \cdot 64 \cdot 2 = 85{,}899{,}345{,}920 \text{ bytes} \approx \textbf{80 GB}
⚠️ KV cache approaches model weight size

For Llama 3 70B at batch=64 and seq_len=4096, the KV cache alone is ~80 GB. The model weights are ~140 GB in FP16. The KV cache is already 57% of model weight size — and it scales linearly with both batch size and sequence length. Double the batch to 128 and you hit 160 GB of KV cache, exceeding the model weights.

Memory Breakdown Across Models

Let us compute KV cache sizes for several popular models at a fixed serving point of batch=32, seq_len=4096, FP16:

def kv_cache_gb(layers, kv_heads, head_dim, seq_len, batch, dtype_bytes=2):
    """Compute KV cache size in GB."""
    return (2 * layers * kv_heads * head_dim * seq_len * batch * dtype_bytes) / (1024**3)
📊

KV Cache Memory for Popular Models (batch=32, seq=4096, FP16)

ModelLayersKV Headsd_headWeights (GB)KV Cache (GB)Cache/Weights
Llama 3 8B 32 8 128 16 4.0 25%
Mistral 7B 32 8 128 14 4.0 29%
Llama 3 70B 80 8 128 140 10.0 7%
Llama 3 70B (batch=64) 80 8 128 140 20.0 14%
Llama 3 70B (batch=128) 80 8 128 140 40.0 29%
Llama 3 405B 126 8 128 810 15.8 2%
GPT-4 class (est.) 120 16 128 ~800 30.0 ~4%

Several observations emerge from this table. First, GQA is a lifesaver: Llama 3 70B uses only 8 KV heads (vs. 64 query heads), which gives an 8x reduction in KV cache size compared to standard multi-head attention. Without GQA, the 70B model at batch=64 would need 640 GB of KV cache — clearly impossible. Second, smaller models are more KV-cache-dominated relative to their weight size, because the KV cache scales with kv_heads * d_head * layers while weights scale with d_model^2 * layers. Third, batch size is the multiplier that pushes KV cache from manageable to dominant.

GPU Memory Layout: Llama 3 70B Serving (batch=64, seq=4096)

Approximate memory breakdown on 4x A100 80GB (320 GB total)

0x8C000 0x00000
0xDC000 0x8C000
0xF0000 0xDC000
0x140000 0xF0000
Model Weights (FP16) 140 GB
KV Cache 80 GB
Activations + Workspace 20 GB
Free / Headroom 80 GB
Static, loaded once at startup
Dynamic, grows per-request
Transient, reused per step
Needed for burst capacity
Model Weights (FP16) 140 GB
KV Cache 80 GB
Activations + Workspace 20 GB
Free / Headroom 80 GB

KV Cache Size vs Batch Size (Llama 3 70B, seq=4096, FP16)

line
Metric 18163264128256
KV Cache (GB)
0.31
2.5
5
10
20
40
80
Model Weights (GB, constant)
140
140
140
140
140
140
140

At batch=256, the KV cache reaches 80 GB — more than half the model weight size. For models with full MHA (no GQA), the crossover happens much sooner.

Memory Management: From Naive to PagedAttention

The Naive Approach: Contiguous Pre-allocation

The simplest KV cache implementation pre-allocates a contiguous buffer for each request at max_seq_len:

class NaiveKVCache:
    def __init__(self, max_seq_len, batch_size, layers, kv_heads, head_dim, dtype=torch.float16):
        # Pre-allocate maximum possible size for every request
        self.k_cache = torch.zeros(
            (batch_size, layers, kv_heads, max_seq_len, head_dim),
            dtype=dtype, device='cuda'
        )
        self.v_cache = torch.zeros_like(self.k_cache)
        self.lengths = torch.zeros(batch_size, dtype=torch.int32)

    def append(self, batch_idx, layer_idx, new_k, new_v):
        pos = self.lengths[batch_idx]
        self.k_cache[batch_idx, layer_idx, :, pos, :] = new_k
        self.v_cache[batch_idx, layer_idx, :, pos, :] = new_v
        self.lengths[batch_idx] += 1

This approach has a fatal flaw: internal fragmentation. If max_seq_len = 4096 but the average request only uses 800 tokens, then 80% of the allocated memory is wasted. The memory is reserved and cannot be used by other requests.

The vLLM team measured this waste empirically and found that 60-80% of KV cache memory is wasted under realistic workloads with contiguous allocation. Their published figure is 68% average waste across several serving traces.

📊

Memory Waste in Contiguous Allocation (measured)

WorkloadAvg Seq LenMax Seq LenAllocatedActually UsedWaste
ShareGPT (chat) 830 4096 100% 20% 80%
Alpaca (instruction) 380 2048 100% 19% 81%
Code generation 1200 4096 100% 29% 71%
Summarization 2800 4096 100% 68% 32%
Weighted average -- -- 100% 32% 68%

There is also external fragmentation: as requests complete and free their contiguous blocks, the free memory becomes a patchwork of different-sized holes. A new request needing a 4096-token block may fail even though total free memory is sufficient, because no single contiguous region is large enough.

PagedAttention: Virtual Memory for KV Cache

The breakthrough insight from vLLM (Kwon et al., 2023) is to apply the same idea that operating systems use for process memory: paging. Instead of allocating one contiguous buffer per request, divide KV cache memory into fixed-size blocks (analogous to 4KB pages in virtual memory) and map them to requests through a page table.

Each block stores the K and V vectors for a fixed number of tokens (the block size, typically 16 tokens). A request’s KV cache is a linked list of blocks, not necessarily contiguous in physical GPU memory.

class KVBlock:
    """A fixed-size block storing K/V for block_size tokens across all layers."""
    def __init__(self, block_id, block_size, layers, kv_heads, head_dim, dtype):
        # Shape: [layers, 2, kv_heads, block_size, head_dim]
        self.data = torch.zeros(
            (layers, 2, kv_heads, block_size, head_dim),
            dtype=dtype, device='cuda'
        )
        self.block_id = block_id
        self.num_filled = 0
        self.block_size = block_size

class BlockAllocator:
    """Free-list allocator for KV cache blocks. O(1) alloc and free."""
    def __init__(self, num_blocks, block_size, layers, kv_heads, head_dim, dtype):
        self.blocks = [
            KVBlock(i, block_size, layers, kv_heads, head_dim, dtype)
            for i in range(num_blocks)
        ]
        self.free_list = list(range(num_blocks))
        self.ref_counts = [0] * num_blocks

    def allocate(self):
        if not self.free_list:
            return None  # OOM
        block_id = self.free_list.pop()
        self.ref_counts[block_id] = 1
        return block_id

    def free(self, block_id):
        self.ref_counts[block_id] -= 1
        if self.ref_counts[block_id] == 0:
            self.blocks[block_id].num_filled = 0
            self.free_list.append(block_id)

    def incref(self, block_id):
        """For copy-on-write sharing."""
        self.ref_counts[block_id] += 1

class PageTable:
    """Maps (request_id, logical_block_idx) to physical block_id."""
    def __init__(self):
        self.tables = {}  # request_id -> list[int] (physical block ids)

    def get_physical_blocks(self, request_id):
        return self.tables.get(request_id, [])

    def append_block(self, request_id, physical_block_id):
        if request_id not in self.tables:
            self.tables[request_id] = []
        self.tables[request_id].append(physical_block_id)

    def release(self, request_id, allocator):
        for block_id in self.tables.pop(request_id, []):
            allocator.free(block_id)

The key properties of this design:

Near-zero waste. A request only allocates blocks as it needs them. A request at 800 tokens with block_size=16 uses exactly 50 blocks. The only waste is in the last partially-filled block — at most block_size - 1 tokens, or 15 tokens in this example. For a 4096-max system, that is under 0.4% waste vs. 80% with contiguous allocation.

No external fragmentation. All blocks are the same size, so any free block can satisfy any allocation. The “fragmentation kills you” problem of contiguous allocation vanishes entirely.

O(1) allocation and deallocation. The free list gives constant-time alloc/free. No searching, no compaction, no defragmentation.

📊

Memory Utilization: Contiguous vs PagedAttention

AllocatorUtilizationExternal FragInternal FragAlloc Time
Contiguous (max_len) 20-40% High Very high O(1)
Contiguous (growing) 60-70% High Medium O(n) realloc
PagedAttention (block=16) 96-99% Zero < 0.4%
PagedAttention (block=1) ~100% Zero Zero O(1), high overhead

Block Size: A Real Tuning Knob

The block size presents a classic overhead-vs-waste tradeoff:

  • Smaller blocks (e.g., 1 token): minimal internal fragmentation, but more page table entries, more pointer chasing, worse memory access patterns for the attention kernel.
  • Larger blocks (e.g., 64 tokens): better memory locality, fewer page table entries, but more internal fragmentation for short or variable-length sequences.

In practice, block sizes of 16-32 tokens hit the sweet spot. vLLM defaults to 16. The overhead from page table management is negligible compared to the memory savings.

Block Size Tradeoff (Llama 3 70B, mixed workload)

line
Metric 148163264128
Internal fragmentation (%)
0
0.5
1
1.8
3.2
5.5
9
Page table overhead (%)
8
3.5
1.8
1
0.6
0.3
0.2
Total overhead (%)
8
4
2.8
2.8
3.8
5.8
9.2

Beam search generates multiple candidate sequences that share a common prefix. Without copy-on-write, you must duplicate the entire KV cache for each beam — multiplying memory usage by the beam width.

With paged allocation, beams can share physical blocks for their common prefix. Each beam’s page table points to the same physical blocks. When a beam diverges (writes to a position in a shared block), the system copies only that single block — the “copy-on-write” pattern from OS virtual memory.

def fork_beam(page_table, allocator, parent_request_id, child_request_id):
    """Fork a beam: child shares parent's blocks via refcounting."""
    parent_blocks = page_table.get_physical_blocks(parent_request_id)
    page_table.tables[child_request_id] = parent_blocks.copy()
    for block_id in parent_blocks:
        allocator.incref(block_id)

def cow_write(page_table, allocator, request_id, logical_block_idx, new_kv_data):
    """Copy-on-write: only copy the block being modified."""
    blocks = page_table.get_physical_blocks(request_id)
    old_block_id = blocks[logical_block_idx]

    if allocator.ref_counts[old_block_id] > 1:
        # Shared block -- must copy before writing
        new_block_id = allocator.allocate()
        allocator.blocks[new_block_id].data.copy_(
            allocator.blocks[old_block_id].data
        )
        allocator.free(old_block_id)  # decrements refcount
        blocks[logical_block_idx] = new_block_id
        old_block_id = new_block_id

    # Now safe to write in-place
    write_kv_to_block(allocator.blocks[old_block_id], new_kv_data)

For beam width ww and sequence length SS with block size bb, naive duplication costs w×S/bw \times \lceil S/b \rceil blocks, but copy-on-write costs only S/b+(w1)×divergent_tokens/b\lceil S/b \rceil + (w-1) \times \lceil \text{divergent\_tokens}/b \rceil blocks. For beam search on a 2048-token sequence with beam width 4 that diverges in the last 50 tokens, naive allocation uses 4x memory while CoW uses approximately 1.05x.

Admission Control and Preemption

With paged allocation, admission control becomes precise: you know exactly how many free blocks remain, so you can make exact decisions about whether to admit a new request.

def can_admit(allocator, estimated_tokens, block_size=16):
    """Admit only if we have enough blocks for the estimated request."""
    needed_blocks = (estimated_tokens + block_size - 1) // block_size
    return len(allocator.free_list) >= needed_blocks

def preempt_if_needed(allocator, scheduler, min_free_blocks):
    """Preempt lowest-priority request if memory is critically low."""
    while len(allocator.free_list) < min_free_blocks:
        victim = scheduler.get_lowest_priority_request()
        if victim is None:
            break  # nothing to preempt
        # Two strategies: swap to CPU or recompute later
        if swap_is_cheaper(victim):
            swap_to_cpu(victim, allocator)
        else:
            mark_for_recompute(victim, allocator)
💡 Swap vs Recompute

When preempting a request, you have two options. Swap copies the KV cache blocks to CPU memory and restores them later — good for long sequences where recomputation would be expensive. Recompute discards the KV cache entirely and re-runs the prefill when the request is rescheduled — good for short prefills or when CPU memory is also tight. The breakeven point depends on PCIe bandwidth vs. compute throughput: for sequences under ~512 tokens, recompute is usually faster than swapping over PCIe Gen4.

KV Cache Compression

Even with perfect memory management (zero fragmentation), you eventually hit the hard limit of GPU memory. Compression techniques trade a small amount of quality for significant memory savings.

Quantization: FP16 to INT8 and INT4

The most straightforward compression is reducing the numerical precision of cached K and V tensors.

FP16 to INT8 (2x savings). Each element goes from 2 bytes to 1 byte. The quality impact is remarkably small because K and V tensors have well-behaved distributions with limited dynamic range. Per-channel or per-token quantization with calibrated scales preserves almost all information.

def quantize_kv_to_int8(kv_tensor):
    """Per-channel INT8 quantization of KV cache."""
    # Compute scale per channel (last dimension)
    abs_max = kv_tensor.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
    scale = abs_max / 127.0

    # Quantize
    quantized = (kv_tensor / scale).round().clamp(-128, 127).to(torch.int8)
    return quantized, scale

def dequantize_kv(quantized, scale):
    """Dequantize back to FP16 for attention computation."""
    return quantized.to(torch.float16) * scale

INT8 to INT4 (4x total savings from FP16). Each element uses 4 bits, packed two per byte. Quality loss becomes measurable but often acceptable for many applications. GPTQ-style or AWQ-style quantization can be applied to KV caches specifically.

📊

KV Cache Quantization Impact (Llama 3 70B, batch=64, seq=4096)

PrecisionBytes/ElementKV Cache SizeSavingsAvg Quality Loss
FP16 2 80 GB Baseline 0%
INT8 (per-channel) 1 40 GB 2x < 0.1% perplexity
INT4 (grouped) 0.5 20 GB 4x 0.3-1.0% perplexity
INT4 + FP16 outliers ~0.6 24 GB 3.3x < 0.2% perplexity
INT2 (experimental) 0.25 10 GB 8x 2-5% perplexity

The “INT4 + FP16 outliers” approach deserves special mention: it keeps a small percentage of high-magnitude values in full precision while quantizing the rest to INT4. This gives nearly the compression of INT4 with quality close to INT8. Systems like KIVI and KVQuant implement this pattern.

Memory Savings vs Quality Degradation (Llama 3 70B on MMLU)

line
Metric FP16INT8INT4+outliersINT4INT3INT2
KV Cache Size (GB, batch=64)
80
40
24
20
15
10
MMLU Score
79.5
79.4
79.2
78.8
77.5
75

Eviction Policies: When Memory Is Full

When GPU memory is exhausted and you cannot (or choose not to) preempt entire requests, you can instead evict individual KV cache entries — dropping the K/V for specific tokens at specific layers. The question is: which entries are least important?

Heavy Hitter Oracle (H2O). The key observation is that attention patterns are sparse: a small fraction of tokens receive a disproportionate share of attention mass. H2O keeps the “heavy hitter” tokens (those that consistently receive high attention scores) plus the most recent tokens, and evicts the rest.

class H2OEviction:
    """Heavy Hitter Oracle: keep tokens with highest cumulative attention."""
    def __init__(self, budget, recent_window=128):
        self.budget = budget  # max tokens to keep
        self.recent_window = recent_window
        self.cumulative_attention = None

    def update(self, attention_weights):
        """attention_weights: [batch, heads, 1, seq_len] from latest step."""
        scores = attention_weights.sum(dim=(0, 1, 2))  # [seq_len]
        if self.cumulative_attention is None:
            self.cumulative_attention = scores
        else:
            # Extend for new position
            self.cumulative_attention = torch.cat([
                self.cumulative_attention,
                scores[-1:]
            ])
            self.cumulative_attention[:-1] += scores[:-1]

    def get_eviction_mask(self, seq_len):
        """Returns boolean mask of tokens to KEEP."""
        if seq_len <= self.budget:
            return torch.ones(seq_len, dtype=torch.bool)

        keep = torch.zeros(seq_len, dtype=torch.bool)

        # Always keep recent tokens
        keep[-self.recent_window:] = True

        # Keep top-k heavy hitters from older tokens
        remaining_budget = self.budget - self.recent_window
        older_scores = self.cumulative_attention[:seq_len - self.recent_window]
        topk_indices = older_scores.topk(remaining_budget).indices
        keep[topk_indices] = True

        return keep

Attention Sinks. Xiao et al. (2023) discovered that the first few tokens in a sequence consistently receive high attention regardless of content — they act as “attention sinks.” A simple but effective eviction policy keeps the first kk tokens (typically 4) plus the most recent WW tokens, evicting everything in between.

class AttentionSinkEviction:
    """StreamingLLM-style: keep first k 'sink' tokens + recent window."""
    def __init__(self, num_sinks=4, window_size=1024):
        self.num_sinks = num_sinks
        self.window_size = window_size

    def trim_cache(self, k_cache, v_cache, current_len):
        """Trim cache to sinks + window."""
        if current_len <= self.num_sinks + self.window_size:
            return k_cache, v_cache  # no trimming needed

        # Keep first num_sinks tokens
        sink_k = k_cache[:, :, :self.num_sinks, :]
        sink_v = v_cache[:, :, :self.num_sinks, :]

        # Keep last window_size tokens
        recent_k = k_cache[:, :, -self.window_size:, :]
        recent_v = v_cache[:, :, -self.window_size:, :]

        return (
            torch.cat([sink_k, recent_k], dim=2),
            torch.cat([sink_v, recent_v], dim=2),
        )
    # Resulting cache length: num_sinks + window_size (constant)
📊

Eviction Policy Comparison (Llama 3 8B, 4096-token generation, budget=1024)

PolicyTokens KeptMemoryQuality (ppl)Implementation
Full cache 4096 100% 5.12 (baseline) N/A
Random eviction 1024 25% 7.85 Trivial
Recent-only 1024 25% 6.41 Trivial
Attention sinks (4+1020) 1024 25% 5.38 Simple
H2O 1024 25% 5.24 Moderate
H2O + sinks 1024 25% 5.19 Moderate

Sliding Window Attention: Mistral’s Approach

Mistral takes a more radical approach: the model is architecturally designed to only attend to the last WW tokens. The KV cache has a hard upper bound regardless of sequence length.

KV_cache_sliding=2×L×Hkv×dh×W×B×bytes\text{KV\_cache\_sliding} = 2 \times L \times H_{kv} \times d_h \times W \times B \times \text{bytes}

For Mistral 7B with W=4096W = 4096:

class SlidingWindowKV:
    """Fixed-size circular buffer for sliding window attention."""
    def __init__(self, window_size, layers, kv_heads, head_dim, batch_size, dtype):
        self.window_size = window_size
        self.buffer_k = torch.zeros(
            (batch_size, layers, kv_heads, window_size, head_dim),
            dtype=dtype, device='cuda'
        )
        self.buffer_v = torch.zeros_like(self.buffer_k)
        self.write_pos = 0  # circular write position

    def append(self, new_k, new_v):
        """Append new token's K/V, overwriting oldest if full."""
        pos = self.write_pos % self.window_size
        self.buffer_k[:, :, :, pos, :] = new_k
        self.buffer_v[:, :, :, pos, :] = new_v
        self.write_pos += 1

    def get_kv(self):
        """Return K/V in correct temporal order."""
        if self.write_pos <= self.window_size:
            return self.buffer_k[:, :, :, :self.write_pos, :], \
                   self.buffer_v[:, :, :, :self.write_pos, :]
        # Reorder circular buffer to temporal order
        start = self.write_pos % self.window_size
        indices = torch.cat([
            torch.arange(start, self.window_size),
            torch.arange(0, start)
        ])
        return self.buffer_k[:, :, :, indices, :], \
               self.buffer_v[:, :, :, indices, :]

The advantage is guaranteed constant memory regardless of sequence length. The disadvantage is that the model literally cannot attend to tokens beyond the window — this is a hard architectural constraint, not just a memory optimization. For tasks requiring long-range dependencies (document summarization, multi-turn conversations), this can degrade quality.

Mistral mitigates this with rolling KV combined with sliding window attention at every layer but adds back global attention at certain layers. More recent architectures like Jamba and Mamba-2 hybrids combine sliding window layers with full-attention layers, getting the memory benefits on most layers while preserving long-range capability on a few.

Memory Bandwidth Analysis: The Decode Bottleneck

Why Decode is Memory-Bound

LLM inference has two distinct phases with very different computational profiles:

Prefill processes all prompt tokens in parallel. It is compute-bound: the GPU’s tensor cores are busy multiplying large matrices, and memory bandwidth is not the bottleneck.

Decode generates one token at a time. It is memory-bound: at each step, the model must load all its weights AND the entire KV cache from HBM to compute a single token’s output. The arithmetic intensity (FLOPs per byte loaded) is extremely low.

The Bandwidth Math

For each decode step, the GPU must load:

  1. Model weights: All parameter matrices. For Llama 3 70B in FP16: ~140 GB.
  2. KV cache: All K and V tensors for all previous tokens. Size grows with sequence length.

The total bytes loaded per token:

bytes_per_token=model_weights+KV_cache_size\text{bytes\_per\_token} = \text{model\_weights} + \text{KV\_cache\_size}

At batch=1, the KV cache per token per layer is:

KV_per_layer=2×Hkv×dh×S×2=2×8×128×S×2=4096×S bytes\text{KV\_per\_layer} = 2 \times H_{kv} \times d_h \times S \times 2 = 2 \times 8 \times 128 \times S \times 2 = 4096 \times S \text{ bytes}

For 80 layers at S=2048S = 2048:

Total KV read=80×4096×2048=671,088,640 bytes0.625 GB\text{Total KV read} = 80 \times 4096 \times 2048 = 671{,}088{,}640 \text{ bytes} \approx 0.625 \text{ GB}

At batch=1, model weights dominate (140 GB vs 0.625 GB). But at batch=64:

Total KV read=80×4096×2048×6440 GB\text{Total KV read} = 80 \times 4096 \times 2048 \times 64 \approx 40 \text{ GB}

Now KV cache is 22% of total memory traffic. At batch=256 and seq=4096:

Total KV read320 GB\text{Total KV read} \approx 320 \text{ GB}

The KV cache read now exceeds the weight read by 2.3x.

📊

Memory Bandwidth Breakdown per Decode Step (Llama 3 70B, seq=2048)

Batch SizeWeight Read (GB)KV Read (GB)Total (GB)KV FractionA100 Time (ms)
1 140 0.6 140.6 0.4% 70
8 140 5.0 145.0 3.4% 73
32 140 20.0 160.0 12.5% 80
64 140 40.0 180.0 22.2% 90
128 140 80.0 220.0 36.4% 110
256 140 160.0 300.0 53.3% 150

The A100’s HBM2e bandwidth is ~2 TB/s. At batch=256, loading 300 GB takes ~150 ms per token. To put that in perspective, the actual compute (matrix multiplications) for batch=256 is only ~20 ms. The GPU is idle 87% of the time, waiting for memory.

⚠️ KV cache becomes the bandwidth bottleneck at high batch

At low batch sizes, weight loading dominates bandwidth. But batch size is the lever you pull for throughput. As you increase batch size, KV cache bandwidth grows linearly while weight bandwidth stays constant. Past a critical batch size, KV cache fetching becomes the primary bottleneck — and compressing KV cache directly increases throughput.

This is why KV cache quantization has a throughput benefit beyond just saving memory. Reducing KV cache from FP16 to INT8 cuts the KV bandwidth in half, which at batch=256 saves 80 GB of memory reads per token — a 27% reduction in total bandwidth demand.

Decode Throughput vs Batch Size (Llama 3 70B, A100 80GB)

line
Metric 183264128256
FP16 KV (tokens/sec)
14
100
350
580
850
1050
INT8 KV (tokens/sec)
14
105
370
640
980
1320
INT4 KV (tokens/sec)
14
107
385
670
1050
1500

At batch=256, INT8 KV cache delivers 26% more throughput than FP16, and INT4 delivers 43% more — purely from reduced bandwidth demand.

Production Considerations

Memory Watermark Tuning

In a production serving system, you need to decide how much GPU memory to reserve for KV cache vs. other uses. The memory watermark is the threshold at which the scheduler stops admitting new requests.

class MemoryWatermarkScheduler:
    """Production scheduler with memory watermark control."""
    def __init__(self, total_gpu_memory_gb, model_weight_gb, watermark_ratio=0.90):
        self.total_memory = total_gpu_memory_gb
        self.model_weights = model_weight_gb
        self.activation_overhead = 2.0  # GB, for activations and CUDA workspace
        self.available_for_kv = (
            self.total_memory - self.model_weights - self.activation_overhead
        )
        # High watermark: stop admitting new requests
        self.high_watermark = self.available_for_kv * watermark_ratio
        # Low watermark: start preempting existing requests
        self.low_watermark = self.available_for_kv * 0.98

    def should_admit(self, current_kv_usage_gb, estimated_new_request_gb):
        return (current_kv_usage_gb + estimated_new_request_gb) < self.high_watermark

    def should_preempt(self, current_kv_usage_gb):
        return current_kv_usage_gb > self.low_watermark

Setting the watermark too conservatively (e.g., 70%) wastes GPU memory and reduces throughput. Setting it too aggressively (e.g., 98%) causes frequent preemptions, which destroy latency. The right value depends on your workload’s sequence length distribution and variance.

📊

Watermark Tuning Impact (Llama 3 70B, mixed workload)

WatermarkMax BatchThroughputPreemptions/minp99 Latency
70% 42 580 tok/s 0 120 ms
85% 54 720 tok/s 0.2 135 ms
90% 58 760 tok/s 1.5 180 ms
95% 61 780 tok/s 8.0 450 ms
98% 63 790 tok/s 25.0 1200 ms

The sweet spot is typically 80-90%. Beyond 90%, preemptions spike and p99 latency degrades rapidly.

Preemption Strategies: Swap vs Recompute

When the scheduler must preempt a request, the choice between swapping to CPU and recomputing from scratch depends on the request’s state:

swap_time=KV_cache_sizePCIe_bandwidth\text{swap\_time} = \frac{\text{KV\_cache\_size}}{\text{PCIe\_bandwidth}} recompute_time=prefill_tokens×time_per_prefill_token1\text{recompute\_time} = \frac{\text{prefill\_tokens} \times \text{time\_per\_prefill\_token}}{1}

For PCIe Gen4 x16 (~25 GB/s bidirectional) and Llama 3 70B:

📊

Swap vs Recompute Breakeven (Llama 3 70B)

Prompt LengthKV Size (GB)Swap Time (ms)Recompute Time (ms)Winner
128 0.005 0.2 15 Swap
512 0.02 0.8 55 Swap
2048 0.08 3.2 210 Swap
8192 0.31 12.4 830 Swap
32768 1.25 50.0 3300 Swap

For typical LLM workloads, swap almost always wins because the KV cache for a single request (without the batch dimension) is relatively small. Recompute only becomes competitive when CPU memory is exhausted or PCIe is heavily contended.

Multi-Tenant Isolation

When serving multiple models or multiple tenants on the same GPU, KV cache memory must be partitioned:

class MultiTenantKVManager:
    """Isolate KV cache allocations between tenants."""
    def __init__(self, total_blocks, tenant_quotas):
        """
        tenant_quotas: dict mapping tenant_id to fraction of total blocks
        Example: {"tenant_a": 0.6, "tenant_b": 0.3, "shared": 0.1}
        """
        self.allocators = {}
        allocated = 0
        for tenant_id, fraction in tenant_quotas.items():
            num_blocks = int(total_blocks * fraction)
            self.allocators[tenant_id] = BlockAllocator(
                num_blocks=num_blocks,
                block_size=16,
                layers=80, kv_heads=8, head_dim=128,
                dtype=torch.float16
            )
            allocated += num_blocks

    def allocate(self, tenant_id):
        allocator = self.allocators.get(tenant_id)
        if allocator is None:
            raise ValueError(f"Unknown tenant: {tenant_id}")
        return allocator.allocate()

    def get_utilization(self, tenant_id):
        alloc = self.allocators[tenant_id]
        total = len(alloc.blocks)
        free = len(alloc.free_list)
        return (total - free) / total

The “shared” pool handles burst capacity: when a tenant exceeds its quota, it can borrow from the shared pool (with lower priority and possible preemption).

Prefix Caching

Many serving workloads share common prefixes: system prompts, few-shot examples, or RAG context. Prefix caching stores these shared KV cache blocks once and reuses them across requests.

class PrefixCache:
    """Cache KV blocks for common prefixes, keyed by token hash."""
    def __init__(self, allocator, max_cached_prefixes=1000):
        self.allocator = allocator
        self.cache = {}  # hash(token_ids) -> list[block_id]
        self.access_order = []  # LRU tracking

    def lookup(self, token_ids):
        """Check if prefix KV cache exists."""
        key = hash(tuple(token_ids))
        if key in self.cache:
            self._touch(key)
            block_ids = self.cache[key]
            for bid in block_ids:
                self.allocator.incref(bid)  # shared reference
            return block_ids
        return None

    def insert(self, token_ids, block_ids):
        """Cache a computed prefix."""
        key = hash(tuple(token_ids))
        self.cache[key] = block_ids
        for bid in block_ids:
            self.allocator.incref(bid)
        self.access_order.append(key)
        self._maybe_evict()

For workloads with a shared system prompt (common in chat applications), prefix caching can save both the memory and the compute cost of the prefill for that prefix. A 2000-token system prompt shared across 64 concurrent requests saves 64x the prefill compute and stores the prefix KV blocks only once.

When NOT to Optimize KV Cache

Not every deployment benefits from aggressive KV cache optimization. Here are the cases where simpler approaches work fine:

Short Sequences

If your maximum sequence length is under 512 tokens, KV cache memory is negligible. For Llama 3 8B at batch=32, seq=512:

KV_cache=2×32×8×128×512×32×2=0.5 GB\text{KV\_cache} = 2 \times 32 \times 8 \times 128 \times 512 \times 32 \times 2 = 0.5 \text{ GB}

Half a gigabyte. On an A100 with 80 GB, this is under 1% of memory. Naive contiguous allocation with 50% waste still only costs 1 GB total. The engineering complexity of PagedAttention is not justified.

Small Models

Models under 3B parameters have small KV caches by construction (fewer layers, fewer heads). For Llama 3.2 1B at batch=64, seq=4096:

KV_cache=2×16×8×64×4096×64×2=4.3 GB\text{KV\_cache} = 2 \times 16 \times 8 \times 64 \times 4096 \times 64 \times 2 = 4.3 \text{ GB}

Still manageable with simple allocation on modern GPUs. The model weights are only ~2 GB, so even with wasteful allocation you have ample headroom.

Prefill-Dominated Workloads

If your workload is mostly long prompts with short outputs (e.g., classification, extraction, scoring), then:

  • Prefill is the bottleneck, not decode
  • KV cache lifetime is short (few decode steps)
  • Memory pressure is transient, not sustained

In this regime, optimizing prefill throughput (chunked prefill, FlashAttention, tensor parallelism) matters more than optimizing KV cache management.

📊

When KV Cache Optimization Matters Most

ScenarioSeq LengthBatch SizeOutput LengthKV Optimization Impact
Short chat < 512 < 32 < 256 Low -- simple alloc is fine
Small model API < 4096 < 64 Any Low-Medium
Classification < 2048 Large < 10 Low -- prefill dominated
Long-form generation > 2048 > 32 > 512 High
Multi-turn chat > 4096 > 64 Variable Very High
RAG with long context > 8192 > 16 > 256 Critical

Throughput Impact: Real Numbers

Let us put together a comprehensive throughput comparison showing the combined impact of KV cache optimizations on a realistic serving setup.

Setup: Llama 3 70B on 4x A100 80GB (tensor parallel), mixed workload with average input 1024 tokens, average output 512 tokens, max sequence 8192.

📊

End-to-End Serving Throughput (Llama 3 70B, 4x A100)

ConfigurationMax BatchThroughput (tok/s)p50 Latencyp99 LatencyGPU Util
Contiguous FP16 24 420 85 ms 350 ms 55%
Paged FP16 58 780 72 ms 180 ms 82%
Paged INT8 96 1050 65 ms 160 ms 88%
Paged INT8 + prefix cache 96 1180 55 ms 145 ms 90%
Paged INT4 + prefix cache 140 1420 52 ms 155 ms 91%

Throughput Improvement Stack (Llama 3 70B, 4x A100)

Metric Contiguous FP16Paged FP16Paged INT8+ Prefix CachePaged INT4 + Cache
Throughput (tokens/sec)
420
780
1050
1180
1420

The progression tells a clear story:

  1. Paged allocation (1.86x over contiguous): eliminates fragmentation, allows 2.4x more concurrent requests
  2. INT8 quantization (1.35x over paged FP16): fits more requests AND reduces bandwidth per request
  3. Prefix caching (1.12x incremental): saves prefill compute for repeated system prompts
  4. INT4 quantization (1.20x over INT8 + cache): pushes batch size even higher, though with small quality tradeoffs

The cumulative improvement from the simplest to the most optimized configuration is 3.4x — and this is on the same hardware, serving the same model, with no changes to model quality (for the INT8 path).

The Full Picture

Optimized GPU Memory Layout: Llama 3 70B Serving

4x A100 80GB with PagedAttention + INT8 KV cache

0x8C000 0x00000
0xBC000 0x8C000
0xC6000 0xBC000
0xD2000 0xC6000
0xE0000 0xD2000
Model Weights (FP16) 140 GB
KV Block Pool (INT8) 48 GB
Prefix Cache 10 GB
Activations + CUDA Workspace 12 GB
Watermark Headroom 10 GB
Sharded across 4 GPUs via tensor parallelism
Paged blocks, 96+ concurrent requests
Shared system prompt KV, refcounted
Transient, reused per step
Absorbs burst without preemption
Model Weights (FP16) 140 GB
KV Block Pool (INT8) 48 GB
Prefix Cache 10 GB
Activations + CUDA Workspace 12 GB
Watermark Headroom 10 GB

KV cache management is not a single optimization but an interlocking system of decisions:

  1. Allocation strategy determines memory utilization and fragmentation. PagedAttention is now the standard.
  2. Quantization level trades precision for capacity and bandwidth. INT8 is the sweet spot for most deployments.
  3. Eviction policy determines graceful degradation under memory pressure. Attention sinks + heavy hitters preserve quality.
  4. Bandwidth awareness explains why KV compression improves throughput, not just capacity.
  5. Watermark tuning balances utilization against latency stability.
  6. Prefix sharing and copy-on-write amortize common computation.

The difference between a naive implementation and a well-tuned one is not incremental — it is the difference between serving 24 concurrent users and serving 140 on the same hardware. For anyone running LLM inference at scale, KV cache management is where the most impactful systems work happens.