KV cache quantization is the highest-leverage memory optimization in LLM serving. At production batch sizes and sequence lengths, the KV cache dominates GPU memory — often consuming 2-5x more than the model weights themselves. Quantizing the KV cache from FP16 to FP8 halves this memory, doubling the number of concurrent requests you can serve. Quantizing to INT4 quarters it, enabling 4x more requests on the same hardware.
This is distinct from weight quantization (Parts 2 and 5) in every way that matters. Weights are static — quantize once, serve forever. The KV cache is dynamic — new key-value pairs are generated for every token, and they must be quantized online during generation with zero additional latency. Weights can use offline calibration data. KV cache quantization must work with whatever distribution the current request produces. The engineering constraints are completely different.
This post covers why KV cache deserves its own quantization strategy, per-token scaling and its implementation, FP8/INT8/INT4 KV cache at the algorithm level, quality impact across model sizes and precision targets, and a complete implementation of online KV cache quantization during autoregressive generation.
KV Cache Memory: The Serving Bottleneck
For a transformer model with layers, KV heads, head dimension , serving requests at sequence length :
The factor of 2 is for K and V. For Llama 70B (, GQA KV heads, ):
At , : 85.9 GB for KV cache alone. The model weights (even at INT4) are only 35 GB. The KV cache is 2.5x the model.
Memory Split: Model Weights vs KV Cache (Llama 70B, seq=4096)
(GB)Quantizing the KV cache is the only way to serve at high batch sizes without adding more GPUs. Each halving of KV precision roughly doubles the batch capacity:
Maximum Batch Size by KV Precision (Llama 70B, INT4 Weights, H100-80GB, seq=4096)
| KV Precision | KV per Request | KV at Max Batch | Max Batch | Improvement |
|---|---|---|---|---|
| FP16 | 1.34 GB | 42.9 GB | 32 | 1.0x |
| FP8 E4M3 | 0.67 GB | 42.9 GB | 64 | 2.0x |
| INT8 | 0.67 GB | 42.9 GB | 64 | 2.0x |
| INT4 | 0.34 GB | 40.8 GB | 120 | 3.75x |
Why KV Cache Quantization Is Different
Dynamic vs Static
Model weights are fixed at load time. You can spend hours running GPTQ or AWQ with calibration data to find optimal scale factors. KV cache values are generated token-by-token during inference. Each new token produces a new K and V vector that must be quantized immediately.
Per-Token vs Per-Tensor
Each token’s K and V vectors have their own distribution. Token 1 might have key values in the range [-0.5, 0.5], while token 500 might have values in [-2.0, 2.0]. A single scale factor for the entire sequence would be dominated by the worst-case token, wasting precision for all the others.
Per-token scaling gives each token’s K (or V) vector its own scale factor. This adds minimal storage overhead (one FP32 or FP16 scale per token per layer per head) but dramatically improves quality because each token’s values use the full quantized range.
No Calibration Data Available
Weight quantization can use calibration data to determine optimal scale factors, group assignments, and channel priorities. KV cache quantization has no access to future tokens — it must quantize each K/V vector as it is produced, based only on that vector’s own statistics.
KV cache quantization happens on the critical path of token generation. Any overhead — computing scale factors, performing the quantization — directly increases per-token latency. The quantization must be fused into the attention kernel or the KV cache write path to avoid extra memory traffic.
Per-Token Scaling Implementation
Per-token scaling computes one scale factor per token, per layer, per attention head. For each token , head , and layer :
where is the maximum representable value in the target format (127 for INT8, 448 for FP8 E4M3, 7 for INT4).
import torch
import torch.nn.functional as F
class KVCacheQuantizer:
"""Online KV cache quantizer with per-token scaling."""
def __init__(self, precision='fp8', head_dim=128):
"""
precision: 'fp8', 'int8', or 'int4'
head_dim: dimension of each attention head
"""
self.precision = precision
self.head_dim = head_dim
if precision == 'fp8':
self.qmax = 448.0
self.qmin = -448.0
self.dtype = torch.float8_e4m3fn if hasattr(torch, 'float8_e4m3fn') else torch.int8
elif precision == 'int8':
self.qmax = 127.0
self.qmin = -128.0
self.dtype = torch.int8
elif precision == 'int4':
self.qmax = 7.0
self.qmin = -8.0
self.dtype = torch.int8 # Store INT4 in INT8 container
else:
raise ValueError(f"Unknown precision: {precision}")
def quantize_token(self, kv_vector):
"""Quantize a single token's K or V vector.
kv_vector: (num_heads, head_dim) -- one token's K or V across all heads
Returns:
quantized: (num_heads, head_dim) in target dtype
scales: (num_heads, 1) per-head scale factors
"""
# Per-head scaling: one scale per attention head
amax = kv_vector.abs().amax(dim=-1, keepdim=True) # (num_heads, 1)
scales = amax / self.qmax
scales = scales.clamp(min=1e-12)
quantized = (kv_vector / scales).round().clamp(self.qmin, self.qmax)
quantized = quantized.to(self.dtype)
return quantized, scales
def dequantize_token(self, quantized, scales):
"""Dequantize a single token's K or V vector.
quantized: (num_heads, head_dim) quantized values
scales: (num_heads, 1) scale factors
Returns: (num_heads, head_dim) FP16/FP32 values
"""
return quantized.float() * scales
def quantize_kv_pair(self, key, value):
"""Quantize both K and V for a single token.
key: (num_heads, head_dim)
value: (num_heads, head_dim)
Returns: (q_key, k_scale, q_value, v_scale)
"""
q_key, k_scale = self.quantize_token(key)
q_value, v_scale = self.quantize_token(value)
return q_key, k_scale, q_value, v_scale
Scale Factor Storage Overhead
Per-token scaling adds one scale factor (typically FP16 or FP32) per token per head per K/V. For Llama 70B:
The KV values per token at INT8 are:
Scale overhead is — negligible.
Complete Online KV Cache with Quantization
class QuantizedKVCache:
"""KV cache with online quantization during generation.
Supports FP8, INT8, and INT4 precision with per-token scaling.
"""
def __init__(self, num_layers, num_heads, head_dim, max_seq_len,
precision='fp8', device='cuda'):
self.num_layers = num_layers
self.num_heads = num_heads
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.precision = precision
self.device = device
self.quantizer = KVCacheQuantizer(precision, head_dim)
# Determine storage dtype
store_dtype = torch.int8 # Covers INT8, INT4, and simulated FP8
# Allocate quantized KV storage
# Shape: (num_layers, max_seq_len, num_heads, head_dim)
self.k_cache = torch.zeros(
num_layers, max_seq_len, num_heads, head_dim,
dtype=store_dtype, device=device
)
self.v_cache = torch.zeros(
num_layers, max_seq_len, num_heads, head_dim,
dtype=store_dtype, device=device
)
# Scale factors: (num_layers, max_seq_len, num_heads, 1)
self.k_scales = torch.zeros(
num_layers, max_seq_len, num_heads, 1,
dtype=torch.float16, device=device
)
self.v_scales = torch.zeros(
num_layers, max_seq_len, num_heads, 1,
dtype=torch.float16, device=device
)
self.seq_len = 0
def append(self, layer_idx, key, value):
"""Append a new token's K,V to the cache.
key: (num_heads, head_dim) -- this token's key vectors
value: (num_heads, head_dim) -- this token's value vectors
Called once per layer per generated token.
"""
pos = self.seq_len
# Quantize K and V for this token
q_key, k_scale, q_value, v_scale = self.quantizer.quantize_kv_pair(key, value)
# Store in cache
self.k_cache[layer_idx, pos] = q_key
self.v_cache[layer_idx, pos] = q_value
self.k_scales[layer_idx, pos] = k_scale.half()
self.v_scales[layer_idx, pos] = v_scale.half()
def advance_position(self):
"""Call after all layers have appended for a token."""
self.seq_len += 1
def get_keys(self, layer_idx):
"""Get dequantized keys for attention computation.
Returns: (seq_len, num_heads, head_dim) FP16 tensor
"""
q_keys = self.k_cache[layer_idx, :self.seq_len] # (seq, heads, dim)
scales = self.k_scales[layer_idx, :self.seq_len] # (seq, heads, 1)
return q_keys.float() * scales.float()
def get_values(self, layer_idx):
"""Get dequantized values for attention computation.
Returns: (seq_len, num_heads, head_dim) FP16 tensor
"""
q_vals = self.v_cache[layer_idx, :self.seq_len]
scales = self.v_scales[layer_idx, :self.seq_len]
return q_vals.float() * scales.float()
def memory_usage(self):
"""Report memory usage in bytes."""
seq = self.seq_len
kv_bytes = 2 * self.num_layers * seq * self.num_heads * self.head_dim
if self.precision == 'int4':
kv_bytes = kv_bytes // 2 # Pack 2 INT4 values per byte
scale_bytes = 2 * self.num_layers * seq * self.num_heads * 2 # FP16 scales
return {
'kv_data_bytes': kv_bytes,
'scale_bytes': scale_bytes,
'total_bytes': kv_bytes + scale_bytes,
'total_gb': (kv_bytes + scale_bytes) / (1024 ** 3),
}
Simulating Autoregressive Generation with Quantized KV Cache
def simulate_generation(model_config, num_tokens=100, precision='fp8'):
"""Simulate autoregressive generation with quantized KV cache.
model_config: dict with num_layers, num_heads, head_dim
"""
cache = QuantizedKVCache(
num_layers=model_config['num_layers'],
num_heads=model_config['num_heads'],
head_dim=model_config['head_dim'],
max_seq_len=num_tokens + 1024,
precision=precision,
)
# Simulate generating tokens
for token_idx in range(num_tokens):
for layer_idx in range(model_config['num_layers']):
# Simulate K,V from this layer's attention computation
key = torch.randn(model_config['num_heads'],
model_config['head_dim']) * 0.5
value = torch.randn(model_config['num_heads'],
model_config['head_dim']) * 0.3
# Quantize and store
cache.append(layer_idx, key, value)
# During attention: retrieve and dequantize all previous K,V
if token_idx > 0:
all_keys = cache.get_keys(layer_idx)
all_values = cache.get_values(layer_idx)
# Attention computation would happen here
cache.advance_position()
mem = cache.memory_usage()
print(f"Generated {num_tokens} tokens with {precision} KV cache")
print(f"KV cache memory: {mem['total_gb']:.3f} GB")
return cache
# Llama 70B config
llama70b = {
'num_layers': 80,
'num_heads': 8, # GQA KV heads
'head_dim': 128,
}
for prec in ['fp8', 'int8', 'int4']:
simulate_generation(llama70b, num_tokens=4096, precision=prec)
print()
FP8 KV Cache: The Sweet Spot
FP8 E4M3 KV cache is the most common production choice. It provides 2x memory savings with minimal quality loss — typically less than 0.1 perplexity points on standard benchmarks.
Why FP8 Works So Well for KV
KV cache values have a natural distribution that FP8 handles well:
-
K values after RoPE (Rotary Position Embedding) are bounded and roughly symmetric. The magnitude depends on the head dimension normalization () and typically falls in [-2, 2] for most heads.
-
V values are projections of the hidden state, roughly Gaussian with occasional moderate outliers. Unlike activations before linear layers, V values do not exhibit the extreme channel-wise outliers that plague INT8 activation quantization.
-
FP8’s non-uniform spacing provides more resolution near zero where the density of KV values is highest.
def analyze_kv_distribution(num_layers=80, num_heads=8, head_dim=128,
seq_len=2048):
"""Analyze the distribution of KV cache values to understand
why FP8 works well.
"""
# Simulate realistic KV distributions
# K values: post-RoPE, roughly Gaussian with head-dependent scale
k_values = torch.randn(num_layers, seq_len, num_heads, head_dim) * 0.3
# V values: linear projection of hidden state
v_values = torch.randn(num_layers, seq_len, num_heads, head_dim) * 0.5
for name, tensor in [("K", k_values), ("V", v_values)]:
flat = tensor.flatten()
print(f"{name} value statistics:")
print(f" Mean: {flat.mean():.4f}")
print(f" Std: {flat.std():.4f}")
print(f" Min: {flat.min():.4f}")
print(f" Max: {flat.max():.4f}")
print(f" Abs max: {flat.abs().max():.4f}")
# FP8 E4M3 quantization error
amax = flat.abs().max()
scale = 448.0 / amax
fp8_q = (flat * scale).clamp(-448, 448).round() / scale
mse = ((flat - fp8_q) ** 2).mean()
snr = 10 * torch.log10(flat.pow(2).mean() / mse)
print(f" FP8 MSE: {mse:.8f}")
print(f" FP8 SNR: {snr:.1f} dB")
print()
analyze_kv_distribution()
KV Cache Quantization Quality (Llama 70B, WikiText-2 Perplexity)
| KV Precision | Scaling | Perplexity | Degradation | Memory Savings |
|---|---|---|---|---|
| FP16 (baseline) | N/A | 3.32 | 0.00 | 1.0x |
| FP8 E4M3 | Per-token | 3.33 | +0.01 | 2.0x |
| INT8 | Per-token | 3.34 | +0.02 | 2.0x |
| INT8 | Per-tensor | 3.41 | +0.09 | 2.0x |
| INT4 | Per-token | 3.48 | +0.16 | 4.0x |
| INT4 | Per-group (g32) | 3.42 | +0.10 | 3.5x |
| INT4 | Per-tensor | 4.21 | +0.89 | 4.0x |
INT4 KV Cache: Maximum Compression
INT4 KV cache provides 4x memory savings but with measurable quality degradation. The key to making INT4 KV viable is aggressive per-token (or per-group) scaling.
Per-Token INT4 Quantization
class INT4KVQuantizer:
"""INT4 KV cache quantizer with per-token or per-group scaling."""
def __init__(self, group_size=None):
"""
group_size: None for per-token scaling (one scale per head),
or integer for per-group scaling (one scale per group within head)
"""
self.group_size = group_size
def quantize_per_token(self, kv_vector):
"""Per-token INT4 quantization.
kv_vector: (num_heads, head_dim)
Returns: (quantized, scales) where scales is (num_heads, 1)
"""
amax = kv_vector.abs().amax(dim=-1, keepdim=True)
scale = amax / 7.0
scale = scale.clamp(min=1e-12)
quantized = (kv_vector / scale).round().clamp(-8, 7).to(torch.int8)
return quantized, scale
def quantize_per_group(self, kv_vector):
"""Per-group INT4 quantization for finer granularity.
kv_vector: (num_heads, head_dim)
Returns: (quantized, scales) where scales is (num_heads, head_dim // group_size)
"""
num_heads, head_dim = kv_vector.shape
gs = self.group_size
assert head_dim % gs == 0
grouped = kv_vector.reshape(num_heads, -1, gs)
amax = grouped.abs().amax(dim=-1, keepdim=True)
scale = amax / 7.0
scale = scale.clamp(min=1e-12)
quantized = (grouped / scale).round().clamp(-8, 7).to(torch.int8)
quantized = quantized.reshape(num_heads, head_dim)
scale = scale.squeeze(-1) # (num_heads, head_dim // gs)
return quantized, scale
def quantize(self, kv_vector):
"""Quantize using configured granularity."""
if self.group_size is None:
return self.quantize_per_token(kv_vector)
return self.quantize_per_group(kv_vector)
def dequantize_per_token(self, quantized, scale):
return quantized.float() * scale
def dequantize_per_group(self, quantized, scale):
num_heads, head_dim = quantized.shape
gs = self.group_size
grouped = quantized.reshape(num_heads, -1, gs)
scale_expanded = scale.unsqueeze(-1)
return (grouped.float() * scale_expanded).reshape(num_heads, head_dim)
def dequantize(self, quantized, scale):
if self.group_size is None:
return self.dequantize_per_token(quantized, scale)
return self.dequantize_per_group(quantized, scale)
INT4 KV Bit Packing
In production, two INT4 values are packed into a single byte to achieve the full 4x memory savings:
def pack_int4(values):
"""Pack pairs of INT4 values into bytes.
values: (N,) tensor of int8 values in [-8, 7] range
Returns: (N//2,) tensor of uint8 packed bytes
"""
assert len(values) % 2 == 0
# Convert to unsigned: add 8 to map [-8,7] to [0,15]
unsigned = (values + 8).to(torch.uint8)
# Pack: high nibble = even indices, low nibble = odd indices
packed = (unsigned[0::2] << 4) | unsigned[1::2]
return packed
def unpack_int4(packed):
"""Unpack bytes to pairs of INT4 values.
packed: (N//2,) tensor of uint8
Returns: (N,) tensor of int8 values in [-8, 7]
"""
high = (packed >> 4).to(torch.int8) - 8
low = (packed & 0x0F).to(torch.int8) - 8
return torch.stack([high, low], dim=-1).flatten()
# Verify round-trip
original = torch.randint(-8, 8, (128,), dtype=torch.int8)
packed = pack_int4(original)
unpacked = unpack_int4(packed)
assert torch.all(original == unpacked)
print(f"Original: {len(original)} bytes, Packed: {len(packed)} bytes")
# 128 bytes -> 64 bytes
Attention with Quantized KV Cache
The attention kernel must dequantize K and V before computing attention scores. In production, this dequantization is fused into the attention kernel to avoid materializing the full FP16 K/V tensors in memory.
def attention_with_quantized_kv(query, k_cache_q, k_scales, v_cache_q, v_scales,
quantizer, head_dim=128):
"""Compute attention using quantized KV cache.
query: (num_heads, head_dim) -- current token's query
k_cache_q: (seq_len, num_heads, head_dim) -- quantized keys
k_scales: (seq_len, num_heads, ...) -- key scale factors
v_cache_q: (seq_len, num_heads, head_dim) -- quantized values
v_scales: (seq_len, num_heads, ...) -- value scale factors
Returns: (num_heads, head_dim) attention output
"""
seq_len = k_cache_q.shape[0]
num_heads = query.shape[0]
# Dequantize K: (seq_len, num_heads, head_dim)
keys_fp = torch.zeros(seq_len, num_heads, head_dim)
values_fp = torch.zeros(seq_len, num_heads, head_dim)
for t in range(seq_len):
keys_fp[t] = quantizer.dequantize(k_cache_q[t], k_scales[t])
values_fp[t] = quantizer.dequantize(v_cache_q[t], v_scales[t])
# Attention scores: Q @ K^T / sqrt(d)
# query: (num_heads, head_dim)
# keys: (seq_len, num_heads, head_dim)
scale = head_dim ** -0.5
scores = torch.einsum('hd,shd->hs', query.float(), keys_fp.float()) * scale
# Softmax
attn_weights = F.softmax(scores, dim=-1) # (num_heads, seq_len)
# Weighted sum of values
output = torch.einsum('hs,shd->hd', attn_weights, values_fp.float())
return output
def benchmark_kv_precision(head_dim=128, num_heads=8, seq_len=2048):
"""Compare attention output quality across KV precisions."""
query = torch.randn(num_heads, head_dim) * (head_dim ** -0.5)
keys_fp16 = torch.randn(seq_len, num_heads, head_dim) * 0.3
values_fp16 = torch.randn(seq_len, num_heads, head_dim) * 0.5
# Reference: FP16 attention
scale = head_dim ** -0.5
scores_ref = torch.einsum('hd,shd->hs', query, keys_fp16) * scale
attn_ref = F.softmax(scores_ref, dim=-1)
output_ref = torch.einsum('hs,shd->hd', attn_ref, values_fp16)
for precision in ['fp8', 'int8', 'int4']:
quantizer = KVCacheQuantizer(precision, head_dim)
# Quantize all K,V
k_q = torch.zeros(seq_len, num_heads, head_dim, dtype=torch.int8)
v_q = torch.zeros(seq_len, num_heads, head_dim, dtype=torch.int8)
k_s = torch.zeros(seq_len, num_heads, 1)
v_s = torch.zeros(seq_len, num_heads, 1)
for t in range(seq_len):
qk, sk, qv, sv = quantizer.quantize_kv_pair(
keys_fp16[t], values_fp16[t]
)
k_q[t], k_s[t] = qk, sk
v_q[t], v_s[t] = qv, sv
output_q = attention_with_quantized_kv(
query, k_q, k_s, v_q, v_s, quantizer, head_dim
)
mse = ((output_ref - output_q) ** 2).mean().item()
cos_sim = F.cosine_similarity(
output_ref.flatten().unsqueeze(0),
output_q.flatten().unsqueeze(0)
).item()
print(f"{precision:5s} KV: MSE={mse:.8f}, cos_sim={cos_sim:.6f}")
benchmark_kv_precision()
Quality-Memory Tradeoff Analysis
The right KV precision depends on your serving constraints and quality requirements.
Quality vs Memory Savings Tradeoff (Llama 70B)
(perplexity degradation)Decision Framework
def recommend_kv_precision(model_size_b, target_batch, seq_len,
gpu_memory_gb=80, weight_precision='int4'):
"""Recommend KV cache precision based on serving constraints."""
# Estimate model weight memory
if weight_precision == 'int4':
weight_gb = model_size_b * 0.5 / 1e9 # 0.5 bytes per param
elif weight_precision == 'fp8':
weight_gb = model_size_b * 1.0 / 1e9
else:
weight_gb = model_size_b * 2.0 / 1e9
available_gb = gpu_memory_gb - weight_gb - 2 # 2 GB overhead
# Estimate KV per request (Llama-style GQA)
# Rough: 2 * num_layers * kv_heads * head_dim * seq_len * bytes_per_elem
# For Llama 70B: ~0.33 GB/request at FP16 for seq=4096
kv_per_request_fp16 = 0.33 * (seq_len / 4096) * (model_size_b / 70e9)
results = {}
for prec, divisor, quality_note in [
('FP16', 1.0, 'lossless'),
('FP8', 2.0, 'near-lossless (0.01 PPL)'),
('INT8', 2.0, 'near-lossless (0.02 PPL)'),
('INT4', 4.0, 'slight degradation (0.1-0.2 PPL)'),
]:
kv_per_req = kv_per_request_fp16 / divisor
max_batch = int(available_gb / kv_per_req) if kv_per_req > 0 else 0
fits = max_batch >= target_batch
results[prec] = {
'kv_per_request_gb': kv_per_req,
'max_batch': max_batch,
'fits': fits,
'quality': quality_note,
}
# Find recommendation
for prec in ['FP8', 'INT8', 'INT4']:
if results[prec]['fits']:
print(f"Recommendation: {prec} KV cache")
print(f" Quality: {results[prec]['quality']}")
print(f" Max batch: {results[prec]['max_batch']} "
f"(target: {target_batch})")
return prec
print("WARNING: Even INT4 KV cannot fit the target batch size.")
print("Consider: tensor parallelism, shorter context, or more GPUs.")
return None
# Example: Llama 70B on single H100, targeting batch=64
recommend_kv_precision(
model_size_b=70e9,
target_batch=64,
seq_len=4096,
gpu_memory_gb=80,
weight_precision='int4'
)
K vs V Quantization Sensitivity
An important but often overlooked detail: K and V have different sensitivity to quantization error.
K quantization errors affect attention score computation. Errors in K shift the dot products , which are then passed through softmax. Small errors in K can cause the softmax to redistribute attention weight incorrectly.
V quantization errors affect the output directly. The output is . Errors in V are linearly weighted by the (correct) attention distribution. If the attention is concentrated on a few tokens, only those tokens’ V errors matter.
In practice, K is more sensitive than V. Some production systems quantize K to FP8 and V to INT4, or use per-group scaling for K and per-token scaling for V.
def measure_k_v_sensitivity(head_dim=128, num_heads=8, seq_len=1024):
"""Measure the relative sensitivity of K vs V to quantization."""
query = torch.randn(num_heads, head_dim) * (head_dim ** -0.5)
keys = torch.randn(seq_len, num_heads, head_dim) * 0.3
values = torch.randn(seq_len, num_heads, head_dim) * 0.5
# Reference output
scale = head_dim ** -0.5
scores = torch.einsum('hd,shd->hs', query, keys) * scale
attn = F.softmax(scores, dim=-1)
output_ref = torch.einsum('hs,shd->hd', attn, values)
# Quantize only K (INT8), keep V in FP16
k_quant = KVCacheQuantizer('int8', head_dim)
k_q_all, k_s_all = [], []
for t in range(seq_len):
kq, ks = k_quant.quantize_token(keys[t])
k_q_all.append(kq)
k_s_all.append(ks)
keys_deq = torch.stack([k_quant.dequantize_token(k_q_all[t], k_s_all[t])
for t in range(seq_len)])
scores_kq = torch.einsum('hd,shd->hs', query, keys_deq) * scale
attn_kq = F.softmax(scores_kq, dim=-1)
output_kq = torch.einsum('hs,shd->hd', attn_kq, values)
mse_konly = ((output_ref - output_kq) ** 2).mean().item()
# Quantize only V (INT8), keep K in FP16
v_quant = KVCacheQuantizer('int8', head_dim)
v_q_all, v_s_all = [], []
for t in range(seq_len):
vq, vs = v_quant.quantize_token(values[t])
v_q_all.append(vq)
v_s_all.append(vs)
values_deq = torch.stack([v_quant.dequantize_token(v_q_all[t], v_s_all[t])
for t in range(seq_len)])
output_vq = torch.einsum('hs,shd->hd', attn, values_deq)
mse_vonly = ((output_ref - output_vq) ** 2).mean().item()
print(f"K-only INT8 MSE: {mse_konly:.8f}")
print(f"V-only INT8 MSE: {mse_vonly:.8f}")
print(f"K/V sensitivity ratio: {mse_konly / mse_vonly:.2f}x")
measure_k_v_sensitivity()
Some systems use different precision for K and V. For example, FP8 for K (more sensitive) and INT4 for V (less sensitive). This gives memory savings closer to INT4 while maintaining quality closer to FP8, because the attention distribution (determined by K) is accurate, and only the value aggregation (determined by V) has reduced precision.
Production Integration: vLLM and TensorRT-LLM
Both vLLM and TensorRT-LLM support KV cache quantization as a runtime configuration option.
vLLM implements FP8 KV cache quantization natively. The quantization is fused into the paged attention kernel — each page stores quantized KV values with per-token scale factors. No separate dequantization kernel is needed.
TensorRT-LLM supports FP8 and INT8 KV cache through its attention plugins. The user specifies the KV precision in the model configuration, and the engine builder generates optimized kernels.
# vLLM configuration for FP8 KV cache (conceptual)
# from vllm import LLM, SamplingParams
#
# llm = LLM(
# model="meta-llama/Llama-2-70b",
# quantization="awq", # INT4 weights
# kv_cache_dtype="fp8_e4m3", # FP8 KV cache
# max_model_len=8192,
# gpu_memory_utilization=0.9,
# )
#
# # The combination of INT4 weights + FP8 KV enables:
# # - 70B model on a single H100-80GB
# # - Batch size ~64 at seq_len=4096
# # - Near-lossless quality
# TensorRT-LLM configuration (conceptual)
# trtllm-build \
# --checkpoint_dir ./llama-70b-awq/ \
# --kv_cache_type FP8 \
# --max_batch_size 64 \
# --max_input_len 4096 \
# --max_seq_len 8192
Sequence Length Scaling: How Quality Degrades
KV cache quantization error accumulates with sequence length. Each new token attends to all previous tokens, and the quantization errors in early tokens affect every subsequent attention computation.
def measure_quality_vs_seqlen(seq_lens, precision='int4', num_heads=8,
head_dim=128):
"""Measure how KV quantization quality degrades with sequence length."""
results = []
quantizer = KVCacheQuantizer(precision, head_dim)
for seq_len in seq_lens:
query = torch.randn(num_heads, head_dim) * (head_dim ** -0.5)
keys = torch.randn(seq_len, num_heads, head_dim) * 0.3
values = torch.randn(seq_len, num_heads, head_dim) * 0.5
# Reference
scale = head_dim ** -0.5
scores = torch.einsum('hd,shd->hs', query, keys) * scale
attn = F.softmax(scores, dim=-1)
output_ref = torch.einsum('hs,shd->hd', attn, values)
# Quantized
k_q = torch.zeros(seq_len, num_heads, head_dim, dtype=torch.int8)
v_q = torch.zeros(seq_len, num_heads, head_dim, dtype=torch.int8)
k_s = torch.zeros(seq_len, num_heads, 1)
v_s = torch.zeros(seq_len, num_heads, 1)
for t in range(seq_len):
qk, sk, qv, sv = quantizer.quantize_kv_pair(keys[t], values[t])
k_q[t], k_s[t] = qk, sk
v_q[t], v_s[t] = qv, sv
output_q = attention_with_quantized_kv(
query, k_q, k_s, v_q, v_s, quantizer, head_dim
)
mse = ((output_ref - output_q) ** 2).mean().item()
cos_sim = F.cosine_similarity(
output_ref.flatten().unsqueeze(0),
output_q.flatten().unsqueeze(0)
).item()
results.append({
'seq_len': seq_len,
'mse': mse,
'cos_sim': cos_sim,
})
return results
seq_lens = [256, 512, 1024, 2048, 4096, 8192]
for prec in ['fp8', 'int8', 'int4']:
print(f"\n{prec} KV cache quality vs sequence length:")
results = measure_quality_vs_seqlen(seq_lens, precision=prec)
for r in results:
print(f" seq={r['seq_len']:5d}: MSE={r['mse']:.8f}, "
f"cos_sim={r['cos_sim']:.6f}")
Quality Degradation vs Sequence Length (Llama 70B, INT4 KV Per-Token)
| Seq Length | PPL Degradation | MMLU Impact | Acceptable? |
|---|---|---|---|
| 512 | +0.05 | -0.1% | Yes |
| 2048 | +0.10 | -0.3% | Yes |
| 4096 | +0.16 | -0.5% | Marginal |
| 8192 | +0.28 | -1.1% | Task-dependent |
| 16384 | +0.52 | -2.3% | Consider FP8 |
Summary
KV cache quantization is fundamentally different from weight quantization: it operates on dynamic data generated during inference, requires online quantization with zero latency overhead, and uses per-token scaling to handle the variable distributions across sequence positions.
FP8 E4M3 KV is the production sweet spot: 2x memory savings with less than 0.02 perplexity degradation. The non-uniform FP8 representation naturally handles the distributions found in key and value projections.
INT8 KV with per-token scaling provides equivalent memory savings to FP8 with marginally higher error. It is preferred on hardware without FP8 support (pre-Hopper GPUs).
INT4 KV with per-token or per-group scaling provides 4x memory savings (3.5x with scale overhead) but with measurable quality degradation (0.1-0.2 PPL at moderate sequence lengths). Quality degrades further at very long sequence lengths.
Per-token scaling is essential for any KV precision below FP16. Without it, a single outlier token dominates the scale factor and wastes precision for all other tokens.
K is more sensitive than V to quantization error, because K errors affect attention score computation (pre-softmax) while V errors are linearly attenuated by the attention distribution. Mixed-precision approaches (FP8 K, INT4 V) can exploit this asymmetry.
This concludes the Quantization Masterclass series. From number formats (Part 1) through weight quantization (Part 2), activation quantization (Part 3), FP8 training and inference (Part 4), FP4 on Blackwell (Part 5), and KV cache quantization (Part 6), you now have a complete technical foundation for every quantization decision in modern AI systems.