Part of Series Quantization Masterclass 8 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest — Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier — Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 — How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization — The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization — OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

Quantization is not a binary choice. You do not quantize an entire model to INT8 or FP8 and call it done. In a real inference pipeline, different operations use different precisions, determined by the numerical sensitivity of each operation and the hardware support available. The GEMM (matrix multiply) uses FP8 or INT8 tensor cores. LayerNorm runs in FP32 because it computes a variance that is numerically unstable at lower precision. Softmax runs in FP32 because its exponential can overflow FP16. Residual additions run in FP32 to prevent error accumulation across layers.

Getting this wrong — running softmax in FP16, for example — does not produce a slightly worse model. It produces NaN outputs or catastrophic quality collapse. This post documents the precision requirements for every operation in a transformer inference pipeline, explains the numerical reasons behind each choice, and implements a per-op precision annotation system.

The Precision Hierarchy

Overview of Operations and Their Precisions

In a standard transformer decoder layer (Llama-style architecture), the operations and their recommended precisions are:

📊

Per-Operation Precision Requirements in Transformer Inference

OperationRecommended PrecisionReasonCost of Wrong Precision
Token Embedding FP16/BF16 Lookup table, no compute Negligible impact
RoPE Encoding FP32 Trigonometric precision Position encoding errors
QKV Projection (GEMM) FP8/INT8 compute, FP16 output Tensor core throughput 2x slower if FP16
Attention Score (QK^T) FP8/INT8 compute, FP32 accumulate Overflow at long contexts NaN at seq_len greater than 2K
Softmax FP32 exp() overflow/underflow NaN outputs
Attention Value (Score * V) FP8/INT8 compute, FP16 output Tensor core throughput 2x slower if FP16
Output Projection (GEMM) FP8/INT8 compute, FP16 output Tensor core throughput 2x slower if FP16
Residual Addition FP32 Error accumulation across layers Quality degradation at 40+ layers
RMSNorm / LayerNorm FP32 Variance computation stability NaN or quality collapse
MLP Gate Projection (GEMM) FP8/INT8 compute, FP16 output Tensor core throughput 2x slower if FP16
SiLU/GELU Activation FP16/BF16 Smooth function, tolerant Negligible impact
MLP Down Projection (GEMM) FP8/INT8 compute, FP16 output Tensor core throughput 2x slower if FP16
LM Head (Final GEMM) FP16 or FP32 Output logit precision matters Top-k/top-p sampling errors
Note: GEMMs dominate runtime (>90%). Using FP8/INT8 for GEMMs and FP32 for everything else is the standard mixed-precision strategy.

Why Each Operation Needs Its Precision

GEMMs: FP8/INT8 Tensor Cores

GEMMs (General Matrix Multiplications) account for over 90% of compute in transformer inference. They map directly to tensor core instructions:

  • A100: INT8 tensor cores at 624 TOPS (vs 312 TFLOPS FP16)
  • H100: FP8 tensor cores at 3958 TFLOPS (vs 1979 TFLOPS FP16), INT8 at 1979 TOPS
  • Blackwell B200: FP4 tensor cores at 9000+ TFLOPS

The GEMM computes Y=XWY = XW where XX is the activation matrix and WW is the weight matrix. Both inputs can be quantized to FP8 or INT8, and the tensor core performs the multiply-accumulate in the quantized format. The accumulator is always FP32 (or at minimum FP16), preventing error accumulation within the GEMM.

import torch
import torch.nn.functional as F

def gemm_precision_comparison(M=2048, N=4096, K=4096):
    """Compare GEMM output across precisions."""
    # Reference: FP32
    A_fp32 = torch.randn(M, K, dtype=torch.float32, device='cuda')
    B_fp32 = torch.randn(K, N, dtype=torch.float32, device='cuda')
    Y_ref = A_fp32 @ B_fp32

    # FP16 GEMM (tensor core, FP32 accumulate)
    A_fp16 = A_fp32.half()
    B_fp16 = B_fp32.half()
    Y_fp16 = (A_fp16 @ B_fp16).float()

    # BF16 GEMM
    A_bf16 = A_fp32.bfloat16()
    B_bf16 = B_fp32.bfloat16()
    Y_bf16 = (A_bf16 @ B_bf16).float()

    # Simulated INT8 GEMM
    a_scale = A_fp32.abs().max() / 127.0
    b_scale = B_fp32.abs().max() / 127.0
    A_int8 = torch.clamp(torch.round(A_fp32 / a_scale), -128, 127)
    B_int8 = torch.clamp(torch.round(B_fp32 / b_scale), -128, 127)
    Y_int8 = (A_int8.float() @ B_int8.float()) * (a_scale * b_scale)

    for name, Y in [("FP16", Y_fp16), ("BF16", Y_bf16), ("INT8", Y_int8)]:
        rel_err = ((Y - Y_ref).norm() / Y_ref.norm()).item()
        max_err = (Y - Y_ref).abs().max().item()
        print(f"{name}: relative_error={rel_err:.6f}, max_abs_error={max_err:.4f}")
