If you have ever tried to serve a large language model in production, you have encountered a frustrating reality: the model weights fit on your GPUs, but you cannot serve many concurrent users. The bottleneck is almost always the KV cache — a runtime data structure that grows with every token generated, for every request, across every layer of the model. At serving scale, the KV cache routinely consumes more GPU memory than the model itself.
This post is a comprehensive treatment of KV cache memory management. We will start from first principles — why the cache exists and what happens without it — then work through exact memory arithmetic for modern models, memory allocation strategies (from naive contiguous buffers to PagedAttention), compression and eviction techniques, bandwidth analysis during decode, and production tuning considerations. The goal is to give you the full systems-level picture so that you can reason about KV cache tradeoffs in your own serving stack.
Why KV Cache Exists
The Autoregressive Bottleneck
Transformer-based language models generate text one token at a time. At each step, the model must compute attention over all previous tokens. The attention mechanism for a single head looks like this:
During generation, the query comes from the new token being generated. But the keys and values must include every previous token in the sequence. Without caching, this means recomputing and projections for all prior tokens at every single generation step.
Without KV Cache: O(n^2) Per Token
Consider generating a sequence of length . At step , the model must:
- Project all tokens through and to get keys and values
- Compute attention between the new query and all keys
- Produce one output token
The projection step alone is per layer. Summed over all generation steps:
For a 4096-token sequence on a model with 80 layers, this means recomputing the K/V projections for tokens 1 through 4095 just to generate token 4096. Then recomputing tokens 1 through 4096 to generate token 4097. The redundant computation is enormous.
def generate_without_cache(model, prompt_tokens, max_new_tokens):
"""Naive generation: recompute everything at each step."""
tokens = prompt_tokens.clone()
for step in range(max_new_tokens):
# Full forward pass over ALL tokens every step
# Attention cost: O(seq_len^2) per layer
logits = model(tokens) # processes entire sequence
next_token = sample(logits[:, -1, :])
tokens = torch.cat([tokens, next_token.unsqueeze(1)], dim=1)
return tokens
# Total compute: O(n^2 * d * L) where L = num_layers
With KV Cache: O(n) Per Token
The fix is straightforward: cache the K and V projections from previous tokens. At each generation step, we only compute K and V for the single new token, then concatenate them with the cached values.
def generate_with_cache(model, prompt_tokens, max_new_tokens):
"""Cached generation: only compute new token's K/V."""
kv_cache = None
# Prefill: process entire prompt (one-time cost)
logits, kv_cache = model(prompt_tokens, kv_cache=None)
next_token = sample(logits[:, -1, :])
generated = [next_token]
# Decode: process one token at a time, reusing cache
for step in range(max_new_tokens - 1):
logits, kv_cache = model(
next_token.unsqueeze(1),
kv_cache=kv_cache # reuse all previous K/V
)
next_token = sample(logits[:, -1, :])
generated.append(next_token)
return generated
# Total compute: O(n * d * L) -- linear, not quadratic
The computational savings are dramatic:
Computational Cost: With vs Without KV Cache (4096-token generation)
| Method | K/V Projections | Attention Ops | Wall Time (est.) | Speedup |
|---|---|---|---|---|
| No cache | O(n^2 * d * L) | O(n^2 * d * L) | ~180 s | 1.0x |
| With KV cache | O(n * d * L) | O(n * d * L) | ~12 s | ~15x |
KV caching converts a compute-bound quadratic problem into a memory-bound linear problem. You trade GPU memory (storing all those cached K/V tensors) for a massive reduction in redundant computation. This trade-off is worthwhile for virtually all serving scenarios, but it means memory management becomes the critical bottleneck.
The cache itself is simple in concept. For each layer, we store two tensors:
class CachedAttention(nn.Module):
def forward(self, x, kv_cache=None):
B, T, D = x.shape
Q = self.W_q(x) # [B, T, D]
K_new = self.W_k(x) # [B, T, D]
V_new = self.W_v(x) # [B, T, D]
# Reshape to [B, num_heads, T, head_dim]
Q = Q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
K_new = K_new.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
V_new = V_new.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
if kv_cache is not None:
K = torch.cat([kv_cache['K'], K_new], dim=2)
V = torch.cat([kv_cache['V'], V_new], dim=2)
else:
K, V = K_new, V_new
# Standard attention (or GQA with head broadcasting)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, V)
new_cache = {'K': K, 'V': V}
return out, new_cache
Simple in concept. Devastating in memory consumption at scale.
The Memory Arithmetic: KV Cache vs Model Weights
The Formula
The total KV cache memory for a model is:
Where:
- = one K tensor + one V tensor
- = number of layers
- = number of KV heads (may differ from query heads in GQA)
- = head dimension
- = sequence length
- = batch size
- = 2 for FP16/BF16, 1 for INT8
Llama 3 70B: A Worked Example
Llama 3 70B uses Grouped Query Attention (GQA) with these parameters:
- layers
- KV heads (while having 64 query heads)
- head dimension
- Model weights: ~140 GB in FP16
Let us compute the KV cache for a realistic serving scenario: batch size 64, sequence length 4096.
Breaking this down step by step:
Wait — let me redo this more carefully. The factor of 2 for “K + V” and the factor of 2 for “FP16 bytes” are separate:
For Llama 3 70B at batch=64 and seq_len=4096, the KV cache alone is ~80 GB. The model weights are ~140 GB in FP16. The KV cache is already 57% of model weight size — and it scales linearly with both batch size and sequence length. Double the batch to 128 and you hit 160 GB of KV cache, exceeding the model weights.
Memory Breakdown Across Models
Let us compute KV cache sizes for several popular models at a fixed serving point of batch=32, seq_len=4096, FP16:
def kv_cache_gb(layers, kv_heads, head_dim, seq_len, batch, dtype_bytes=2):
"""Compute KV cache size in GB."""
return (2 * layers * kv_heads * head_dim * seq_len * batch * dtype_bytes) / (1024**3)
KV Cache Memory for Popular Models (batch=32, seq=4096, FP16)
| Model | Layers | KV Heads | d_head | Weights (GB) | KV Cache (GB) | Cache/Weights |
|---|---|---|---|---|---|---|
| Llama 3 8B | 32 | 8 | 128 | 16 | 4.0 | 25% |
| Mistral 7B | 32 | 8 | 128 | 14 | 4.0 | 29% |
| Llama 3 70B | 80 | 8 | 128 | 140 | 10.0 | 7% |
| Llama 3 70B (batch=64) | 80 | 8 | 128 | 140 | 20.0 | 14% |
| Llama 3 70B (batch=128) | 80 | 8 | 128 | 140 | 40.0 | 29% |
| Llama 3 405B | 126 | 8 | 128 | 810 | 15.8 | 2% |
| GPT-4 class (est.) | 120 | 16 | 128 | ~800 | 30.0 | ~4% |
Several observations emerge from this table. First, GQA is a lifesaver: Llama 3 70B uses only 8 KV heads (vs. 64 query heads), which gives an 8x reduction in KV cache size compared to standard multi-head attention. Without GQA, the 70B model at batch=64 would need 640 GB of KV cache — clearly impossible. Second, smaller models are more KV-cache-dominated relative to their weight size, because the KV cache scales with kv_heads * d_head * layers while weights scale with d_model^2 * layers. Third, batch size is the multiplier that pushes KV cache from manageable to dominant.
GPU Memory Layout: Llama 3 70B Serving (batch=64, seq=4096)
Approximate memory breakdown on 4x A100 80GB (320 GB total)
0x8C000 0x00000 0xDC000 0x8C000 0xF0000 0xDC000 0x140000 0xF0000 140 GB 80 GB 20 GB 80 GB KV Cache Size vs Batch Size (Llama 3 70B, seq=4096, FP16)
line| Metric | 1 | 8 | 16 | 32 | 64 | 128 | 256 |
|---|---|---|---|---|---|---|---|
| KV Cache (GB) | |||||||
| Model Weights (GB, constant) |
At batch=256, the KV cache reaches 80 GB — more than half the model weight size. For models with full MHA (no GQA), the crossover happens much sooner.
Memory Management: From Naive to PagedAttention
The Naive Approach: Contiguous Pre-allocation
The simplest KV cache implementation pre-allocates a contiguous buffer for each request at max_seq_len:
class NaiveKVCache:
def __init__(self, max_seq_len, batch_size, layers, kv_heads, head_dim, dtype=torch.float16):
# Pre-allocate maximum possible size for every request
self.k_cache = torch.zeros(
(batch_size, layers, kv_heads, max_seq_len, head_dim),
dtype=dtype, device='cuda'
)
self.v_cache = torch.zeros_like(self.k_cache)
self.lengths = torch.zeros(batch_size, dtype=torch.int32)
def append(self, batch_idx, layer_idx, new_k, new_v):
pos = self.lengths[batch_idx]
self.k_cache[batch_idx, layer_idx, :, pos, :] = new_k
self.v_cache[batch_idx, layer_idx, :, pos, :] = new_v
self.lengths[batch_idx] += 1
This approach has a fatal flaw: internal fragmentation. If max_seq_len = 4096 but the average request only uses 800 tokens, then 80% of the allocated memory is wasted. The memory is reserved and cannot be used by other requests.
The vLLM team measured this waste empirically and found that 60-80% of KV cache memory is wasted under realistic workloads with contiguous allocation. Their published figure is 68% average waste across several serving traces.
Memory Waste in Contiguous Allocation (measured)
| Workload | Avg Seq Len | Max Seq Len | Allocated | Actually Used | Waste |
|---|---|---|---|---|---|
| ShareGPT (chat) | 830 | 4096 | 100% | 20% | 80% |
| Alpaca (instruction) | 380 | 2048 | 100% | 19% | 81% |
| Code generation | 1200 | 4096 | 100% | 29% | 71% |
| Summarization | 2800 | 4096 | 100% | 68% | 32% |
| Weighted average | -- | -- | 100% | 32% | 68% |
There is also external fragmentation: as requests complete and free their contiguous blocks, the free memory becomes a patchwork of different-sized holes. A new request needing a 4096-token block may fail even though total free memory is sufficient, because no single contiguous region is large enough.
PagedAttention: Virtual Memory for KV Cache
The breakthrough insight from vLLM (Kwon et al., 2023) is to apply the same idea that operating systems use for process memory: paging. Instead of allocating one contiguous buffer per request, divide KV cache memory into fixed-size blocks (analogous to 4KB pages in virtual memory) and map them to requests through a page table.
Each block stores the K and V vectors for a fixed number of tokens (the block size, typically 16 tokens). A request’s KV cache is a linked list of blocks, not necessarily contiguous in physical GPU memory.
class KVBlock:
"""A fixed-size block storing K/V for block_size tokens across all layers."""
def __init__(self, block_id, block_size, layers, kv_heads, head_dim, dtype):
# Shape: [layers, 2, kv_heads, block_size, head_dim]
self.data = torch.zeros(
(layers, 2, kv_heads, block_size, head_dim),
dtype=dtype, device='cuda'
)
self.block_id = block_id
self.num_filled = 0
self.block_size = block_size
class BlockAllocator:
"""Free-list allocator for KV cache blocks. O(1) alloc and free."""
def __init__(self, num_blocks, block_size, layers, kv_heads, head_dim, dtype):
self.blocks = [
KVBlock(i, block_size, layers, kv_heads, head_dim, dtype)
for i in range(num_blocks)
]
self.free_list = list(range(num_blocks))
self.ref_counts = [0] * num_blocks
def allocate(self):
if not self.free_list:
return None # OOM
block_id = self.free_list.pop()
self.ref_counts[block_id] = 1
return block_id
def free(self, block_id):
self.ref_counts[block_id] -= 1
if self.ref_counts[block_id] == 0:
self.blocks[block_id].num_filled = 0
self.free_list.append(block_id)
def incref(self, block_id):
"""For copy-on-write sharing."""
self.ref_counts[block_id] += 1
class PageTable:
"""Maps (request_id, logical_block_idx) to physical block_id."""
def __init__(self):
self.tables = {} # request_id -> list[int] (physical block ids)
def get_physical_blocks(self, request_id):
return self.tables.get(request_id, [])
def append_block(self, request_id, physical_block_id):
if request_id not in self.tables:
self.tables[request_id] = []
self.tables[request_id].append(physical_block_id)
def release(self, request_id, allocator):
for block_id in self.tables.pop(request_id, []):
allocator.free(block_id)
The key properties of this design:
Near-zero waste. A request only allocates blocks as it needs them. A request at 800 tokens with block_size=16 uses exactly 50 blocks. The only waste is in the last partially-filled block — at most block_size - 1 tokens, or 15 tokens in this example. For a 4096-max system, that is under 0.4% waste vs. 80% with contiguous allocation.
No external fragmentation. All blocks are the same size, so any free block can satisfy any allocation. The “fragmentation kills you” problem of contiguous allocation vanishes entirely.
O(1) allocation and deallocation. The free list gives constant-time alloc/free. No searching, no compaction, no defragmentation.
Memory Utilization: Contiguous vs PagedAttention
| Allocator | Utilization | External Frag | Internal Frag | Alloc Time |
|---|---|---|---|---|
| Contiguous (max_len) | 20-40% | High | Very high | O(1) |
| Contiguous (growing) | 60-70% | High | Medium | O(n) realloc |
| PagedAttention (block=16) | 96-99% | Zero | < 0.4% | |
| PagedAttention (block=1) | ~100% | Zero | Zero | O(1), high overhead |
Block Size: A Real Tuning Knob
The block size presents a classic overhead-vs-waste tradeoff:
- Smaller blocks (e.g., 1 token): minimal internal fragmentation, but more page table entries, more pointer chasing, worse memory access patterns for the attention kernel.
- Larger blocks (e.g., 64 tokens): better memory locality, fewer page table entries, but more internal fragmentation for short or variable-length sequences.
In practice, block sizes of 16-32 tokens hit the sweet spot. vLLM defaults to 16. The overhead from page table management is negligible compared to the memory savings.
Block Size Tradeoff (Llama 3 70B, mixed workload)
line| Metric | 1 | 4 | 8 | 16 | 32 | 64 | 128 |
|---|---|---|---|---|---|---|---|
| Internal fragmentation (%) | |||||||
| Page table overhead (%) | |||||||
| Total overhead (%) |
Copy-on-Write for Beam Search
Beam search generates multiple candidate sequences that share a common prefix. Without copy-on-write, you must duplicate the entire KV cache for each beam — multiplying memory usage by the beam width.
With paged allocation, beams can share physical blocks for their common prefix. Each beam’s page table points to the same physical blocks. When a beam diverges (writes to a position in a shared block), the system copies only that single block — the “copy-on-write” pattern from OS virtual memory.
def fork_beam(page_table, allocator, parent_request_id, child_request_id):
"""Fork a beam: child shares parent's blocks via refcounting."""
parent_blocks = page_table.get_physical_blocks(parent_request_id)
page_table.tables[child_request_id] = parent_blocks.copy()
for block_id in parent_blocks:
allocator.incref(block_id)
def cow_write(page_table, allocator, request_id, logical_block_idx, new_kv_data):
"""Copy-on-write: only copy the block being modified."""
blocks = page_table.get_physical_blocks(request_id)
old_block_id = blocks[logical_block_idx]
if allocator.ref_counts[old_block_id] > 1:
# Shared block -- must copy before writing
new_block_id = allocator.allocate()
allocator.blocks[new_block_id].data.copy_(
allocator.blocks[old_block_id].data
)
allocator.free(old_block_id) # decrements refcount
blocks[logical_block_idx] = new_block_id
old_block_id = new_block_id
# Now safe to write in-place
write_kv_to_block(allocator.blocks[old_block_id], new_kv_data)
For beam width and sequence length with block size , naive duplication costs blocks, but copy-on-write costs only blocks. For beam search on a 2048-token sequence with beam width 4 that diverges in the last 50 tokens, naive allocation uses 4x memory while CoW uses approximately 1.05x.
Admission Control and Preemption
With paged allocation, admission control becomes precise: you know exactly how many free blocks remain, so you can make exact decisions about whether to admit a new request.
def can_admit(allocator, estimated_tokens, block_size=16):
"""Admit only if we have enough blocks for the estimated request."""
needed_blocks = (estimated_tokens + block_size - 1) // block_size
return len(allocator.free_list) >= needed_blocks
def preempt_if_needed(allocator, scheduler, min_free_blocks):
"""Preempt lowest-priority request if memory is critically low."""
while len(allocator.free_list) < min_free_blocks:
victim = scheduler.get_lowest_priority_request()
if victim is None:
break # nothing to preempt
# Two strategies: swap to CPU or recompute later
if swap_is_cheaper(victim):
swap_to_cpu(victim, allocator)
else:
mark_for_recompute(victim, allocator)
When preempting a request, you have two options. Swap copies the KV cache blocks to CPU memory and restores them later — good for long sequences where recomputation would be expensive. Recompute discards the KV cache entirely and re-runs the prefill when the request is rescheduled — good for short prefills or when CPU memory is also tight. The breakeven point depends on PCIe bandwidth vs. compute throughput: for sequences under ~512 tokens, recompute is usually faster than swapping over PCIe Gen4.
KV Cache Compression
Even with perfect memory management (zero fragmentation), you eventually hit the hard limit of GPU memory. Compression techniques trade a small amount of quality for significant memory savings.
Quantization: FP16 to INT8 and INT4
The most straightforward compression is reducing the numerical precision of cached K and V tensors.
FP16 to INT8 (2x savings). Each element goes from 2 bytes to 1 byte. The quality impact is remarkably small because K and V tensors have well-behaved distributions with limited dynamic range. Per-channel or per-token quantization with calibrated scales preserves almost all information.
def quantize_kv_to_int8(kv_tensor):
"""Per-channel INT8 quantization of KV cache."""
# Compute scale per channel (last dimension)
abs_max = kv_tensor.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
scale = abs_max / 127.0
# Quantize
quantized = (kv_tensor / scale).round().clamp(-128, 127).to(torch.int8)
return quantized, scale
def dequantize_kv(quantized, scale):
"""Dequantize back to FP16 for attention computation."""
return quantized.to(torch.float16) * scale
INT8 to INT4 (4x total savings from FP16). Each element uses 4 bits, packed two per byte. Quality loss becomes measurable but often acceptable for many applications. GPTQ-style or AWQ-style quantization can be applied to KV caches specifically.
KV Cache Quantization Impact (Llama 3 70B, batch=64, seq=4096)
| Precision | Bytes/Element | KV Cache Size | Savings | Avg Quality Loss |
|---|---|---|---|---|
| FP16 | 2 | 80 GB | Baseline | 0% |
| INT8 (per-channel) | 1 | 40 GB | 2x | < 0.1% perplexity |
| INT4 (grouped) | 0.5 | 20 GB | 4x | 0.3-1.0% perplexity |
| INT4 + FP16 outliers | ~0.6 | 24 GB | 3.3x | < 0.2% perplexity |
| INT2 (experimental) | 0.25 | 10 GB | 8x | 2-5% perplexity |
The “INT4 + FP16 outliers” approach deserves special mention: it keeps a small percentage of high-magnitude values in full precision while quantizing the rest to INT4. This gives nearly the compression of INT4 with quality close to INT8. Systems like KIVI and KVQuant implement this pattern.
Memory Savings vs Quality Degradation (Llama 3 70B on MMLU)
line| Metric | FP16 | INT8 | INT4+outliers | INT4 | INT3 | INT2 |
|---|---|---|---|---|---|---|
| KV Cache Size (GB, batch=64) | ||||||
| MMLU Score |
Eviction Policies: When Memory Is Full
When GPU memory is exhausted and you cannot (or choose not to) preempt entire requests, you can instead evict individual KV cache entries — dropping the K/V for specific tokens at specific layers. The question is: which entries are least important?
Heavy Hitter Oracle (H2O). The key observation is that attention patterns are sparse: a small fraction of tokens receive a disproportionate share of attention mass. H2O keeps the “heavy hitter” tokens (those that consistently receive high attention scores) plus the most recent tokens, and evicts the rest.
class H2OEviction:
"""Heavy Hitter Oracle: keep tokens with highest cumulative attention."""
def __init__(self, budget, recent_window=128):
self.budget = budget # max tokens to keep
self.recent_window = recent_window
self.cumulative_attention = None
def update(self, attention_weights):
"""attention_weights: [batch, heads, 1, seq_len] from latest step."""
scores = attention_weights.sum(dim=(0, 1, 2)) # [seq_len]
if self.cumulative_attention is None:
self.cumulative_attention = scores
else:
# Extend for new position
self.cumulative_attention = torch.cat([
self.cumulative_attention,
scores[-1:]
])
self.cumulative_attention[:-1] += scores[:-1]
def get_eviction_mask(self, seq_len):
"""Returns boolean mask of tokens to KEEP."""
if seq_len <= self.budget:
return torch.ones(seq_len, dtype=torch.bool)
keep = torch.zeros(seq_len, dtype=torch.bool)
# Always keep recent tokens
keep[-self.recent_window:] = True
# Keep top-k heavy hitters from older tokens
remaining_budget = self.budget - self.recent_window
older_scores = self.cumulative_attention[:seq_len - self.recent_window]
topk_indices = older_scores.topk(remaining_budget).indices
keep[topk_indices] = True
return keep
Attention Sinks. Xiao et al. (2023) discovered that the first few tokens in a sequence consistently receive high attention regardless of content — they act as “attention sinks.” A simple but effective eviction policy keeps the first tokens (typically 4) plus the most recent tokens, evicting everything in between.
class AttentionSinkEviction:
"""StreamingLLM-style: keep first k 'sink' tokens + recent window."""
def __init__(self, num_sinks=4, window_size=1024):
self.num_sinks = num_sinks
self.window_size = window_size
def trim_cache(self, k_cache, v_cache, current_len):
"""Trim cache to sinks + window."""
if current_len <= self.num_sinks + self.window_size:
return k_cache, v_cache # no trimming needed
# Keep first num_sinks tokens
sink_k = k_cache[:, :, :self.num_sinks, :]
sink_v = v_cache[:, :, :self.num_sinks, :]
# Keep last window_size tokens
recent_k = k_cache[:, :, -self.window_size:, :]
recent_v = v_cache[:, :, -self.window_size:, :]
return (
torch.cat([sink_k, recent_k], dim=2),
torch.cat([sink_v, recent_v], dim=2),
)
# Resulting cache length: num_sinks + window_size (constant)
Eviction Policy Comparison (Llama 3 8B, 4096-token generation, budget=1024)
| Policy | Tokens Kept | Memory | Quality (ppl) | Implementation |
|---|---|---|---|---|
| Full cache | 4096 | 100% | 5.12 (baseline) | N/A |
| Random eviction | 1024 | 25% | 7.85 | Trivial |
| Recent-only | 1024 | 25% | 6.41 | Trivial |
| Attention sinks (4+1020) | 1024 | 25% | 5.38 | Simple |
| H2O | 1024 | 25% | 5.24 | Moderate |
| H2O + sinks | 1024 | 25% | 5.19 | Moderate |
Sliding Window Attention: Mistral’s Approach
Mistral takes a more radical approach: the model is architecturally designed to only attend to the last tokens. The KV cache has a hard upper bound regardless of sequence length.
For Mistral 7B with :
class SlidingWindowKV:
"""Fixed-size circular buffer for sliding window attention."""
def __init__(self, window_size, layers, kv_heads, head_dim, batch_size, dtype):
self.window_size = window_size
self.buffer_k = torch.zeros(
(batch_size, layers, kv_heads, window_size, head_dim),
dtype=dtype, device='cuda'
)
self.buffer_v = torch.zeros_like(self.buffer_k)
self.write_pos = 0 # circular write position
def append(self, new_k, new_v):
"""Append new token's K/V, overwriting oldest if full."""
pos = self.write_pos % self.window_size
self.buffer_k[:, :, :, pos, :] = new_k
self.buffer_v[:, :, :, pos, :] = new_v
self.write_pos += 1
def get_kv(self):
"""Return K/V in correct temporal order."""
if self.write_pos <= self.window_size:
return self.buffer_k[:, :, :, :self.write_pos, :], \
self.buffer_v[:, :, :, :self.write_pos, :]
# Reorder circular buffer to temporal order
start = self.write_pos % self.window_size
indices = torch.cat([
torch.arange(start, self.window_size),
torch.arange(0, start)
])
return self.buffer_k[:, :, :, indices, :], \
self.buffer_v[:, :, :, indices, :]
The advantage is guaranteed constant memory regardless of sequence length. The disadvantage is that the model literally cannot attend to tokens beyond the window — this is a hard architectural constraint, not just a memory optimization. For tasks requiring long-range dependencies (document summarization, multi-turn conversations), this can degrade quality.
Mistral mitigates this with rolling KV combined with sliding window attention at every layer but adds back global attention at certain layers. More recent architectures like Jamba and Mamba-2 hybrids combine sliding window layers with full-attention layers, getting the memory benefits on most layers while preserving long-range capability on a few.
Memory Bandwidth Analysis: The Decode Bottleneck
Why Decode is Memory-Bound
LLM inference has two distinct phases with very different computational profiles:
Prefill processes all prompt tokens in parallel. It is compute-bound: the GPU’s tensor cores are busy multiplying large matrices, and memory bandwidth is not the bottleneck.
Decode generates one token at a time. It is memory-bound: at each step, the model must load all its weights AND the entire KV cache from HBM to compute a single token’s output. The arithmetic intensity (FLOPs per byte loaded) is extremely low.
The Bandwidth Math
For each decode step, the GPU must load:
- Model weights: All parameter matrices. For Llama 3 70B in FP16: ~140 GB.
- KV cache: All K and V tensors for all previous tokens. Size grows with sequence length.
The total bytes loaded per token:
At batch=1, the KV cache per token per layer is:
For 80 layers at :
At batch=1, model weights dominate (140 GB vs 0.625 GB). But at batch=64:
Now KV cache is 22% of total memory traffic. At batch=256 and seq=4096:
The KV cache read now exceeds the weight read by 2.3x.
Memory Bandwidth Breakdown per Decode Step (Llama 3 70B, seq=2048)
| Batch Size | Weight Read (GB) | KV Read (GB) | Total (GB) | KV Fraction | A100 Time (ms) |
|---|---|---|---|---|---|
| 1 | 140 | 0.6 | 140.6 | 0.4% | 70 |
| 8 | 140 | 5.0 | 145.0 | 3.4% | 73 |
| 32 | 140 | 20.0 | 160.0 | 12.5% | 80 |
| 64 | 140 | 40.0 | 180.0 | 22.2% | 90 |
| 128 | 140 | 80.0 | 220.0 | 36.4% | 110 |
| 256 | 140 | 160.0 | 300.0 | 53.3% | 150 |
The A100’s HBM2e bandwidth is ~2 TB/s. At batch=256, loading 300 GB takes ~150 ms per token. To put that in perspective, the actual compute (matrix multiplications) for batch=256 is only ~20 ms. The GPU is idle 87% of the time, waiting for memory.
At low batch sizes, weight loading dominates bandwidth. But batch size is the lever you pull for throughput. As you increase batch size, KV cache bandwidth grows linearly while weight bandwidth stays constant. Past a critical batch size, KV cache fetching becomes the primary bottleneck — and compressing KV cache directly increases throughput.
This is why KV cache quantization has a throughput benefit beyond just saving memory. Reducing KV cache from FP16 to INT8 cuts the KV bandwidth in half, which at batch=256 saves 80 GB of memory reads per token — a 27% reduction in total bandwidth demand.
Decode Throughput vs Batch Size (Llama 3 70B, A100 80GB)
line| Metric | 1 | 8 | 32 | 64 | 128 | 256 |
|---|---|---|---|---|---|---|
| FP16 KV (tokens/sec) | ||||||
| INT8 KV (tokens/sec) | ||||||
| INT4 KV (tokens/sec) |
At batch=256, INT8 KV cache delivers 26% more throughput than FP16, and INT4 delivers 43% more — purely from reduced bandwidth demand.
Production Considerations
Memory Watermark Tuning
In a production serving system, you need to decide how much GPU memory to reserve for KV cache vs. other uses. The memory watermark is the threshold at which the scheduler stops admitting new requests.
class MemoryWatermarkScheduler:
"""Production scheduler with memory watermark control."""
def __init__(self, total_gpu_memory_gb, model_weight_gb, watermark_ratio=0.90):
self.total_memory = total_gpu_memory_gb
self.model_weights = model_weight_gb
self.activation_overhead = 2.0 # GB, for activations and CUDA workspace
self.available_for_kv = (
self.total_memory - self.model_weights - self.activation_overhead
)
# High watermark: stop admitting new requests
self.high_watermark = self.available_for_kv * watermark_ratio
# Low watermark: start preempting existing requests
self.low_watermark = self.available_for_kv * 0.98
def should_admit(self, current_kv_usage_gb, estimated_new_request_gb):
return (current_kv_usage_gb + estimated_new_request_gb) < self.high_watermark
def should_preempt(self, current_kv_usage_gb):
return current_kv_usage_gb > self.low_watermark
Setting the watermark too conservatively (e.g., 70%) wastes GPU memory and reduces throughput. Setting it too aggressively (e.g., 98%) causes frequent preemptions, which destroy latency. The right value depends on your workload’s sequence length distribution and variance.
Watermark Tuning Impact (Llama 3 70B, mixed workload)
| Watermark | Max Batch | Throughput | Preemptions/min | p99 Latency |
|---|---|---|---|---|
| 70% | 42 | 580 tok/s | 0 | 120 ms |
| 85% | 54 | 720 tok/s | 0.2 | 135 ms |
| 90% | 58 | 760 tok/s | 1.5 | 180 ms |
| 95% | 61 | 780 tok/s | 8.0 | 450 ms |
| 98% | 63 | 790 tok/s | 25.0 | 1200 ms |
The sweet spot is typically 80-90%. Beyond 90%, preemptions spike and p99 latency degrades rapidly.
Preemption Strategies: Swap vs Recompute
When the scheduler must preempt a request, the choice between swapping to CPU and recomputing from scratch depends on the request’s state:
For PCIe Gen4 x16 (~25 GB/s bidirectional) and Llama 3 70B:
Swap vs Recompute Breakeven (Llama 3 70B)
| Prompt Length | KV Size (GB) | Swap Time (ms) | Recompute Time (ms) | Winner |
|---|---|---|---|---|
| 128 | 0.005 | 0.2 | 15 | Swap |
| 512 | 0.02 | 0.8 | 55 | Swap |
| 2048 | 0.08 | 3.2 | 210 | Swap |
| 8192 | 0.31 | 12.4 | 830 | Swap |
| 32768 | 1.25 | 50.0 | 3300 | Swap |
For typical LLM workloads, swap almost always wins because the KV cache for a single request (without the batch dimension) is relatively small. Recompute only becomes competitive when CPU memory is exhausted or PCIe is heavily contended.
Multi-Tenant Isolation
When serving multiple models or multiple tenants on the same GPU, KV cache memory must be partitioned:
class MultiTenantKVManager:
"""Isolate KV cache allocations between tenants."""
def __init__(self, total_blocks, tenant_quotas):
"""
tenant_quotas: dict mapping tenant_id to fraction of total blocks
Example: {"tenant_a": 0.6, "tenant_b": 0.3, "shared": 0.1}
"""
self.allocators = {}
allocated = 0
for tenant_id, fraction in tenant_quotas.items():
num_blocks = int(total_blocks * fraction)
self.allocators[tenant_id] = BlockAllocator(
num_blocks=num_blocks,
block_size=16,
layers=80, kv_heads=8, head_dim=128,
dtype=torch.float16
)
allocated += num_blocks
def allocate(self, tenant_id):
allocator = self.allocators.get(tenant_id)
if allocator is None:
raise ValueError(f"Unknown tenant: {tenant_id}")
return allocator.allocate()
def get_utilization(self, tenant_id):
alloc = self.allocators[tenant_id]
total = len(alloc.blocks)
free = len(alloc.free_list)
return (total - free) / total
The “shared” pool handles burst capacity: when a tenant exceeds its quota, it can borrow from the shared pool (with lower priority and possible preemption).
Prefix Caching
Many serving workloads share common prefixes: system prompts, few-shot examples, or RAG context. Prefix caching stores these shared KV cache blocks once and reuses them across requests.
class PrefixCache:
"""Cache KV blocks for common prefixes, keyed by token hash."""
def __init__(self, allocator, max_cached_prefixes=1000):
self.allocator = allocator
self.cache = {} # hash(token_ids) -> list[block_id]
self.access_order = [] # LRU tracking
def lookup(self, token_ids):
"""Check if prefix KV cache exists."""
key = hash(tuple(token_ids))
if key in self.cache:
self._touch(key)
block_ids = self.cache[key]
for bid in block_ids:
self.allocator.incref(bid) # shared reference
return block_ids
return None
def insert(self, token_ids, block_ids):
"""Cache a computed prefix."""
key = hash(tuple(token_ids))
self.cache[key] = block_ids
for bid in block_ids:
self.allocator.incref(bid)
self.access_order.append(key)
self._maybe_evict()
For workloads with a shared system prompt (common in chat applications), prefix caching can save both the memory and the compute cost of the prefill for that prefix. A 2000-token system prompt shared across 64 concurrent requests saves 64x the prefill compute and stores the prefix KV blocks only once.
When NOT to Optimize KV Cache
Not every deployment benefits from aggressive KV cache optimization. Here are the cases where simpler approaches work fine:
Short Sequences
If your maximum sequence length is under 512 tokens, KV cache memory is negligible. For Llama 3 8B at batch=32, seq=512:
Half a gigabyte. On an A100 with 80 GB, this is under 1% of memory. Naive contiguous allocation with 50% waste still only costs 1 GB total. The engineering complexity of PagedAttention is not justified.
Small Models
Models under 3B parameters have small KV caches by construction (fewer layers, fewer heads). For Llama 3.2 1B at batch=64, seq=4096:
Still manageable with simple allocation on modern GPUs. The model weights are only ~2 GB, so even with wasteful allocation you have ample headroom.
Prefill-Dominated Workloads
If your workload is mostly long prompts with short outputs (e.g., classification, extraction, scoring), then:
- Prefill is the bottleneck, not decode
- KV cache lifetime is short (few decode steps)
- Memory pressure is transient, not sustained
In this regime, optimizing prefill throughput (chunked prefill, FlashAttention, tensor parallelism) matters more than optimizing KV cache management.
When KV Cache Optimization Matters Most
| Scenario | Seq Length | Batch Size | Output Length | KV Optimization Impact |
|---|---|---|---|---|
| Short chat | < 512 | < 32 | < 256 | Low -- simple alloc is fine |
| Small model API | < 4096 | < 64 | Any | Low-Medium |
| Classification | < 2048 | Large | < 10 | Low -- prefill dominated |
| Long-form generation | > 2048 | > 32 | > 512 | High |
| Multi-turn chat | > 4096 | > 64 | Variable | Very High |
| RAG with long context | > 8192 | > 16 | > 256 | Critical |
Throughput Impact: Real Numbers
Let us put together a comprehensive throughput comparison showing the combined impact of KV cache optimizations on a realistic serving setup.
Setup: Llama 3 70B on 4x A100 80GB (tensor parallel), mixed workload with average input 1024 tokens, average output 512 tokens, max sequence 8192.
End-to-End Serving Throughput (Llama 3 70B, 4x A100)
| Configuration | Max Batch | Throughput (tok/s) | p50 Latency | p99 Latency | GPU Util |
|---|---|---|---|---|---|
| Contiguous FP16 | 24 | 420 | 85 ms | 350 ms | 55% |
| Paged FP16 | 58 | 780 | 72 ms | 180 ms | 82% |
| Paged INT8 | 96 | 1050 | 65 ms | 160 ms | 88% |
| Paged INT8 + prefix cache | 96 | 1180 | 55 ms | 145 ms | 90% |
| Paged INT4 + prefix cache | 140 | 1420 | 52 ms | 155 ms | 91% |
Throughput Improvement Stack (Llama 3 70B, 4x A100)
| Metric | Contiguous FP16 | Paged FP16 | Paged INT8 | + Prefix Cache | Paged INT4 + Cache |
|---|---|---|---|---|---|
| Throughput (tokens/sec) |
The progression tells a clear story:
- Paged allocation (1.86x over contiguous): eliminates fragmentation, allows 2.4x more concurrent requests
- INT8 quantization (1.35x over paged FP16): fits more requests AND reduces bandwidth per request
- Prefix caching (1.12x incremental): saves prefill compute for repeated system prompts
- INT4 quantization (1.20x over INT8 + cache): pushes batch size even higher, though with small quality tradeoffs
The cumulative improvement from the simplest to the most optimized configuration is 3.4x — and this is on the same hardware, serving the same model, with no changes to model quality (for the INT8 path).
The Full Picture
Optimized GPU Memory Layout: Llama 3 70B Serving
4x A100 80GB with PagedAttention + INT8 KV cache
0x8C000 0x00000 0xBC000 0x8C000 0xC6000 0xBC000 0xD2000 0xC6000 0xE0000 0xD2000 140 GB 48 GB 10 GB 12 GB 10 GB KV cache management is not a single optimization but an interlocking system of decisions:
- Allocation strategy determines memory utilization and fragmentation. PagedAttention is now the standard.
- Quantization level trades precision for capacity and bandwidth. INT8 is the sweet spot for most deployments.
- Eviction policy determines graceful degradation under memory pressure. Attention sinks + heavy hitters preserve quality.
- Bandwidth awareness explains why KV compression improves throughput, not just capacity.
- Watermark tuning balances utilization against latency stability.
- Prefix sharing and copy-on-write amortize common computation.
The difference between a naive implementation and a well-tuned one is not incremental — it is the difference between serving 24 concurrent users and serving 140 on the same hardware. For anyone running LLM inference at scale, KV cache management is where the most impactful systems work happens.