Part of Series Quantization Masterclass 6 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest — Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier — Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 — How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization — The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization — OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

KV cache quantization is the highest-leverage memory optimization in LLM serving. At production batch sizes and sequence lengths, the KV cache dominates GPU memory — often consuming 2-5x more than the model weights themselves. Quantizing the KV cache from FP16 to FP8 halves this memory, doubling the number of concurrent requests you can serve. Quantizing to INT4 quarters it, enabling 4x more requests on the same hardware.

This is distinct from weight quantization (Parts 2 and 5) in every way that matters. Weights are static — quantize once, serve forever. The KV cache is dynamic — new key-value pairs are generated for every token, and they must be quantized online during generation with zero additional latency. Weights can use offline calibration data. KV cache quantization must work with whatever distribution the current request produces. The engineering constraints are completely different.

This post covers why KV cache deserves its own quantization strategy, per-token scaling and its implementation, FP8/INT8/INT4 KV cache at the algorithm level, quality impact across model sizes and precision targets, and a complete implementation of online KV cache quantization during autoregressive generation.

KV Cache Memory: The Serving Bottleneck

For a transformer model with LL layers, HH KV heads, head dimension dd, serving BB requests at sequence length SS:

KV cache (bytes)=2×L×H×d×S×B×bytes_per_element\text{KV cache (bytes)} = 2 \times L \times H \times d \times S \times B \times \text{bytes\_per\_element}

The factor of 2 is for K and V. For Llama 70B (L=80L=80, H=8H=8 GQA KV heads, d=128d=128):

FP16 KV=2×80×8×128×S×B×2=327680×S×B bytes\text{FP16 KV} = 2 \times 80 \times 8 \times 128 \times S \times B \times 2 = 327680 \times S \times B \text{ bytes}

At S=4096S=4096, B=64B=64: 85.9 GB for KV cache alone. The model weights (even at INT4) are only 35 GB. The KV cache is 2.5x the model.

Memory Split: Model Weights vs KV Cache (Llama 70B, seq=4096)

(GB)
INT4 Weights
35 GB
FP16 KV (B=32) 1.2x weights
43 GB
FP16 KV (B=64) 2.5x weights
86 GB
FP8 KV (B=64) 1.2x weights
43 GB
INT4 KV (B=64) 0.6x weights
21 GB

Quantizing the KV cache is the only way to serve at high batch sizes without adding more GPUs. Each halving of KV precision roughly doubles the batch capacity:

📊

Maximum Batch Size by KV Precision (Llama 70B, INT4 Weights, H100-80GB, seq=4096)

KV PrecisionKV per RequestKV at Max BatchMax BatchImprovement
FP16 1.34 GB 42.9 GB 32 1.0x
FP8 E4M3 0.67 GB 42.9 GB 64 2.0x
INT8 0.67 GB 42.9 GB 64 2.0x
INT4 0.34 GB 40.8 GB 120 3.75x
Note: Assumes 35 GB for INT4 model weights, leaving 45 GB for KV cache. KV per request = 2 * 80 * 8 * 128 * 4096 * bytes_per_element.

Why KV Cache Quantization Is Different

Dynamic vs Static

Model weights are fixed at load time. You can spend hours running GPTQ or AWQ with calibration data to find optimal scale factors. KV cache values are generated token-by-token during inference. Each new token produces a new K and V vector that must be quantized immediately.

Per-Token vs Per-Tensor

Each token’s K and V vectors have their own distribution. Token 1 might have key values in the range [-0.5, 0.5], while token 500 might have values in [-2.0, 2.0]. A single scale factor for the entire sequence would be dominated by the worst-case token, wasting precision for all the others.

Per-token scaling gives each token’s K (or V) vector its own scale factor. This adds minimal storage overhead (one FP32 or FP16 scale per token per layer per head) but dramatically improves quality because each token’s values use the full quantized range.

No Calibration Data Available

Weight quantization can use calibration data to determine optimal scale factors, group assignments, and channel priorities. KV cache quantization has no access to future tokens — it must quantize each K/V vector as it is produced, based only on that vector’s own statistics.

⚠️ Online Quantization Constraint

KV cache quantization happens on the critical path of token generation. Any overhead — computing scale factors, performing the quantization — directly increases per-token latency. The quantization must be fused into the attention kernel or the KV cache write path to avoid extra memory traffic.

Per-Token Scaling Implementation

Per-token scaling computes one scale factor per token, per layer, per attention head. For each token tt, head hh, and layer ll:

