A Llama 70B model with GQA (8 KV heads, head dimension 128, 80 layers) at 128K context length requires of KV cache per request in FP16. On an 8x H100 cluster (640 GB total HBM), the model weights alone consume 140 GB (FP16), leaving 500 GB for KV cache. At 41.9 GB per request, the system can serve only 11 concurrent 128K-context requests. Reducing KV cache size directly increases serving throughput.
Four strategies address this, each with different tradeoffs between memory savings and output quality.
Strategy 1: KV Cache Quantization (INT8/FP8)
The most straightforward approach: store KV cache values in lower precision. KV cache values have a bounded range (post-softmax attention is applied to these values), making them amenable to quantization with minimal quality loss.
Per-Token Asymmetric INT8
import torch
class KVCacheQuantizerINT8:
"""Quantize KV cache to INT8 with per-token scaling."""
def __init__(self):
pass
def quantize_kv(self, key, value):
"""Quantize K and V tensors to INT8.
key: [batch, num_kv_heads, seq_len, head_dim] in FP16
value: [batch, num_kv_heads, seq_len, head_dim] in FP16
returns: quantized tensors + scales
"""
k_quant, k_scale, k_zero = self._quantize_per_token(key)
v_quant, v_scale, v_zero = self._quantize_per_token(value)
return {
"k_quant": k_quant, # int8
"k_scale": k_scale, # fp16, per-token
"k_zero": k_zero, # fp16, per-token
"v_quant": v_quant, # int8
"v_scale": v_scale, # fp16
"v_zero": v_zero, # fp16
}
def _quantize_per_token(self, tensor):
"""Asymmetric per-token INT8 quantization.
Each token (across head_dim) gets its own scale and zero point."""
# tensor: [batch, heads, seq_len, head_dim]
# Compute min/max along head_dim
t_min = tensor.amin(dim=-1, keepdim=True)
t_max = tensor.amax(dim=-1, keepdim=True)
# Scale and zero point
scale = (t_max - t_min) / 255.0
scale = torch.clamp(scale, min=1e-8) # Avoid division by zero
zero = t_min
# Quantize
quantized = torch.clamp(
torch.round((tensor - zero) / scale), 0, 255
).to(torch.uint8)
return quantized, scale.to(torch.float16), zero.to(torch.float16)
def dequantize_kv(self, quantized_kv):
"""Dequantize for attention computation."""
k = (quantized_kv["k_quant"].float() *
quantized_kv["k_scale"] + quantized_kv["k_zero"]).half()
v = (quantized_kv["v_quant"].float() *
quantized_kv["v_scale"] + quantized_kv["v_zero"]).half()
return k, v
def memory_savings(self, seq_len, num_kv_heads, head_dim, num_layers):
"""Calculate memory savings from INT8 quantization."""
# FP16: 2 bytes per element
fp16_bytes = 2 * num_layers * 2 * num_kv_heads * seq_len * head_dim * 2
# INT8: 1 byte per element + scale/zero overhead
int8_data = 2 * num_layers * 2 * num_kv_heads * seq_len * head_dim * 1
# Scale and zero: 2 values per token, per head, FP16
scale_overhead = 2 * num_layers * 2 * num_kv_heads * seq_len * 2 * 2
int8_total = int8_data + scale_overhead
savings = 1 - int8_total / fp16_bytes
return {
"fp16_gb": fp16_bytes / 1e9,
"int8_gb": int8_total / 1e9,
"savings_pct": savings * 100,
"ratio": fp16_bytes / int8_total,
}
FP8 KV Cache (Hopper Native)
On H100 GPUs, FP8 (E4M3 or E5M2) is natively supported in tensor cores:
class KVCacheQuantizerFP8:
"""FP8 KV cache quantization using Hopper native FP8 support."""
def __init__(self, fp8_format="e4m3"):
self.fp8_dtype = torch.float8_e4m3fn # E4M3: more precision, less range
# E5M2 alternative: torch.float8_e5m2 for more range, less precision
def quantize_kv(self, key, value):
"""Quantize to FP8 with per-tensor scaling.
FP8 E4M3 range: [-448, 448], precision: ~3.5 decimal digits
FP16 range: [-65504, 65504], precision: ~3.3 decimal digits
"""
# Per-tensor scale to fit FP16 range into FP8 range
k_amax = key.abs().amax()
v_amax = value.abs().amax()
k_scale = k_amax / 448.0 # Max representable in E4M3
v_scale = v_amax / 448.0
k_fp8 = (key / k_scale).to(self.fp8_dtype)
v_fp8 = (value / v_scale).to(self.fp8_dtype)
return {
"k_fp8": k_fp8,
"k_scale": k_scale,
"v_fp8": v_fp8,
"v_scale": v_scale,
}
def dequantize_for_attention(self, quantized_kv, query):
"""Dequantize and compute attention.
On H100, FP8 GEMMs are natively supported, so we can
compute Q @ K^T directly in FP8 without explicit dequant."""
# Option 1: Dequantize then compute (fallback)
k = quantized_kv["k_fp8"].to(torch.float16) * quantized_kv["k_scale"]
v = quantized_kv["v_fp8"].to(torch.float16) * quantized_kv["v_scale"]
# Option 2: FP8 matmul (Hopper native, 2x throughput)
# scores = torch._scaled_mm(
# query.to(torch.float8_e4m3fn),
# quantized_kv["k_fp8"].transpose(-2, -1),
# scale_a=query_scale,
# scale_b=quantized_kv["k_scale"],
# )
return k, v
KV Cache Quantization: Memory and Quality Impact
| Precision | Bytes/Element | Memory (128K, Llama 70B) | Quality Loss (PPL) | Throughput Gain |
|---|---|---|---|---|
| FP16 (baseline) | 2 | 41.9 GB | 0 (baseline) | 1.0x |
| FP8 (E4M3) | 1 | 21.0 GB | +0.02 PPL | 2.0x |
| INT8 (per-token) | ~1.03 | 21.6 GB | +0.05 PPL | 1.94x |
| INT4 (per-group) | ~0.56 | 11.7 GB | +0.3 PPL | 3.6x |
Strategy 2: H2O (Heavy Hitter Oracle)
H2O observes that attention patterns are highly non-uniform: a small fraction of tokens receive the majority of attention weight. These “heavy hitter” tokens should be kept, while low-attention tokens can be evicted.
Attention Score Tracking
class H2OKVCache:
"""H2O: Heavy-Hitter Oracle for KV cache eviction.
Tracks cumulative attention scores and evicts low-importance tokens."""
def __init__(self, max_cache_size, num_layers, num_kv_heads,
head_dim, heavy_hitter_ratio=0.5, recent_ratio=0.25):
self.max_size = max_cache_size
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
# Budget allocation
self.heavy_budget = int(max_cache_size * heavy_hitter_ratio)
self.recent_budget = int(max_cache_size * recent_ratio)
# Per-layer, per-head attention score accumulators
# Tracks cumulative attention each token has received
self.attention_scores = {} # layer -> [batch, heads, seq_len]
# KV cache storage
self.kv_cache = {} # layer -> (K, V) tensors
self.token_indices = {} # layer -> which original positions are cached
def update_attention_scores(self, layer_idx, attention_weights):
"""Called after each attention computation.
attention_weights: [batch, num_heads, 1, seq_len] (decode step)
"""
if layer_idx not in self.attention_scores:
self.attention_scores[layer_idx] = torch.zeros_like(
attention_weights.squeeze(2)
)
# Accumulate: each token's total received attention
self.attention_scores[layer_idx] += attention_weights.squeeze(2)
def evict_if_needed(self, layer_idx):
"""Evict low-importance tokens if cache exceeds budget."""
if layer_idx not in self.kv_cache:
return
k, v = self.kv_cache[layer_idx]
current_size = k.shape[2] # seq_len dimension
if current_size <= self.max_size:
return # No eviction needed
scores = self.attention_scores[layer_idx] # [batch, heads, seq_len]
# Average across heads to get per-token importance
token_importance = scores.mean(dim=1) # [batch, seq_len]
# Always keep: recent tokens (last recent_budget positions)
recent_mask = torch.zeros(current_size, dtype=torch.bool,
device=k.device)
recent_mask[-self.recent_budget:] = True
# Heavy hitters: top-K by cumulative attention score
# Exclude recent tokens from heavy hitter selection
non_recent_scores = token_importance.clone()
non_recent_scores[:, -self.recent_budget:] = -float("inf")
_, heavy_indices = torch.topk(
non_recent_scores, self.heavy_budget, dim=-1
)
# Build keep mask: heavy hitters + recent tokens
keep_mask = recent_mask.unsqueeze(0).expand_as(token_importance)
keep_mask.scatter_(1, heavy_indices, True)
# Evict: keep only selected tokens
keep_indices = keep_mask[0].nonzero().squeeze(-1) # Assume batch=1
self.kv_cache[layer_idx] = (
k[:, :, keep_indices, :],
v[:, :, keep_indices, :],
)
self.attention_scores[layer_idx] = scores[:, :, keep_indices]
self.token_indices[layer_idx] = keep_indices
def get_kv(self, layer_idx):
"""Get current KV cache for attention computation."""
return self.kv_cache[layer_idx]
def append_kv(self, layer_idx, new_k, new_v, attention_weights):
"""Append new KV and update scores, then evict if needed."""
if layer_idx in self.kv_cache:
k, v = self.kv_cache[layer_idx]
self.kv_cache[layer_idx] = (
torch.cat([k, new_k], dim=2),
torch.cat([v, new_v], dim=2),
)
else:
self.kv_cache[layer_idx] = (new_k, new_v)
self.update_attention_scores(layer_idx, attention_weights)
self.evict_if_needed(layer_idx)
H2O’s key observation: in Llama-family models, approximately 5-10% of tokens consistently receive over 90% of cumulative attention weight. These are typically: (1) the BOS/system prompt tokens, (2) tokens marking structural boundaries (newlines, punctuation), and (3) semantically important content tokens. The heavy hitter pattern is consistent across layers, meaning the same tokens tend to be important at every layer.
H2O Quality Analysis
def analyze_h2o_quality(model, dataset, cache_sizes):
"""Measure quality degradation at different H2O cache budgets."""
results = []
for max_cache in cache_sizes:
h2o_cache = H2OKVCache(
max_cache_size=max_cache,
num_layers=model.config.num_hidden_layers,
num_kv_heads=model.config.num_key_value_heads,
head_dim=model.config.hidden_size // model.config.num_attention_heads,
)
total_loss = 0
num_tokens = 0
for sample in dataset:
# Run model with H2O cache
logits = model.forward_with_h2o(sample.input_ids, h2o_cache)
loss = cross_entropy(logits, sample.labels)
total_loss += loss.item() * sample.labels.numel()
num_tokens += sample.labels.numel()
ppl = torch.exp(torch.tensor(total_loss / num_tokens))
results.append({
"max_cache": max_cache,
"perplexity": ppl.item(),
"memory_ratio": max_cache / len(sample.input_ids),
})
return results
H2O Perplexity vs Cache Budget (Llama 70B, 128K Context)
line| Metric | 5% | 10% | 20% | 30% | 50% | 75% | 100% |
|---|---|---|---|---|---|---|---|
| H2O (heavy hitter + recent) | |||||||
| Random eviction | |||||||
| Full KV (baseline) |
Strategy 3: Attention Sinks
StreamingLLM discovered that the first few tokens in any sequence receive disproportionately high attention, regardless of their semantic content. These “attention sinks” act as learned bias terms in the attention computation. Removing them causes catastrophic quality degradation even if the content tokens are preserved.
class AttentionSinkCache:
"""Attention sink + sliding window KV cache.
Always keeps first N tokens (sinks) + last W tokens (window)."""
def __init__(self, num_sink_tokens=4, window_size=1024,
num_layers=80, num_kv_heads=8, head_dim=128):
self.num_sinks = num_sink_tokens
self.window_size = window_size
self.total_budget = num_sink_tokens + window_size
# Pre-allocate cache
self.k_cache = {} # layer -> [batch, heads, total_budget, head_dim]
self.v_cache = {}
self.current_len = 0 # Total tokens seen so far
def append(self, layer_idx, new_k, new_v):
"""Append new KV token, maintaining sink + window invariant.
new_k, new_v: [batch, heads, 1, head_dim]
"""
self.current_len += 1
if layer_idx not in self.k_cache:
self.k_cache[layer_idx] = new_k
self.v_cache[layer_idx] = new_v
return
k = self.k_cache[layer_idx]
v = self.v_cache[layer_idx]
seq_len = k.shape[2]
if seq_len < self.total_budget:
# Cache not full yet, just append
self.k_cache[layer_idx] = torch.cat([k, new_k], dim=2)
self.v_cache[layer_idx] = torch.cat([v, new_v], dim=2)
else:
# Cache full: keep sinks + shift window + add new token
# Layout: [sink_0, ..., sink_N, window_start, ..., window_end]
sinks_k = k[:, :, :self.num_sinks, :]
sinks_v = v[:, :, :self.num_sinks, :]
# Window: drop oldest window token, append new
window_k = k[:, :, self.num_sinks + 1:, :] # Drop first window token
window_v = v[:, :, self.num_sinks + 1:, :]
self.k_cache[layer_idx] = torch.cat(
[sinks_k, window_k, new_k], dim=2
)
self.v_cache[layer_idx] = torch.cat(
[sinks_v, window_v, new_v], dim=2
)
def get_kv(self, layer_idx):
"""Return current KV cache for attention."""
return self.k_cache[layer_idx], self.v_cache[layer_idx]
def get_position_ids(self):
"""Return position IDs for the cached tokens.
Sink tokens keep their original positions (0, 1, ..., num_sinks-1).
Window tokens have positions relative to current_len."""
sink_positions = list(range(self.num_sinks))
cached_len = min(self.current_len - self.num_sinks, self.window_size)
window_start = max(self.num_sinks, self.current_len - self.window_size)
window_positions = list(range(window_start, self.current_len))
return torch.tensor(sink_positions + window_positions)
Attention sink positions must use their ORIGINAL position IDs with RoPE, not consecutive IDs. If sink tokens 0-3 are at positions [0, 1, 2, 3] and window tokens are at positions [95000, 95001, …, 96023], the position IDs must reflect this gap. Using consecutive IDs [0, 1, 2, 3, 4, 5, …] breaks the model because RoPE encodes absolute position information.
Attention Sinks Quality Analysis
The critical question: how many sink tokens do you need? Research shows that 4 sink tokens capture the dominant attention bias pattern. Adding more sinks beyond 4 provides diminishing returns:
def evaluate_sink_counts(model, eval_data, window_size=1024):
"""Measure quality impact of different sink token counts."""
results = []
for num_sinks in [0, 1, 2, 4, 8, 16, 32]:
cache = AttentionSinkCache(
num_sink_tokens=num_sinks,
window_size=window_size,
num_layers=model.config.num_hidden_layers,
num_kv_heads=model.config.num_key_value_heads,
head_dim=model.config.hidden_size // model.config.num_attention_heads,
)
total_loss = 0
total_tokens = 0
for sample in eval_data:
logits = model.forward_with_cache(sample.input_ids, cache)
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
sample.labels.view(-1),
reduction="sum",
)
total_loss += loss.item()
total_tokens += sample.labels.numel()
ppl = torch.exp(torch.tensor(total_loss / total_tokens)).item()
results.append({
"num_sinks": num_sinks,
"perplexity": ppl,
"total_cache_size": num_sinks + window_size,
})
return results
Sink Token Count vs Perplexity (Llama 70B, Window=1024, 32K Input)
| Sink Tokens | Total Cache | Perplexity | vs Full KV |
|---|---|---|---|
| 0 (window only) | 1024 | 15.8 | +10.85 |
| 1 | 1025 | 6.2 | +1.25 |
| 4 | 1028 | 5.6 | +0.65 |
| 16 | 1040 | 5.5 | +0.55 |
| 32 | 1056 | 5.5 | +0.55 |
| Full KV | 32768 | 4.95 | baseline |
Without any sink tokens (pure sliding window on a model not trained for it), perplexity degrades catastrophically from 4.95 to 15.8. Adding just 4 sink tokens reduces the gap to +0.65 PPL. Beyond 4 sinks, the improvement plateaus, confirming that the attention sink phenomenon is concentrated in the first few token positions.
Strategy 4: Sliding Window Attention
Some models (Mistral, Phi) are trained with sliding window attention: each token can only attend to the last tokens. This inherently bounds the KV cache to entries per layer.
class SlidingWindowKVCache:
"""Fixed-size sliding window KV cache.
Only works correctly with models trained using sliding window attention."""
def __init__(self, window_size=4096, num_layers=32,
num_kv_heads=8, head_dim=128, device="cuda"):
self.W = window_size
# Pre-allocate circular buffers
self.k_buffer = torch.zeros(
num_layers, 1, num_kv_heads, window_size, head_dim,
dtype=torch.float16, device=device
)
self.v_buffer = torch.zeros(
num_layers, 1, num_kv_heads, window_size, head_dim,
dtype=torch.float16, device=device
)
self.write_pos = 0 # Circular buffer write position
self.total_written = 0
def append(self, layer_idx, new_k, new_v):
"""Append new KV at circular buffer position."""
pos = self.write_pos % self.W
self.k_buffer[layer_idx, :, :, pos, :] = new_k.squeeze(2)
self.v_buffer[layer_idx, :, :, pos, :] = new_v.squeeze(2)
if layer_idx == 0: # Only increment once per token
self.write_pos += 1
self.total_written += 1
def get_kv(self, layer_idx):
"""Get KV cache in correct temporal order."""
if self.total_written <= self.W:
# Buffer not full yet, return what we have
return (
self.k_buffer[layer_idx, :, :, :self.total_written, :],
self.v_buffer[layer_idx, :, :, :self.total_written, :],
)
# Buffer full: reorder from circular to temporal
start = self.write_pos % self.W
k = torch.cat([
self.k_buffer[layer_idx, :, :, start:, :],
self.k_buffer[layer_idx, :, :, :start, :],
], dim=2)
v = torch.cat([
self.v_buffer[layer_idx, :, :, start:, :],
self.v_buffer[layer_idx, :, :, :start, :],
], dim=2)
return k, v
def memory_usage(self):
"""Fixed memory regardless of sequence length."""
return self.k_buffer.nbytes + self.v_buffer.nbytes
Combining Strategies
The four strategies are not mutually exclusive. Production systems combine them:
class CombinedKVCacheStrategy:
"""Combine quantization + H2O + attention sinks."""
def __init__(self, config):
self.quantizer = KVCacheQuantizerFP8()
# Sink tokens: always keep first 4 tokens in FP16 (no quantization)
self.num_sinks = 4
# Heavy hitters: keep top 20% by attention score in FP8
self.heavy_ratio = 0.2
# Recent window: keep last 512 tokens in FP8
self.recent_window = 512
# Everything else: evicted
self.max_cache = config.max_kv_cache_tokens
def manage_cache(self, layer_idx, kv_cache, attention_weights):
"""Combined cache management after each decode step."""
k, v = kv_cache
seq_len = k.shape[2]
if seq_len <= self.max_cache:
# Under budget: just quantize
return self.quantizer.quantize_kv(k, v)
# Over budget: apply eviction + quantization
# 1. Keep sinks (always, FP16)
sink_k = k[:, :, :self.num_sinks, :]
sink_v = v[:, :, :self.num_sinks, :]
# 2. Keep recent window (FP8)
recent_k = k[:, :, -self.recent_window:, :]
recent_v = v[:, :, -self.recent_window:, :]
recent_quant = self.quantizer.quantize_kv(recent_k, recent_v)
# 3. H2O on middle tokens
middle_k = k[:, :, self.num_sinks:-self.recent_window, :]
middle_v = v[:, :, self.num_sinks:-self.recent_window, :]
middle_scores = attention_weights[:, :, :, self.num_sinks:-self.recent_window]
heavy_budget = self.max_cache - self.num_sinks - self.recent_window
_, heavy_idx = torch.topk(
middle_scores.mean(dim=1).squeeze(1), heavy_budget, dim=-1
)
heavy_k = middle_k[:, :, heavy_idx.squeeze(0), :]
heavy_v = middle_v[:, :, heavy_idx.squeeze(0), :]
heavy_quant = self.quantizer.quantize_kv(heavy_k, heavy_v)
return {
"sinks": (sink_k, sink_v), # FP16
"heavy_hitters": heavy_quant, # FP8
"recent": recent_quant, # FP8
}
Combined Strategy: Memory Usage and Quality (Llama 70B, 128K Context)
| Strategy | Memory | Compression | PPL Impact | Max Concurrent Requests |
|---|---|---|---|---|
| Full FP16 | 41.9 GB | 1.0x | Baseline | 11 |
| FP8 only | 21.0 GB | 2.0x | +0.02 | 23 |
| H2O (50% budget) | 21.0 GB | 2.0x | +0.10 | 23 |
| Sinks + Window (W=4K) | 1.3 GB | 32x | +0.8 (long deps) | 384 |
| FP8 + H2O (50%) | 10.5 GB | 4.0x | +0.12 | 47 |
| Sinks + H2O (30%) + FP8 | 6.3 GB | 6.6x | +0.25 | 79 |
Max Concurrent 128K Requests vs Compression Strategy (8x H100, 500 GB KV Budget)
| Metric | Full FP16 | FP8 | H2O 50% | Sink+Window | FP8+H2O | Sink+H2O+FP8 |
|---|---|---|---|---|---|---|
| Max Concurrent Requests |
Impact on Attention Kernel
Quantized and evicted KV caches require modified attention kernels:
def quantized_paged_attention(query, kv_blocks_fp8, kv_scales,
page_table, seq_lens):
"""Attention kernel that operates on FP8 KV cache.
Dequantizes on-the-fly during attention computation."""
batch_size = query.shape[0]
num_heads = query.shape[1]
head_dim = query.shape[-1]
output = torch.zeros_like(query)
for b in range(batch_size):
# Gather this request's KV blocks
num_blocks = (seq_lens[b] + 15) // 16
all_scores = []
all_values = []
for block_idx in range(num_blocks):
physical_block = page_table[b, block_idx]
# Load FP8 K block and dequantize
k_fp8 = kv_blocks_fp8[physical_block, 0] # [kv_heads, block_size, head_dim]
k_scale = kv_scales[physical_block, 0]
k_block = k_fp8.to(torch.float16) * k_scale
v_fp8 = kv_blocks_fp8[physical_block, 1]
v_scale = kv_scales[physical_block, 1]
v_block = v_fp8.to(torch.float16) * v_scale
# Compute Q @ K^T for this block
# query: [1, num_heads, 1, head_dim]
# k_block: [num_kv_heads, block_size, head_dim] (needs GQA expansion)
scores = torch.matmul(
query[b:b+1], k_block.transpose(-2, -1)
) / (head_dim ** 0.5)
all_scores.append(scores)
all_values.append(v_block)
# Concatenate and compute attention
all_scores = torch.cat(all_scores, dim=-1)
all_values = torch.cat(all_values, dim=-2)
# Trim to actual sequence length
all_scores = all_scores[:, :, :, :seq_lens[b]]
all_values = all_values[:, :seq_lens[b], :]
attn = torch.softmax(all_scores, dim=-1)
output[b] = torch.matmul(attn, all_values)
return output
FP8 KV cache quantization with on-the-fly dequantization in the attention kernel adds negligible latency (less than 2% overhead) because the dequantization is bandwidth-free: the FP8 to FP16 conversion happens in registers after the data is already loaded from HBM. The bandwidth savings of loading 1 byte instead of 2 bytes per element directly translates to 2x faster KV cache reads during decode attention.
Choosing the Right Strategy
The right strategy depends on the workload:
Strategy Selection Guide
| Workload | Recommended Strategy | Reason |
|---|---|---|
| Short context (less than 4K) | FP8 quantization only | Minimal KV, just save memory for more batching |
| Medium context (4K-32K) | FP8 + H2O (50%) | Good compression with minimal quality loss |
| Long context (32K-128K) | Sinks + H2O (30%) + FP8 | Aggressive compression needed |
| Streaming/infinite context | Sinks + sliding window | Fixed memory, accepts long-range quality loss |
| Multi-turn chat | FP8 + prefix caching | Cache shared system prompt, compress per-turn KV |
Implementation in vLLM and SGLang
Both major serving frameworks have implemented KV cache quantization:
# vLLM: enable FP8 KV cache via command line
"""
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3.1-70B \
--kv-cache-dtype fp8_e4m3 \
--tensor-parallel-size 8
"""
# SGLang: enable FP8 KV cache
"""
python -m sglang.launch_server \
--model-path meta-llama/Llama-3.1-70B \
--kv-cache-dtype fp8_e4m3 \
--tp 8
"""
# In vLLM, FP8 KV cache is implemented in the attention backends:
class FP8KVCacheAttention:
"""Simplified FP8 KV cache management in vLLM."""
def __init__(self, num_blocks, block_size, num_kv_heads, head_dim):
# Allocate KV cache in FP8 format
# Half the memory of FP16 cache
self.k_cache = torch.zeros(
num_blocks, block_size, num_kv_heads, head_dim,
dtype=torch.float8_e4m3fn, device="cuda"
)
self.v_cache = torch.zeros(
num_blocks, block_size, num_kv_heads, head_dim,
dtype=torch.float8_e4m3fn, device="cuda"
)
# Per-block scale factors (FP32 for accuracy)
self.k_scales = torch.ones(
num_blocks, dtype=torch.float32, device="cuda"
)
self.v_scales = torch.ones(
num_blocks, dtype=torch.float32, device="cuda"
)
def write_kv(self, block_idx, slot_in_block, k_fp16, v_fp16):
"""Write FP16 KV values into FP8 cache with scaling."""
# Compute scale for this write
k_amax = k_fp16.abs().amax()
v_amax = v_fp16.abs().amax()
k_scale = k_amax / 448.0
v_scale = v_amax / 448.0
# Quantize and store
self.k_cache[block_idx, slot_in_block] = (k_fp16 / k_scale).to(
torch.float8_e4m3fn
)
self.v_cache[block_idx, slot_in_block] = (v_fp16 / v_scale).to(
torch.float8_e4m3fn
)
# Update running scale (exponential moving average)
alpha = 0.1
self.k_scales[block_idx] = (
(1 - alpha) * self.k_scales[block_idx] + alpha * k_scale
)
self.v_scales[block_idx] = (
(1 - alpha) * self.v_scales[block_idx] + alpha * v_scale
)
Benchmarking KV Compression Impact
The only way to validate a KV compression strategy is to measure both memory savings and quality impact on your target workload:
def benchmark_kv_compression(model, strategies, eval_dataset):
"""Compare KV compression strategies on throughput and quality."""
results = []
for strategy_name, strategy in strategies.items():
# Measure quality (perplexity on evaluation set)
total_loss = 0
total_tokens = 0
for sample in eval_dataset:
logits = model.forward_with_kv_strategy(
sample.input_ids, strategy
)
loss = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
sample.labels.view(-1),
reduction="sum",
)
total_loss += loss.item()
total_tokens += sample.labels.numel()
ppl = torch.exp(torch.tensor(total_loss / total_tokens)).item()
# Measure memory
kv_memory_gb = strategy.memory_usage() / 1e9
# Estimate max concurrent requests
total_hbm = 80 * 8 # 8x H100
weight_memory = 140 # 70B FP16
kv_budget = total_hbm - weight_memory
max_requests = int(kv_budget / kv_memory_gb)
results.append({
"strategy": strategy_name,
"perplexity": ppl,
"kv_memory_gb": kv_memory_gb,
"max_concurrent": max_requests,
"throughput_relative": max_requests / results[0]["max_concurrent"] if results else 1.0,
})
return results
KV cache compression is the primary lever for increasing serving throughput at long context lengths. The 41.9 GB per request for 128K context means that without compression, most of the GPU cluster’s memory is consumed by a handful of requests. FP8 quantization alone doubles capacity with negligible quality impact. Adding H2O eviction on top provides another 2x. Together, they enable 4x more concurrent long-context requests, directly translating to 4x higher throughput for the same hardware. The choice between strategies is ultimately an empirical question: measure perplexity on your specific workload at your target compression ratio, and pick the strategy that preserves quality at the memory budget you need.