ℹ️ FP32 Accumulation Is Non-Negotiable

Tensor cores always accumulate in FP32 (or at minimum higher precision than the inputs). An FP8 multiply produces a 16-bit intermediate, and the addition tree uses FP32. Without this, a 4096-element dot product would overflow FP8 range (max value 448 for E4M3) within the first few terms.

LayerNorm / RMSNorm: FP32

LayerNorm computes:

LayerNorm(x)=xμσ2+ϵγ+β\text{LayerNorm}(x) = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} \cdot \gamma + \beta

RMSNorm (used in Llama, Mistral) computes:

RMSNorm(x)=x1di=1dxi2+ϵγ\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{d}\sum_{i=1}^{d} x_i^2 + \epsilon}} \cdot \gamma

The variance computation σ2=1d(xiμ)2\sigma^2 = \frac{1}{d}\sum(x_i - \mu)^2 is the critical operation. In FP16 (max representable value: 65504, precision: ~0.001 at value 1.0), summing 4096 squared values can:

  1. Overflow: if xi=10x_i = 10, then xi2=100x_i^2 = 100, and the sum of 4096 terms is 409,600 — within FP16 range. But if xi=50x_i = 50 (possible with activation outliers), xi2=2500x_i^2 = 2500, sum = 10,240,000 — overflows FP16.
  2. Lose precision: the subtraction xiμx_i - \mu can produce catastrophic cancellation when xiμx_i \approx \mu.
def layernorm_precision_failure():
    """Demonstrate LayerNorm failure in FP16 vs FP32."""
    d = 4096
    # Activations with outliers (realistic for LLMs)
    x = torch.randn(1, d, dtype=torch.float32, device='cuda') * 0.5
    # Insert outlier channels
    x[0, 0:10] = 80.0  # Large outliers

    # FP32 LayerNorm (correct)
    ln_fp32 = torch.nn.LayerNorm(d, device='cuda', dtype=torch.float32)
    y_fp32 = ln_fp32(x)

    # FP16 LayerNorm (problematic)
    ln_fp16 = torch.nn.LayerNorm(d, device='cuda', dtype=torch.float16)
    ln_fp16.weight.data = ln_fp32.weight.data.half()
    ln_fp16.bias.data = ln_fp32.bias.data.half()
    x_fp16 = x.half()

    # Check for overflow in variance computation
    variance_fp32 = x.var(dim=-1)
    variance_fp16 = x_fp16.float().var(dim=-1)  # Compute in float for comparison

    # The FP16 computation of x^2 can overflow
    x_squared_fp16 = (x_fp16 * x_fp16)
    has_inf = torch.isinf(x_squared_fp16).any().item()
    print(f"FP16 x^2 has inf: {has_inf}")
    print(f"FP32 variance: {variance_fp32.item():.4f}")

    try:
        y_fp16 = ln_fp16(x_fp16)
        has_nan = torch.isnan(y_fp16).any().item()
        print(f"FP16 LayerNorm output has NaN: {has_nan}")
    except RuntimeError as e:
        print(f"FP16 LayerNorm failed: {e}")

Softmax: FP32

Softmax computes:

softmax(xi)=exixmaxjexjxmax\text{softmax}(x_i) = \frac{e^{x_i - x_{\max}}}{\sum_j e^{x_j - x_{\max}}}