st,h,lK=max(Kt,h,l)Qmaxs_{t,h,l}^K = \frac{\max(|K_{t,h,l}|)}{Q_{\max}}

where QmaxQ_{\max} is the maximum representable value in the target format (127 for INT8, 448 for FP8 E4M3, 7 for INT4).

import torch
import torch.nn.functional as F

class KVCacheQuantizer:
    """Online KV cache quantizer with per-token scaling."""

    def __init__(self, precision='fp8', head_dim=128):
        """
        precision: 'fp8', 'int8', or 'int4'
        head_dim: dimension of each attention head
        """
        self.precision = precision
        self.head_dim = head_dim

        if precision == 'fp8':
            self.qmax = 448.0
            self.qmin = -448.0
            self.dtype = torch.float8_e4m3fn if hasattr(torch, 'float8_e4m3fn') else torch.int8
        elif precision == 'int8':
            self.qmax = 127.0
            self.qmin = -128.0
            self.dtype = torch.int8
        elif precision == 'int4':
            self.qmax = 7.0
            self.qmin = -8.0
            self.dtype = torch.int8  # Store INT4 in INT8 container
        else:
            raise ValueError(f"Unknown precision: {precision}")

    def quantize_token(self, kv_vector):
        """Quantize a single token's K or V vector.

        kv_vector: (num_heads, head_dim) -- one token's K or V across all heads

        Returns:
            quantized: (num_heads, head_dim) in target dtype
            scales: (num_heads, 1) per-head scale factors
        """
        # Per-head scaling: one scale per attention head
        amax = kv_vector.abs().amax(dim=-1, keepdim=True)  # (num_heads, 1)
        scales = amax / self.qmax
        scales = scales.clamp(min=1e-12)

        quantized = (kv_vector / scales).round().clamp(self.qmin, self.qmax)
        quantized = quantized.to(self.dtype)

        return quantized, scales

    def dequantize_token(self, quantized, scales):
        """Dequantize a single token's K or V vector.

        quantized: (num_heads, head_dim) quantized values
        scales: (num_heads, 1) scale factors

        Returns: (num_heads, head_dim) FP16/FP32 values
        """
        return quantized.float() * scales

    def quantize_kv_pair(self, key, value):
        """Quantize both K and V for a single token.

        key: (num_heads, head_dim)
        value: (num_heads, head_dim)

        Returns: (q_key, k_scale, q_value, v_scale)
        """
        q_key, k_scale = self.quantize_token(key)
        q_value, v_scale = self.quantize_token(value)
        return q_key, k_scale, q_value, v_scale

Scale Factor Storage Overhead

Per-token scaling adds one scale factor (typically FP16 or FP32) per token per head per K/V. For Llama 70B:

Scale overhead per token=2×L×H×scale_bytes=2×80×8×2=2560 bytes\text{Scale overhead per token} = 2 \times L \times H \times \text{scale\_bytes} = 2 \times 80 \times 8 \times 2 = 2560 \text{ bytes}

The KV values per token at INT8 are:

KV per token=2×80×8×128×1=163840 bytes\text{KV per token} = 2 \times 80 \times 8 \times 128 \times 1 = 163840 \text{ bytes}

Scale overhead is 2560/163840=1.6%2560 / 163840 = 1.6\% — negligible.

Complete Online KV Cache with Quantization

class QuantizedKVCache:
    """KV cache with online quantization during generation.

    Supports FP8, INT8, and INT4 precision with per-token scaling.
    """

    def __init__(self, num_layers, num_heads, head_dim, max_seq_len,
                 precision='fp8', device='cuda'):
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        self.precision = precision
        self.device = device

        self.quantizer = KVCacheQuantizer(precision, head_dim)

        # Determine storage dtype
        store_dtype = torch.int8  # Covers INT8, INT4, and simulated FP8

        # Allocate quantized KV storage
        # Shape: (num_layers, max_seq_len, num_heads, head_dim)
        self.k_cache = torch.zeros(
            num_layers, max_seq_len, num_heads, head_dim,
            dtype=store_dtype, device=device
        )
        self.v_cache = torch.zeros(
            num_layers, max_seq_len, num_heads, head_dim,
            dtype=store_dtype, device=device
        )

        # Scale factors: (num_layers, max_seq_len, num_heads, 1)
        self.k_scales = torch.zeros(
            num_layers, max_seq_len, num_heads, 1,
            dtype=torch.float16, device=device
        )
        self.v_scales = torch.zeros(
            num_layers, max_seq_len, num_heads, 1,
            dtype=torch.float16, device=device
        )

        self.seq_len = 0

    def append(self, layer_idx, key, value):
        """Append a new token's K,V to the cache.

        key: (num_heads, head_dim) -- this token's key vectors
        value: (num_heads, head_dim) -- this token's value vectors

        Called once per layer per generated token.
        """
        pos = self.seq_len

        # Quantize K and V for this token
        q_key, k_scale, q_value, v_scale = self.quantizer.quantize_kv_pair(key, value)

        # Store in cache
        self.k_cache[layer_idx, pos] = q_key
        self.v_cache[layer_idx, pos] = q_value
        self.k_scales[layer_idx, pos] = k_scale.half()
        self.v_scales[layer_idx, pos] = v_scale.half()

    def advance_position(self):
        """Call after all layers have appended for a token."""
        self.seq_len += 1

    def get_keys(self, layer_idx):
        """Get dequantized keys for attention computation.

        Returns: (seq_len, num_heads, head_dim) FP16 tensor
        """
        q_keys = self.k_cache[layer_idx, :self.seq_len]  # (seq, heads, dim)
        scales = self.k_scales[layer_idx, :self.seq_len]  # (seq, heads, 1)
        return q_keys.float() * scales.float()

    def get_values(self, layer_idx):
        """Get dequantized values for attention computation.

        Returns: (seq_len, num_heads, head_dim) FP16 tensor
        """
        q_vals = self.v_cache[layer_idx, :self.seq_len]
        scales = self.v_scales[layer_idx, :self.seq_len]
        return q_vals.float() * scales.float()

    def memory_usage(self):
        """Report memory usage in bytes."""
        seq = self.seq_len
        kv_bytes = 2 * self.num_layers * seq * self.num_heads * self.head_dim
        if self.precision == 'int4':
            kv_bytes = kv_bytes // 2  # Pack 2 INT4 values per byte
        scale_bytes = 2 * self.num_layers * seq * self.num_heads * 2  # FP16 scales

        return {
            'kv_data_bytes': kv_bytes,
            'scale_bytes': scale_bytes,
            'total_bytes': kv_bytes + scale_bytes,
            'total_gb': (kv_bytes + scale_bytes) / (1024 ** 3),
        }

Simulating Autoregressive Generation with Quantized KV Cache

def simulate_generation(model_config, num_tokens=100, precision='fp8'):
    """Simulate autoregressive generation with quantized KV cache.

    model_config: dict with num_layers, num_heads, head_dim
    """
    cache = QuantizedKVCache(
        num_layers=model_config['num_layers'],
        num_heads=model_config['num_heads'],
        head_dim=model_config['head_dim'],
        max_seq_len=num_tokens + 1024,
        precision=precision,
    )

    # Simulate generating tokens
    for token_idx in range(num_tokens):
        for layer_idx in range(model_config['num_layers']):
            # Simulate K,V from this layer's attention computation
            key = torch.randn(model_config['num_heads'],
                             model_config['head_dim']) * 0.5
            value = torch.randn(model_config['num_heads'],
                               model_config['head_dim']) * 0.3

            # Quantize and store
            cache.append(layer_idx, key, value)

            # During attention: retrieve and dequantize all previous K,V
            if token_idx > 0:
                all_keys = cache.get_keys(layer_idx)
                all_values = cache.get_values(layer_idx)
                # Attention computation would happen here

        cache.advance_position()

    mem = cache.memory_usage()
    print(f"Generated {num_tokens} tokens with {precision} KV cache")
    print(f"KV cache memory: {mem['total_gb']:.3f} GB")

    return cache

# Llama 70B config
llama70b = {
    'num_layers': 80,
    'num_heads': 8,  # GQA KV heads
    'head_dim': 128,
}

for prec in ['fp8', 'int8', 'int4']:
    simulate_generation(llama70b, num_tokens=4096, precision=prec)
    print()

FP8 KV Cache: The Sweet Spot

FP8 E4M3 KV cache is the most common production choice. It provides 2x memory savings with minimal quality loss — typically less than 0.1 perplexity points on standard benchmarks.

Why FP8 Works So Well for KV