Even with the xmaxx_{\max} subtraction for numerical stability, FP16 softmax fails in several scenarios:

  1. Exponent range: FP16 max is 65504. e1159874e^{11} \approx 59874 is near FP16 max. Attention logits qkT/dqk^T / \sqrt{d} can exceed 11 for long sequences or when attention is sharply focused.

  2. Small probabilities: After softmax, most attention weights are near zero. FP16 minimum positive normal is 6×105\approx 6 \times 10^{-5}. Any attention weight smaller than this becomes exactly zero, losing information about low-attention tokens.

  3. Sum precision: The denominator exjxmax\sum e^{x_j - x_{\max}} sums potentially thousands of terms. FP16 accumulation loses precision.

def softmax_precision_failure(seq_len=4096):
    """Demonstrate softmax precision issues in FP16."""
    # Simulate attention scores
    # At long contexts, some scores can be large
    scores = torch.randn(1, 32, seq_len, seq_len, device='cuda') * 2.0

    # Inject a few very strong attention positions
    scores[0, :, :, 0] = 15.0  # Strong attention to position 0

    # FP32 softmax (reference)
    probs_fp32 = torch.softmax(scores, dim=-1)

    # FP16 softmax
    scores_fp16 = scores.half()
    probs_fp16 = torch.softmax(scores_fp16, dim=-1).float()

    # Compare
    max_diff = (probs_fp32 - probs_fp16).abs().max().item()
    mean_diff = (probs_fp32 - probs_fp16).abs().mean().item()
    num_zeros_fp32 = (probs_fp32 == 0).sum().item()
    num_zeros_fp16 = (probs_fp16 == 0).sum().item()

    print(f"Max probability difference: {max_diff:.8f}")
    print(f"Mean probability difference: {mean_diff:.8f}")
    print(f"Zero entries FP32: {num_zeros_fp32}, FP16: {num_zeros_fp16}")
    print(f"FP16 lost {num_zeros_fp16 - num_zeros_fp32} non-zero probabilities")
🚨 FlashAttention Handles This Internally

FlashAttention computes softmax in FP32 within the kernel, even when the inputs and outputs are FP16/BF16. The online softmax algorithm maintains the running max and sum in FP32 registers. If you are using FlashAttention, you do not need to worry about softmax precision. If you are writing a custom attention kernel, you must handle this yourself.

Residual Additions: FP32

In a transformer with LL layers, each residual add contributes:

hl+1=hl+Attn(hl)+MLP(hl+Attn(hl))h_{l+1} = h_l + \text{Attn}(h_l) + \text{MLP}(h_l + \text{Attn}(h_l))

The hidden state hh passes through all LL layers via residual connections. If residual additions are done in FP16, rounding errors accumulate:

  • Each FP16 addition has relative error ϵ5×104\epsilon \approx 5 \times 10^{-4} (half-precision unit roundoff)
  • After LL layers with 2 residual adds each, the accumulated error is O(Lϵ)O(L \cdot \epsilon)
  • For L=80L = 80 (Llama-2 70B), this is O(0.08)O(0.08) — an 8% relative error on the hidden state
def residual_accumulation_error(num_layers=80, hidden_dim=8192):
    """Simulate residual error accumulation across layers."""
    # FP32 reference path
    h_fp32 = torch.randn(1, 1, hidden_dim, dtype=torch.float32, device='cuda')

    # FP16 path
    h_fp16 = h_fp32.half()

    for layer in range(num_layers):
        # Simulate attention output
        attn_out = torch.randn_like(h_fp32) * 0.1
        # Simulate MLP output
        mlp_out = torch.randn_like(h_fp32) * 0.1

        # FP32 residual path
        h_fp32 = h_fp32 + attn_out
        h_fp32 = h_fp32 + mlp_out

        # FP16 residual path
        h_fp16 = h_fp16 + attn_out.half()
        h_fp16 = h_fp16 + mlp_out.half()

        if layer % 10 == 0:
            rel_err = ((h_fp32 - h_fp16.float()).norm() / h_fp32.norm()).item()
            print(f"Layer {layer:3d}: relative error = {rel_err:.6f}")

    final_err = ((h_fp32 - h_fp16.float()).norm() / h_fp32.norm()).item()
    print(f"Final relative error after {num_layers} layers: {final_err:.6f}")
    return final_err

RoPE: FP32 for Trigonometric Computation

Rotary Position Embeddings compute cos(mθ)\cos(m\theta) and sin(mθ)\sin(m\theta) where mm is the position index and θ\theta varies by dimension. At long contexts (m>100000m \gt 100000), the product mθm\theta is large, and FP16 trigonometric functions lose precision.