KV cache values have a natural distribution that FP8 handles well:

  1. K values after RoPE (Rotary Position Embedding) are bounded and roughly symmetric. The magnitude depends on the head dimension normalization (1/d1/\sqrt{d}) and typically falls in [-2, 2] for most heads.

  2. V values are projections of the hidden state, roughly Gaussian with occasional moderate outliers. Unlike activations before linear layers, V values do not exhibit the extreme channel-wise outliers that plague INT8 activation quantization.

  3. FP8’s non-uniform spacing provides more resolution near zero where the density of KV values is highest.

def analyze_kv_distribution(num_layers=80, num_heads=8, head_dim=128,
                             seq_len=2048):
    """Analyze the distribution of KV cache values to understand
    why FP8 works well.
    """
    # Simulate realistic KV distributions
    # K values: post-RoPE, roughly Gaussian with head-dependent scale
    k_values = torch.randn(num_layers, seq_len, num_heads, head_dim) * 0.3
    # V values: linear projection of hidden state
    v_values = torch.randn(num_layers, seq_len, num_heads, head_dim) * 0.5

    for name, tensor in [("K", k_values), ("V", v_values)]:
        flat = tensor.flatten()
        print(f"{name} value statistics:")
        print(f"  Mean:     {flat.mean():.4f}")
        print(f"  Std:      {flat.std():.4f}")
        print(f"  Min:      {flat.min():.4f}")
        print(f"  Max:      {flat.max():.4f}")
        print(f"  Abs max:  {flat.abs().max():.4f}")

        # FP8 E4M3 quantization error
        amax = flat.abs().max()
        scale = 448.0 / amax
        fp8_q = (flat * scale).clamp(-448, 448).round() / scale
        mse = ((flat - fp8_q) ** 2).mean()
        snr = 10 * torch.log10(flat.pow(2).mean() / mse)
        print(f"  FP8 MSE:  {mse:.8f}")
        print(f"  FP8 SNR:  {snr:.1f} dB")
        print()

analyze_kv_distribution()
📊

KV Cache Quantization Quality (Llama 70B, WikiText-2 Perplexity)

KV PrecisionScalingPerplexityDegradationMemory Savings
FP16 (baseline) N/A 3.32 0.00 1.0x
FP8 E4M3 Per-token 3.33 +0.01 2.0x
INT8 Per-token 3.34 +0.02 2.0x
INT8 Per-tensor 3.41 +0.09 2.0x
INT4 Per-token 3.48 +0.16 4.0x
INT4 Per-group (g32) 3.42 +0.10 3.5x
INT4 Per-tensor 4.21 +0.89 4.0x
Note: FP8 with per-token scaling achieves near-lossless quality. INT4 with per-token scaling shows measurable but often acceptable degradation. Per-tensor scaling for INT4 is catastrophic.

INT4 KV Cache: Maximum Compression

INT4 KV cache provides 4x memory savings but with measurable quality degradation. The key to making INT4 KV viable is aggressive per-token (or per-group) scaling.

Per-Token INT4 Quantization

class INT4KVQuantizer:
    """INT4 KV cache quantizer with per-token or per-group scaling."""

    def __init__(self, group_size=None):
        """
        group_size: None for per-token scaling (one scale per head),
                    or integer for per-group scaling (one scale per group within head)
        """
        self.group_size = group_size

    def quantize_per_token(self, kv_vector):
        """Per-token INT4 quantization.

        kv_vector: (num_heads, head_dim)
        Returns: (quantized, scales) where scales is (num_heads, 1)
        """
        amax = kv_vector.abs().amax(dim=-1, keepdim=True)
        scale = amax / 7.0
        scale = scale.clamp(min=1e-12)

        quantized = (kv_vector / scale).round().clamp(-8, 7).to(torch.int8)
        return quantized, scale

    def quantize_per_group(self, kv_vector):
        """Per-group INT4 quantization for finer granularity.

        kv_vector: (num_heads, head_dim)
        Returns: (quantized, scales) where scales is (num_heads, head_dim // group_size)
        """
        num_heads, head_dim = kv_vector.shape
        gs = self.group_size
        assert head_dim % gs == 0

        grouped = kv_vector.reshape(num_heads, -1, gs)
        amax = grouped.abs().amax(dim=-1, keepdim=True)
        scale = amax / 7.0
        scale = scale.clamp(min=1e-12)

        quantized = (grouped / scale).round().clamp(-8, 7).to(torch.int8)
        quantized = quantized.reshape(num_heads, head_dim)
        scale = scale.squeeze(-1)  # (num_heads, head_dim // gs)

        return quantized, scale

    def quantize(self, kv_vector):
        """Quantize using configured granularity."""
        if self.group_size is None:
            return self.quantize_per_token(kv_vector)
        return self.quantize_per_group(kv_vector)

    def dequantize_per_token(self, quantized, scale):
        return quantized.float() * scale

    def dequantize_per_group(self, quantized, scale):
        num_heads, head_dim = quantized.shape
        gs = self.group_size
        grouped = quantized.reshape(num_heads, -1, gs)
        scale_expanded = scale.unsqueeze(-1)
        return (grouped.float() * scale_expanded).reshape(num_heads, head_dim)

    def dequantize(self, quantized, scale):
        if self.group_size is None:
            return self.dequantize_per_token(quantized, scale)
        return self.dequantize_per_group(quantized, scale)

INT4 KV Bit Packing

In production, two INT4 values are packed into a single byte to achieve the full 4x memory savings:

def pack_int4(values):
    """Pack pairs of INT4 values into bytes.

    values: (N,) tensor of int8 values in [-8, 7] range
    Returns: (N//2,) tensor of uint8 packed bytes
    """
    assert len(values) % 2 == 0
    # Convert to unsigned: add 8 to map [-8,7] to [0,15]
    unsigned = (values + 8).to(torch.uint8)
    # Pack: high nibble = even indices, low nibble = odd indices
    packed = (unsigned[0::2] << 4) | unsigned[1::2]
    return packed

def unpack_int4(packed):
    """Unpack bytes to pairs of INT4 values.

    packed: (N//2,) tensor of uint8
    Returns: (N,) tensor of int8 values in [-8, 7]
    """
    high = (packed >> 4).to(torch.int8) - 8
    low = (packed & 0x0F).to(torch.int8) - 8
    return torch.stack([high, low], dim=-1).flatten()

# Verify round-trip
original = torch.randint(-8, 8, (128,), dtype=torch.int8)
packed = pack_int4(original)
unpacked = unpack_int4(packed)
assert torch.all(original == unpacked)
print(f"Original: {len(original)} bytes, Packed: {len(packed)} bytes")
# 128 bytes -> 64 bytes

Attention with Quantized KV Cache

The attention kernel must dequantize K and V before computing attention scores. In production, this dequantization is fused into the attention kernel to avoid materializing the full FP16 K/V tensors in memory.

def attention_with_quantized_kv(query, k_cache_q, k_scales, v_cache_q, v_scales,
                                  quantizer, head_dim=128):
    """Compute attention using quantized KV cache.

    query: (num_heads, head_dim) -- current token's query
    k_cache_q: (seq_len, num_heads, head_dim) -- quantized keys
    k_scales: (seq_len, num_heads, ...) -- key scale factors
    v_cache_q: (seq_len, num_heads, head_dim) -- quantized values
    v_scales: (seq_len, num_heads, ...) -- value scale factors

    Returns: (num_heads, head_dim) attention output
    """
    seq_len = k_cache_q.shape[0]
    num_heads = query.shape[0]

    # Dequantize K: (seq_len, num_heads, head_dim)
    keys_fp = torch.zeros(seq_len, num_heads, head_dim)
    values_fp = torch.zeros(seq_len, num_heads, head_dim)
    for t in range(seq_len):
        keys_fp[t] = quantizer.dequantize(k_cache_q[t], k_scales[t])
        values_fp[t] = quantizer.dequantize(v_cache_q[t], v_scales[t])

    # Attention scores: Q @ K^T / sqrt(d)
    # query: (num_heads, head_dim)
    # keys: (seq_len, num_heads, head_dim)
    scale = head_dim ** -0.5
    scores = torch.einsum('hd,shd->hs', query.float(), keys_fp.float()) * scale

    # Softmax
    attn_weights = F.softmax(scores, dim=-1)  # (num_heads, seq_len)

    # Weighted sum of values
    output = torch.einsum('hs,shd->hd', attn_weights, values_fp.float())

    return output