def rope_precision_analysis(max_pos=131072, dim=128):
    """Show RoPE precision loss in FP16 at long positions."""
    # Standard RoPE frequencies
    base = 10000.0
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))

    positions = torch.tensor([1, 1000, 10000, 100000, max_pos], dtype=torch.float32)

    for pos in positions:
        angles_fp32 = pos * freqs
        cos_fp32 = torch.cos(angles_fp32)

        angles_fp16 = (pos * freqs).half()
        cos_fp16 = torch.cos(angles_fp16.float())  # cos in fp32 of fp16 angle

        # Angle quantization error
        angle_err = (angles_fp32 - angles_fp16.float()).abs().max().item()
        cos_err = (cos_fp32 - cos_fp16).abs().max().item()

        print(f"Position {int(pos.item()):>7d}: "
              f"max_angle_error={angle_err:.6f}, "
              f"max_cos_error={cos_err:.6f}")

Embedding Lookup: FP16

Token embeddings are a lookup table. The embedding table stores V×dV \times d vectors (vocabulary size times hidden dimension). There is no computation — just a table lookup. FP16 is sufficient because:

  1. The lookup itself is a memory operation, not a compute operation
  2. The embedding vectors have moderate magnitudes (initialized near zero, trained to values in the range of approximately -1 to 1)
  3. The subsequent LayerNorm (in FP32) handles any precision issues
def embedding_precision_analysis(vocab_size=128256, hidden=4096):
    """Embedding tables are fine in FP16."""
    embed_fp32 = torch.nn.Embedding(vocab_size, hidden, dtype=torch.float32)
    embed_fp16 = torch.nn.Embedding(vocab_size, hidden, dtype=torch.float16)
    embed_fp16.weight.data = embed_fp32.weight.data.half()

    # Random input tokens
    input_ids = torch.randint(0, vocab_size, (1, 128))

    out_fp32 = embed_fp32(input_ids)
    out_fp16 = embed_fp16(input_ids).float()

    rel_err = ((out_fp32 - out_fp16).norm() / out_fp32.norm()).item()
    print(f"Embedding FP16 relative error: {rel_err:.8f}")
    # Typically < 1e-3 -- negligible

Implementation: Per-Op Precision Annotation

The Precision Policy

from dataclasses import dataclass, field
from enum import Enum

class OpPrecision(Enum):
    FP32 = "float32"
    FP16 = "float16"
    BF16 = "bfloat16"
    FP8_E4M3 = "float8_e4m3fn"
    FP8_E5M2 = "float8_e5m2"
    INT8 = "int8"
    INT4 = "int4"

@dataclass
class LayerPrecisionPolicy:
    """Precision policy for a single transformer layer."""
    # GEMM inputs (weights and activations)
    gemm_weight: OpPrecision = OpPrecision.FP8_E4M3
    gemm_activation: OpPrecision = OpPrecision.FP8_E4M3
    gemm_accumulator: OpPrecision = OpPrecision.FP32

    # Normalization
    layernorm: OpPrecision = OpPrecision.FP32

    # Attention
    softmax: OpPrecision = OpPrecision.FP32
    rope: OpPrecision = OpPrecision.FP32

    # Residual path
    residual: OpPrecision = OpPrecision.FP32

    # Activations (SiLU, GELU)
    activation_fn: OpPrecision = OpPrecision.FP16

    # KV cache storage
    kv_cache: OpPrecision = OpPrecision.FP8_E4M3

    # Output
    lm_head: OpPrecision = OpPrecision.FP16

@dataclass
class ModelPrecisionPolicy:
    """Precision policy for the full model."""
    embedding: OpPrecision = OpPrecision.FP16
    layers: LayerPrecisionPolicy = field(default_factory=LayerPrecisionPolicy)
    lm_head: OpPrecision = OpPrecision.FP16

    def summary(self):
        print("=== Model Precision Policy ===")
        print(f"Embedding:       {self.embedding.value}")
        print(f"GEMM weight:     {self.layers.gemm_weight.value}")
        print(f"GEMM activation: {self.layers.gemm_activation.value}")
        print(f"GEMM accumulate: {self.layers.gemm_accumulator.value}")
        print(f"LayerNorm:       {self.layers.layernorm.value}")
        print(f"Softmax:         {self.layers.softmax.value}")
        print(f"RoPE:            {self.layers.rope.value}")
        print(f"Residual add:    {self.layers.residual.value}")
        print(f"Activation fn:   {self.layers.activation_fn.value}")
        print(f"KV cache:        {self.layers.kv_cache.value}")
        print(f"LM Head:         {self.layers.lm_head.value}")