def benchmark_kv_precision(head_dim=128, num_heads=8, seq_len=2048):
    """Compare attention output quality across KV precisions."""
    query = torch.randn(num_heads, head_dim) * (head_dim ** -0.5)
    keys_fp16 = torch.randn(seq_len, num_heads, head_dim) * 0.3
    values_fp16 = torch.randn(seq_len, num_heads, head_dim) * 0.5

    # Reference: FP16 attention
    scale = head_dim ** -0.5
    scores_ref = torch.einsum('hd,shd->hs', query, keys_fp16) * scale
    attn_ref = F.softmax(scores_ref, dim=-1)
    output_ref = torch.einsum('hs,shd->hd', attn_ref, values_fp16)

    for precision in ['fp8', 'int8', 'int4']:
        quantizer = KVCacheQuantizer(precision, head_dim)

        # Quantize all K,V
        k_q = torch.zeros(seq_len, num_heads, head_dim, dtype=torch.int8)
        v_q = torch.zeros(seq_len, num_heads, head_dim, dtype=torch.int8)
        k_s = torch.zeros(seq_len, num_heads, 1)
        v_s = torch.zeros(seq_len, num_heads, 1)

        for t in range(seq_len):
            qk, sk, qv, sv = quantizer.quantize_kv_pair(
                keys_fp16[t], values_fp16[t]
            )
            k_q[t], k_s[t] = qk, sk
            v_q[t], v_s[t] = qv, sv

        output_q = attention_with_quantized_kv(
            query, k_q, k_s, v_q, v_s, quantizer, head_dim
        )

        mse = ((output_ref - output_q) ** 2).mean().item()
        cos_sim = F.cosine_similarity(
            output_ref.flatten().unsqueeze(0),
            output_q.flatten().unsqueeze(0)
        ).item()

        print(f"{precision:5s} KV: MSE={mse:.8f}, cos_sim={cos_sim:.6f}")

benchmark_kv_precision()

Quality-Memory Tradeoff Analysis

The right KV precision depends on your serving constraints and quality requirements.

Quality vs Memory Savings Tradeoff (Llama 70B)

(perplexity degradation)
FP8 (2x savings) near-lossless
0.01 perplexity degradation
INT8 per-token (2x) near-lossless
0.02 perplexity degradation
INT4 per-group (3.5x) acceptable
0.1 perplexity degradation
INT4 per-token (4x) noticeable
0.16 perplexity degradation
INT4 per-tensor (4x) unacceptable
0.89 perplexity degradation

Decision Framework

def recommend_kv_precision(model_size_b, target_batch, seq_len,
                            gpu_memory_gb=80, weight_precision='int4'):
    """Recommend KV cache precision based on serving constraints."""

    # Estimate model weight memory
    if weight_precision == 'int4':
        weight_gb = model_size_b * 0.5 / 1e9  # 0.5 bytes per param
    elif weight_precision == 'fp8':
        weight_gb = model_size_b * 1.0 / 1e9
    else:
        weight_gb = model_size_b * 2.0 / 1e9

    available_gb = gpu_memory_gb - weight_gb - 2  # 2 GB overhead

    # Estimate KV per request (Llama-style GQA)
    # Rough: 2 * num_layers * kv_heads * head_dim * seq_len * bytes_per_elem
    # For Llama 70B: ~0.33 GB/request at FP16 for seq=4096
    kv_per_request_fp16 = 0.33 * (seq_len / 4096) * (model_size_b / 70e9)

    results = {}
    for prec, divisor, quality_note in [
        ('FP16', 1.0, 'lossless'),
        ('FP8',  2.0, 'near-lossless (0.01 PPL)'),
        ('INT8', 2.0, 'near-lossless (0.02 PPL)'),
        ('INT4', 4.0, 'slight degradation (0.1-0.2 PPL)'),
    ]:
        kv_per_req = kv_per_request_fp16 / divisor
        max_batch = int(available_gb / kv_per_req) if kv_per_req > 0 else 0
        fits = max_batch >= target_batch

        results[prec] = {
            'kv_per_request_gb': kv_per_req,
            'max_batch': max_batch,
            'fits': fits,
            'quality': quality_note,
        }

    # Find recommendation
    for prec in ['FP8', 'INT8', 'INT4']:
        if results[prec]['fits']:
            print(f"Recommendation: {prec} KV cache")
            print(f"  Quality: {results[prec]['quality']}")
            print(f"  Max batch: {results[prec]['max_batch']} "
                  f"(target: {target_batch})")
            return prec

    print("WARNING: Even INT4 KV cannot fit the target batch size.")
    print("Consider: tensor parallelism, shorter context, or more GPUs.")
    return None

# Example: Llama 70B on single H100, targeting batch=64
recommend_kv_precision(
    model_size_b=70e9,
    target_batch=64,
    seq_len=4096,
    gpu_memory_gb=80,
    weight_precision='int4'
)

K vs V Quantization Sensitivity

An important but often overlooked detail: K and V have different sensitivity to quantization error.