Standard Policies for Common Configurations

def policy_fp16_baseline():
    """Standard FP16 inference -- no quantization."""
    return ModelPrecisionPolicy(
        embedding=OpPrecision.FP16,
        layers=LayerPrecisionPolicy(
            gemm_weight=OpPrecision.FP16,
            gemm_activation=OpPrecision.FP16,
            gemm_accumulator=OpPrecision.FP32,
            layernorm=OpPrecision.FP32,
            softmax=OpPrecision.FP32,
            rope=OpPrecision.FP32,
            residual=OpPrecision.FP32,
            activation_fn=OpPrecision.FP16,
            kv_cache=OpPrecision.FP16,
            lm_head=OpPrecision.FP16,
        ),
        lm_head=OpPrecision.FP16,
    )

def policy_w8a8_int8():
    """W8A8 INT8 inference (SmoothQuant style)."""
    return ModelPrecisionPolicy(
        embedding=OpPrecision.FP16,
        layers=LayerPrecisionPolicy(
            gemm_weight=OpPrecision.INT8,
            gemm_activation=OpPrecision.INT8,
            gemm_accumulator=OpPrecision.FP32,
            layernorm=OpPrecision.FP32,
            softmax=OpPrecision.FP32,
            rope=OpPrecision.FP32,
            residual=OpPrecision.FP32,
            activation_fn=OpPrecision.FP16,
            kv_cache=OpPrecision.INT8,
            lm_head=OpPrecision.FP16,
        ),
        lm_head=OpPrecision.FP16,
    )

def policy_fp8_h100():
    """FP8 inference on H100 (optimal for Hopper)."""
    return ModelPrecisionPolicy(
        embedding=OpPrecision.FP16,
        layers=LayerPrecisionPolicy(
            gemm_weight=OpPrecision.FP8_E4M3,
            gemm_activation=OpPrecision.FP8_E4M3,
            gemm_accumulator=OpPrecision.FP32,
            layernorm=OpPrecision.FP32,
            softmax=OpPrecision.FP32,
            rope=OpPrecision.FP32,
            residual=OpPrecision.FP32,
            activation_fn=OpPrecision.BF16,
            kv_cache=OpPrecision.FP8_E4M3,
            lm_head=OpPrecision.BF16,
        ),
        lm_head=OpPrecision.BF16,
    )

def policy_w4a16_gptq():
    """W4A16: INT4 weights, FP16 activations (GPTQ/AWQ style)."""
    return ModelPrecisionPolicy(
        embedding=OpPrecision.FP16,
        layers=LayerPrecisionPolicy(
            gemm_weight=OpPrecision.INT4,
            gemm_activation=OpPrecision.FP16,
            gemm_accumulator=OpPrecision.FP32,
            layernorm=OpPrecision.FP32,
            softmax=OpPrecision.FP32,
            rope=OpPrecision.FP32,
            residual=OpPrecision.FP32,
            activation_fn=OpPrecision.FP16,
            kv_cache=OpPrecision.FP16,
            lm_head=OpPrecision.FP16,
        ),
        lm_head=OpPrecision.FP16,
    )

Applying the Policy to a Model