K quantization errors affect attention score computation. Errors in K shift the dot products QKTQ \cdot K^T, which are then passed through softmax. Small errors in K can cause the softmax to redistribute attention weight incorrectly.

V quantization errors affect the output directly. The output is softmax(QKT)V\text{softmax}(QK^T) \cdot V. Errors in V are linearly weighted by the (correct) attention distribution. If the attention is concentrated on a few tokens, only those tokens’ V errors matter.

In practice, K is more sensitive than V. Some production systems quantize K to FP8 and V to INT4, or use per-group scaling for K and per-token scaling for V.

def measure_k_v_sensitivity(head_dim=128, num_heads=8, seq_len=1024):
    """Measure the relative sensitivity of K vs V to quantization."""
    query = torch.randn(num_heads, head_dim) * (head_dim ** -0.5)
    keys = torch.randn(seq_len, num_heads, head_dim) * 0.3
    values = torch.randn(seq_len, num_heads, head_dim) * 0.5

    # Reference output
    scale = head_dim ** -0.5
    scores = torch.einsum('hd,shd->hs', query, keys) * scale
    attn = F.softmax(scores, dim=-1)
    output_ref = torch.einsum('hs,shd->hd', attn, values)

    # Quantize only K (INT8), keep V in FP16
    k_quant = KVCacheQuantizer('int8', head_dim)
    k_q_all, k_s_all = [], []
    for t in range(seq_len):
        kq, ks = k_quant.quantize_token(keys[t])
        k_q_all.append(kq)
        k_s_all.append(ks)
    keys_deq = torch.stack([k_quant.dequantize_token(k_q_all[t], k_s_all[t])
                            for t in range(seq_len)])
    scores_kq = torch.einsum('hd,shd->hs', query, keys_deq) * scale
    attn_kq = F.softmax(scores_kq, dim=-1)
    output_kq = torch.einsum('hs,shd->hd', attn_kq, values)
    mse_konly = ((output_ref - output_kq) ** 2).mean().item()

    # Quantize only V (INT8), keep K in FP16
    v_quant = KVCacheQuantizer('int8', head_dim)
    v_q_all, v_s_all = [], []
    for t in range(seq_len):
        vq, vs = v_quant.quantize_token(values[t])
        v_q_all.append(vq)
        v_s_all.append(vs)
    values_deq = torch.stack([v_quant.dequantize_token(v_q_all[t], v_s_all[t])
                              for t in range(seq_len)])
    output_vq = torch.einsum('hs,shd->hd', attn, values_deq)
    mse_vonly = ((output_ref - output_vq) ** 2).mean().item()

    print(f"K-only INT8 MSE: {mse_konly:.8f}")
    print(f"V-only INT8 MSE: {mse_vonly:.8f}")
    print(f"K/V sensitivity ratio: {mse_konly / mse_vonly:.2f}x")

measure_k_v_sensitivity()
💡 Mixed KV Precision

Some systems use different precision for K and V. For example, FP8 for K (more sensitive) and INT4 for V (less sensitive). This gives memory savings closer to INT4 while maintaining quality closer to FP8, because the attention distribution (determined by K) is accurate, and only the value aggregation (determined by V) has reduced precision.

Production Integration: vLLM and TensorRT-LLM

Both vLLM and TensorRT-LLM support KV cache quantization as a runtime configuration option.

vLLM implements FP8 KV cache quantization natively. The quantization is fused into the paged attention kernel — each page stores quantized KV values with per-token scale factors. No separate dequantization kernel is needed.

TensorRT-LLM supports FP8 and INT8 KV cache through its attention plugins. The user specifies the KV precision in the model configuration, and the engine builder generates optimized kernels.

# vLLM configuration for FP8 KV cache (conceptual)
# from vllm import LLM, SamplingParams
#
# llm = LLM(
#     model="meta-llama/Llama-2-70b",
#     quantization="awq",           # INT4 weights
#     kv_cache_dtype="fp8_e4m3",    # FP8 KV cache
#     max_model_len=8192,
#     gpu_memory_utilization=0.9,
# )
#
# # The combination of INT4 weights + FP8 KV enables:
# # - 70B model on a single H100-80GB
# # - Batch size ~64 at seq_len=4096
# # - Near-lossless quality

# TensorRT-LLM configuration (conceptual)
# trtllm-build \
#   --checkpoint_dir ./llama-70b-awq/ \
#   --kv_cache_type FP8 \
#   --max_batch_size 64 \
#   --max_input_len 4096 \
#   --max_seq_len 8192

Sequence Length Scaling: How Quality Degrades

KV cache quantization error accumulates with sequence length. Each new token attends to all previous tokens, and the quantization errors in early tokens affect every subsequent attention computation.

def measure_quality_vs_seqlen(seq_lens, precision='int4', num_heads=8,
                               head_dim=128):
    """Measure how KV quantization quality degrades with sequence length."""
    results = []
    quantizer = KVCacheQuantizer(precision, head_dim)

    for seq_len in seq_lens:
        query = torch.randn(num_heads, head_dim) * (head_dim ** -0.5)
        keys = torch.randn(seq_len, num_heads, head_dim) * 0.3
        values = torch.randn(seq_len, num_heads, head_dim) * 0.5

        # Reference
        scale = head_dim ** -0.5
        scores = torch.einsum('hd,shd->hs', query, keys) * scale
        attn = F.softmax(scores, dim=-1)
        output_ref = torch.einsum('hs,shd->hd', attn, values)

        # Quantized
        k_q = torch.zeros(seq_len, num_heads, head_dim, dtype=torch.int8)
        v_q = torch.zeros(seq_len, num_heads, head_dim, dtype=torch.int8)
        k_s = torch.zeros(seq_len, num_heads, 1)
        v_s = torch.zeros(seq_len, num_heads, 1)

        for t in range(seq_len):
            qk, sk, qv, sv = quantizer.quantize_kv_pair(keys[t], values[t])
            k_q[t], k_s[t] = qk, sk
            v_q[t], v_s[t] = qv, sv

        output_q = attention_with_quantized_kv(
            query, k_q, k_s, v_q, v_s, quantizer, head_dim
        )

        mse = ((output_ref - output_q) ** 2).mean().item()
        cos_sim = F.cosine_similarity(
            output_ref.flatten().unsqueeze(0),
            output_q.flatten().unsqueeze(0)
        ).item()

        results.append({
            'seq_len': seq_len,
            'mse': mse,
            'cos_sim': cos_sim,
        })

    return results

seq_lens = [256, 512, 1024, 2048, 4096, 8192]
for prec in ['fp8', 'int8', 'int4']:
    print(f"\n{prec} KV cache quality vs sequence length:")
    results = measure_quality_vs_seqlen(seq_lens, precision=prec)
    for r in results:
        print(f"  seq={r['seq_len']:5d}: MSE={r['mse']:.8f}, "
              f"cos_sim={r['cos_sim']:.6f}")
📊

Quality Degradation vs Sequence Length (Llama 70B, INT4 KV Per-Token)

Seq LengthPPL DegradationMMLU ImpactAcceptable?
512 +0.05 -0.1% Yes
2048 +0.10 -0.3% Yes
4096 +0.16 -0.5% Marginal
8192 +0.28 -1.1% Task-dependent
16384 +0.52 -2.3% Consider FP8
Note: Quality degradation grows sub-linearly with sequence length for per-token scaling. For very long contexts (16K+), FP8 is recommended over INT4.

Summary

KV cache quantization is fundamentally different from weight quantization: it operates on dynamic data generated during inference, requires online quantization with zero latency overhead, and uses per-token scaling to handle the variable distributions across sequence positions.

FP8 E4M3 KV is the production sweet spot: 2x memory savings with less than 0.02 perplexity degradation. The non-uniform FP8 representation naturally handles the distributions found in key and value projections.

INT8 KV with per-token scaling provides equivalent memory savings to FP8 with marginally higher error. It is preferred on hardware without FP8 support (pre-Hopper GPUs).

INT4 KV with per-token or per-group scaling provides 4x memory savings (3.5x with scale overhead) but with measurable quality degradation (0.1-0.2 PPL at moderate sequence lengths). Quality degrades further at very long sequence lengths.

Per-token scaling is essential for any KV precision below FP16. Without it, a single outlier token dominates the scale factor and wastes precision for all other tokens.

K is more sensitive than V to quantization error, because K errors affect attention score computation (pre-softmax) while V errors are linearly attenuated by the attention distribution. Mixed-precision approaches (FP8 K, INT4 V) can exploit this asymmetry.

This concludes the Quantization Masterclass series. From number formats (Part 1) through weight quantization (Part 2), activation quantization (Part 3), FP8 training and inference (Part 4), FP4 on Blackwell (Part 5), and KV cache quantization (Part 6), you now have a complete technical foundation for every quantization decision in modern AI systems.