class MixedPrecisionWrapper(torch.nn.Module):
    """Wrap a transformer layer to enforce precision policy."""

    def __init__(self, layer, policy):
        super().__init__()
        self.layer = layer
        self.policy = policy

    def cast_for_gemm(self, weight, activation):
        """Cast weight and activation to GEMM precision."""
        wp = self.policy.gemm_weight
        ap = self.policy.gemm_activation

        if wp == OpPrecision.FP8_E4M3:
            w = self.quantize_fp8(weight)
        elif wp == OpPrecision.INT8:
            w = self.quantize_int8(weight)
        elif wp == OpPrecision.INT4:
            w = weight  # INT4 dequantized at kernel level
        else:
            w = weight.to(getattr(torch, wp.value))

        if ap == OpPrecision.FP8_E4M3:
            a = self.quantize_fp8(activation)
        elif ap == OpPrecision.INT8:
            a = self.quantize_int8(activation)
        else:
            a = activation.to(getattr(torch, ap.value))

        return w, a

    def quantize_fp8(self, tensor):
        """Quantize to FP8 E4M3 with per-tensor scale."""
        abs_max = tensor.detach().abs().max()
        # FP8 E4M3 max value is 448
        scale = abs_max / 448.0
        scale = max(scale.item(), 1e-12)
        # Simulate FP8: scale down, clamp, scale back
        t_scaled = tensor / scale
        t_clamped = torch.clamp(t_scaled, -448.0, 448.0)
        return t_clamped * scale, scale

    def quantize_int8(self, tensor):
        """Quantize to INT8 with per-tensor scale."""
        abs_max = tensor.detach().abs().max()
        scale = abs_max / 127.0
        scale = max(scale.item(), 1e-12)
        t_int = torch.clamp(torch.round(tensor / scale), -128, 127)
        return t_int * scale, scale

    def forward_norm(self, norm_module, x):
        """Run normalization in policy-specified precision."""
        target_dtype = getattr(torch, self.policy.layernorm.value)
        x_cast = x.to(target_dtype)
        out = norm_module(x_cast)
        return out.to(x.dtype)

    def forward_residual(self, residual, new_value):
        """Run residual addition in policy-specified precision."""
        target_dtype = getattr(torch, self.policy.residual.value)
        return (residual.to(target_dtype) + new_value.to(target_dtype)).to(residual.dtype)

The Memory Bandwidth Perspective

Why GEMM Precision Affects Decode Speed

During autoregressive decode, each token requires reading the entire weight matrix from HBM. The decode step is memory-bandwidth-bound, not compute-bound. Reducing weight precision from FP16 to FP8 halves the data read, directly translating to 2x higher tokens/second.

def compute_decode_bandwidth_requirements(
    model_params_B=70,
    num_layers=80,
    hidden_dim=8192,
    batch_size=1,
    hbm_bandwidth_GBs=3350,  # H100
):
    """Calculate time per token for different weight precisions."""
    total_weight_bytes = {
        "FP16": model_params_B * 1e9 * 2,      # 2 bytes per param
        "FP8":  model_params_B * 1e9 * 1,       # 1 byte per param
        "INT8": model_params_B * 1e9 * 1,       # 1 byte per param
        "INT4": model_params_B * 1e9 * 0.5,     # 0.5 bytes per param
    }

    print(f"Model: {model_params_B}B params, HBM BW: {hbm_bandwidth_GBs} GB/s")
    print(f"{'Precision':<10} {'Weight Size':<15} {'Time/Token':<15} {'Tokens/sec':<15}")

    for precision, size_bytes in total_weight_bytes.items():
        size_gb = size_bytes / 1e9
        time_per_token_ms = (size_gb / hbm_bandwidth_GBs) * 1000
        tokens_per_sec = 1000 / time_per_token_ms

        print(f"{precision:<10} {size_gb:>8.1f} GB     "
              f"{time_per_token_ms:>8.2f} ms     "
              f"{tokens_per_sec:>8.1f} tok/s")

compute_decode_bandwidth_requirements()

Decode Tokens/sec by Weight Precision (Llama-2 70B, H100 SXM)

(tokens/sec)
FP16 baseline
24 tokens/sec
FP8/INT8 2.0x
48 tokens/sec
INT4 4.0x
96 tokens/sec
Non-GEMM Ops Are Negligible for Bandwidth

LayerNorm, softmax, activation functions, and residual adds operate on tensors of size [B,S,d][B, S, d] (batch, sequence length, hidden dim). For decode (S=1S=1), these tensors are tiny compared to the weight matrices. Running them in FP32 instead of FP16 doubles their size but has negligible impact on total memory bandwidth (<1% of total traffic).

Prefill vs Decode: Different Bottlenecks

During prefill (processing the prompt), the operation is compute-bound, not memory-bound. The weight matrices are read once but used for many tokens. Here, the tensor core throughput matters:

def prefill_vs_decode_analysis(
    model_params_B=70,
    hidden_dim=8192,
    prompt_length=2048,
    hbm_bandwidth_GBs=3350,
    fp16_tflops=990,     # H100 FP16 tensor core
    fp8_tflops=1979,     # H100 FP8 tensor core
    int8_tops=1979,      # H100 INT8 tensor core
):
    """Compare prefill (compute-bound) and decode (memory-bound)."""
    weight_bytes_fp16 = model_params_B * 1e9 * 2
    weight_bytes_fp8 = model_params_B * 1e9 * 1

    # Prefill: compute-bound
    # FLOPs = 2 * params * seq_len (for each token, 2 FLOPs per weight)
    flops_prefill = 2 * model_params_B * 1e9 * prompt_length

    prefill_fp16_ms = (flops_prefill / (fp16_tflops * 1e12)) * 1000
    prefill_fp8_ms = (flops_prefill / (fp8_tflops * 1e12)) * 1000
    prefill_int8_ms = (flops_prefill / (int8_tops * 1e12)) * 1000

    # Decode: memory-bound
    decode_fp16_ms = (weight_bytes_fp16 / (hbm_bandwidth_GBs * 1e9)) * 1000
    decode_fp8_ms = (weight_bytes_fp8 / (hbm_bandwidth_GBs * 1e9)) * 1000

    print("=== Prefill (compute-bound) ===")
    print(f"FP16: {prefill_fp16_ms:.1f} ms ({prompt_length} tokens)")
    print(f"FP8:  {prefill_fp8_ms:.1f} ms (speedup: {prefill_fp16_ms/prefill_fp8_ms:.2f}x)")
    print(f"INT8: {prefill_int8_ms:.1f} ms (speedup: {prefill_fp16_ms/prefill_int8_ms:.2f}x)")

    print("\n=== Decode (memory-bound) ===")
    print(f"FP16: {decode_fp16_ms:.2f} ms per token")
    print(f"FP8:  {decode_fp8_ms:.2f} ms per token "
          f"(speedup: {decode_fp16_ms/decode_fp8_ms:.2f}x)")

prefill_vs_decode_analysis()
📊

Prefill vs Decode Speedup from FP8 (Llama-2 70B, H100)

PhaseFP16 TimeFP8 TimeSpeedupBottleneck
Prefill (2K tokens) 290 ms 145 ms 2.0x Compute (tensor cores)
Decode (1 token) 41.8 ms 20.9 ms 2.0x Memory bandwidth
Prefill (128 tokens) 18.1 ms 9.1 ms 2.0x Compute
Decode (batch=32) 41.8 ms 20.9 ms 2.0x Memory bandwidth
Note: FP8 provides 2x speedup for both phases but for different reasons. Prefill: 2x tensor core throughput. Decode: 2x less data to read from HBM.

Mixed Precision in Production Systems

vLLM’s Precision Handling

vLLM applies precision per-op based on the quantization configuration:

# Pseudocode showing vLLM's mixed precision handling
class LlamaDecoderLayer:
    def forward(self, hidden_states, kv_cache):
        # 1. RMSNorm: always FP32 internally
        residual = hidden_states
        normed = self.input_layernorm(hidden_states)  # FP32 internal

        # 2. QKV Projection: quantized GEMM
        # Weights stored in INT4/INT8/FP8
        # Dequantize + GEMM in one fused kernel
        qkv = self.qkv_proj(normed)  # FP8 compute, FP16 output

        # 3. RoPE: FP32 trig, cast back to FP16
        q, k = apply_rope(qkv, positions)  # FP32 sin/cos

        # 4. Attention: FlashAttention handles precision internally
        # QK^T in reduced precision, softmax in FP32,
        # output in FP16
        attn_out = flash_attention(q, k, v, kv_cache)

        # 5. Output projection: quantized GEMM
        attn_out = self.o_proj(attn_out)  # FP8 compute, FP16 output

        # 6. Residual: FP32
        hidden_states = residual.float() + attn_out.float()
        hidden_states = hidden_states.to(residual.dtype)

        # 7. Post-attention norm: FP32 internal
        residual = hidden_states
        normed = self.post_attention_layernorm(hidden_states)

        # 8. MLP: quantized GEMMs
        gate = self.gate_proj(normed)   # FP8 compute
        up = self.up_proj(normed)       # FP8 compute
        mlp_out = F.silu(gate) * up     # FP16 activation
        mlp_out = self.down_proj(mlp_out)  # FP8 compute

        # 9. Residual: FP32
        hidden_states = residual.float() + mlp_out.float()
        hidden_states = hidden_states.to(residual.dtype)

        return hidden_states

BF16 vs FP16 for Non-GEMM Operations

BF16 (bfloat16) has the same exponent range as FP32 (8 exponent bits) but only 7 mantissa bits (vs 10 for FP16). For non-GEMM operations:

def bf16_vs_fp16_comparison():
    """BF16 has larger range but lower precision than FP16."""
    x = torch.tensor([65504.0, 0.00006, 1.0009765625])

    fp16 = x.half()
    bf16 = x.bfloat16()
    fp32 = x.float()

    print("Value        | FP32          | FP16          | BF16")
    print("-" * 65)
    for i in range(len(x)):
        print(f"{fp32[i].item():<12} | "
              f"{fp32[i].item():<13} | "
              f"{fp16[i].item():<13} | "
              f"{bf16[i].item():<13}")

    # BF16 advantage: no overflow for large intermediate values
    large_val = torch.tensor([100000.0])
    print(f"\n100000.0 in FP16: {large_val.half().item()}")    # inf!
    print(f"100000.0 in BF16: {large_val.bfloat16().item()}")  # 99840.0
ℹ️ BF16 Is Preferred Over FP16 for Non-GEMM Ops

Modern LLM inference (H100 and later) uses BF16 for non-GEMM operations instead of FP16. BF16 cannot overflow for values that occur in practice (range up to 3.4×10383.4 \times 10^{38}), eliminating the need for explicit overflow checks. The precision loss (7 vs 10 mantissa bits) is acceptable for intermediate computations.

Precision Validation Framework

Automated Precision Testing

class PrecisionValidator:
    """Validate that each op produces correct results at its assigned precision."""

    def __init__(self, model, policy, tolerance=0.01):
        self.model = model
        self.policy = policy
        self.tolerance = tolerance
        self.results = {}

    def validate_layernorm(self, test_input):
        """Test LayerNorm at different precisions."""
        for name, module in self.model.named_modules():
            if not isinstance(module, (torch.nn.LayerNorm, torch.nn.RMSNorm)):
                continue

            # FP32 reference
            ref = module(test_input.float()).float()

            # Test FP16
            try:
                fp16_out = module(test_input.half()).float()
                fp16_err = ((ref - fp16_out).norm() / ref.norm()).item()
                has_nan = torch.isnan(fp16_out).any().item()
            except RuntimeError:
                fp16_err = float('inf')
                has_nan = True

            # Test BF16
            bf16_out = module(test_input.bfloat16()).float()
            bf16_err = ((ref - bf16_out).norm() / ref.norm()).item()

            self.results[name] = {
                'fp16_error': fp16_err,
                'fp16_has_nan': has_nan,
                'bf16_error': bf16_err,
                'recommendation': 'FP32' if has_nan or fp16_err > self.tolerance
                                  else 'FP16/BF16'
            }

        return self.results

    def validate_softmax(self, score_tensor):
        """Test softmax at different precisions."""
        # FP32 reference
        ref = torch.softmax(score_tensor.float(), dim=-1)

        # FP16
        fp16_out = torch.softmax(score_tensor.half(), dim=-1).float()
        fp16_err = ((ref - fp16_out).norm() / ref.norm()).item()
        fp16_nan = torch.isnan(fp16_out).any().item()

        # BF16
        bf16_out = torch.softmax(score_tensor.bfloat16(), dim=-1).float()
        bf16_err = ((ref - bf16_out).norm() / ref.norm()).item()

        self.results['softmax'] = {
            'fp16_error': fp16_err,
            'fp16_has_nan': fp16_nan,
            'bf16_error': bf16_err,
        }
        return self.results

    def report(self):
        """Print validation report."""
        print("\n=== Precision Validation Report ===")
        for name, result in self.results.items():
            status = "PASS" if not result.get('fp16_has_nan', False) else "FAIL"
            print(f"{name}: {status}")
            for key, val in result.items():
                print(f"  {key}: {val}")

Summary

Mixed precision inference is not optional — it is required for correct and efficient LLM serving. The rules are concrete: GEMMs use FP8/INT8 for throughput (this is where 90%+ of compute lives), LayerNorm and softmax use FP32 for numerical correctness (FP16 produces NaN or overflow), residual additions use FP32 to prevent error accumulation across layers, and RoPE uses FP32 for trigonometric precision at long contexts.

The implementation is straightforward: define a precision policy per operation type, apply casts at operation boundaries, and validate with automated tests. Production systems like vLLM handle this internally — the user specifies the weight quantization format, and the framework applies the correct mixed precision policy for every other operation.