Quantization is not a binary choice. You do not quantize an entire model to INT8 or FP8 and call it done. In a real inference pipeline, different operations use different precisions, determined by the numerical sensitivity of each operation and the hardware support available. The GEMM (matrix multiply) uses FP8 or INT8 tensor cores. LayerNorm runs in FP32 because it computes a variance that is numerically unstable at lower precision. Softmax runs in FP32 because its exponential can overflow FP16. Residual additions run in FP32 to prevent error accumulation across layers.
Getting this wrong — running softmax in FP16, for example — does not produce a slightly worse model. It produces NaN outputs or catastrophic quality collapse. This post documents the precision requirements for every operation in a transformer inference pipeline, explains the numerical reasons behind each choice, and implements a per-op precision annotation system.
The Precision Hierarchy
Overview of Operations and Their Precisions
In a standard transformer decoder layer (Llama-style architecture), the operations and their recommended precisions are:
Per-Operation Precision Requirements in Transformer Inference
| Operation | Recommended Precision | Reason | Cost of Wrong Precision |
|---|---|---|---|
| Token Embedding | FP16/BF16 | Lookup table, no compute | Negligible impact |
| RoPE Encoding | FP32 | Trigonometric precision | Position encoding errors |
| QKV Projection (GEMM) | FP8/INT8 compute, FP16 output | Tensor core throughput | 2x slower if FP16 |
| Attention Score (QK^T) | FP8/INT8 compute, FP32 accumulate | Overflow at long contexts | NaN at seq_len greater than 2K |
| Softmax | FP32 | exp() overflow/underflow | NaN outputs |
| Attention Value (Score * V) | FP8/INT8 compute, FP16 output | Tensor core throughput | 2x slower if FP16 |
| Output Projection (GEMM) | FP8/INT8 compute, FP16 output | Tensor core throughput | 2x slower if FP16 |
| Residual Addition | FP32 | Error accumulation across layers | Quality degradation at 40+ layers |
| RMSNorm / LayerNorm | FP32 | Variance computation stability | NaN or quality collapse |
| MLP Gate Projection (GEMM) | FP8/INT8 compute, FP16 output | Tensor core throughput | 2x slower if FP16 |
| SiLU/GELU Activation | FP16/BF16 | Smooth function, tolerant | Negligible impact |
| MLP Down Projection (GEMM) | FP8/INT8 compute, FP16 output | Tensor core throughput | 2x slower if FP16 |
| LM Head (Final GEMM) | FP16 or FP32 | Output logit precision matters | Top-k/top-p sampling errors |
Why Each Operation Needs Its Precision
GEMMs: FP8/INT8 Tensor Cores
GEMMs (General Matrix Multiplications) account for over 90% of compute in transformer inference. They map directly to tensor core instructions:
- A100: INT8 tensor cores at 624 TOPS (vs 312 TFLOPS FP16)
- H100: FP8 tensor cores at 3958 TFLOPS (vs 1979 TFLOPS FP16), INT8 at 1979 TOPS
- Blackwell B200: FP4 tensor cores at 9000+ TFLOPS
The GEMM computes where is the activation matrix and is the weight matrix. Both inputs can be quantized to FP8 or INT8, and the tensor core performs the multiply-accumulate in the quantized format. The accumulator is always FP32 (or at minimum FP16), preventing error accumulation within the GEMM.
import torch
import torch.nn.functional as F
def gemm_precision_comparison(M=2048, N=4096, K=4096):
"""Compare GEMM output across precisions."""
# Reference: FP32
A_fp32 = torch.randn(M, K, dtype=torch.float32, device='cuda')
B_fp32 = torch.randn(K, N, dtype=torch.float32, device='cuda')
Y_ref = A_fp32 @ B_fp32
# FP16 GEMM (tensor core, FP32 accumulate)
A_fp16 = A_fp32.half()
B_fp16 = B_fp32.half()
Y_fp16 = (A_fp16 @ B_fp16).float()
# BF16 GEMM
A_bf16 = A_fp32.bfloat16()
B_bf16 = B_fp32.bfloat16()
Y_bf16 = (A_bf16 @ B_bf16).float()
# Simulated INT8 GEMM
a_scale = A_fp32.abs().max() / 127.0
b_scale = B_fp32.abs().max() / 127.0
A_int8 = torch.clamp(torch.round(A_fp32 / a_scale), -128, 127)
B_int8 = torch.clamp(torch.round(B_fp32 / b_scale), -128, 127)
Y_int8 = (A_int8.float() @ B_int8.float()) * (a_scale * b_scale)
for name, Y in [("FP16", Y_fp16), ("BF16", Y_bf16), ("INT8", Y_int8)]:
rel_err = ((Y - Y_ref).norm() / Y_ref.norm()).item()
max_err = (Y - Y_ref).abs().max().item()
print(f"{name}: relative_error={rel_err:.6f}, max_abs_error={max_err:.4f}")
Tensor cores always accumulate in FP32 (or at minimum higher precision than the inputs). An FP8 multiply produces a 16-bit intermediate, and the addition tree uses FP32. Without this, a 4096-element dot product would overflow FP8 range (max value 448 for E4M3) within the first few terms.
LayerNorm / RMSNorm: FP32
LayerNorm computes:
RMSNorm (used in Llama, Mistral) computes:
The variance computation is the critical operation. In FP16 (max representable value: 65504, precision: ~0.001 at value 1.0), summing 4096 squared values can:
- Overflow: if , then , and the sum of 4096 terms is 409,600 — within FP16 range. But if (possible with activation outliers), , sum = 10,240,000 — overflows FP16.
- Lose precision: the subtraction can produce catastrophic cancellation when .
def layernorm_precision_failure():
"""Demonstrate LayerNorm failure in FP16 vs FP32."""
d = 4096
# Activations with outliers (realistic for LLMs)
x = torch.randn(1, d, dtype=torch.float32, device='cuda') * 0.5
# Insert outlier channels
x[0, 0:10] = 80.0 # Large outliers
# FP32 LayerNorm (correct)
ln_fp32 = torch.nn.LayerNorm(d, device='cuda', dtype=torch.float32)
y_fp32 = ln_fp32(x)
# FP16 LayerNorm (problematic)
ln_fp16 = torch.nn.LayerNorm(d, device='cuda', dtype=torch.float16)
ln_fp16.weight.data = ln_fp32.weight.data.half()
ln_fp16.bias.data = ln_fp32.bias.data.half()
x_fp16 = x.half()
# Check for overflow in variance computation
variance_fp32 = x.var(dim=-1)
variance_fp16 = x_fp16.float().var(dim=-1) # Compute in float for comparison
# The FP16 computation of x^2 can overflow
x_squared_fp16 = (x_fp16 * x_fp16)
has_inf = torch.isinf(x_squared_fp16).any().item()
print(f"FP16 x^2 has inf: {has_inf}")
print(f"FP32 variance: {variance_fp32.item():.4f}")
try:
y_fp16 = ln_fp16(x_fp16)
has_nan = torch.isnan(y_fp16).any().item()
print(f"FP16 LayerNorm output has NaN: {has_nan}")
except RuntimeError as e:
print(f"FP16 LayerNorm failed: {e}")
Softmax: FP32
Softmax computes:
Even with the subtraction for numerical stability, FP16 softmax fails in several scenarios:
-
Exponent range: FP16 max is 65504. is near FP16 max. Attention logits can exceed 11 for long sequences or when attention is sharply focused.
-
Small probabilities: After softmax, most attention weights are near zero. FP16 minimum positive normal is . Any attention weight smaller than this becomes exactly zero, losing information about low-attention tokens.
-
Sum precision: The denominator sums potentially thousands of terms. FP16 accumulation loses precision.
def softmax_precision_failure(seq_len=4096):
"""Demonstrate softmax precision issues in FP16."""
# Simulate attention scores
# At long contexts, some scores can be large
scores = torch.randn(1, 32, seq_len, seq_len, device='cuda') * 2.0
# Inject a few very strong attention positions
scores[0, :, :, 0] = 15.0 # Strong attention to position 0
# FP32 softmax (reference)
probs_fp32 = torch.softmax(scores, dim=-1)
# FP16 softmax
scores_fp16 = scores.half()
probs_fp16 = torch.softmax(scores_fp16, dim=-1).float()
# Compare
max_diff = (probs_fp32 - probs_fp16).abs().max().item()
mean_diff = (probs_fp32 - probs_fp16).abs().mean().item()
num_zeros_fp32 = (probs_fp32 == 0).sum().item()
num_zeros_fp16 = (probs_fp16 == 0).sum().item()
print(f"Max probability difference: {max_diff:.8f}")
print(f"Mean probability difference: {mean_diff:.8f}")
print(f"Zero entries FP32: {num_zeros_fp32}, FP16: {num_zeros_fp16}")
print(f"FP16 lost {num_zeros_fp16 - num_zeros_fp32} non-zero probabilities")
FlashAttention computes softmax in FP32 within the kernel, even when the inputs and outputs are FP16/BF16. The online softmax algorithm maintains the running max and sum in FP32 registers. If you are using FlashAttention, you do not need to worry about softmax precision. If you are writing a custom attention kernel, you must handle this yourself.
Residual Additions: FP32
In a transformer with layers, each residual add contributes:
The hidden state passes through all layers via residual connections. If residual additions are done in FP16, rounding errors accumulate:
- Each FP16 addition has relative error (half-precision unit roundoff)
- After layers with 2 residual adds each, the accumulated error is
- For (Llama-2 70B), this is — an 8% relative error on the hidden state
def residual_accumulation_error(num_layers=80, hidden_dim=8192):
"""Simulate residual error accumulation across layers."""
# FP32 reference path
h_fp32 = torch.randn(1, 1, hidden_dim, dtype=torch.float32, device='cuda')
# FP16 path
h_fp16 = h_fp32.half()
for layer in range(num_layers):
# Simulate attention output
attn_out = torch.randn_like(h_fp32) * 0.1
# Simulate MLP output
mlp_out = torch.randn_like(h_fp32) * 0.1
# FP32 residual path
h_fp32 = h_fp32 + attn_out
h_fp32 = h_fp32 + mlp_out
# FP16 residual path
h_fp16 = h_fp16 + attn_out.half()
h_fp16 = h_fp16 + mlp_out.half()
if layer % 10 == 0:
rel_err = ((h_fp32 - h_fp16.float()).norm() / h_fp32.norm()).item()
print(f"Layer {layer:3d}: relative error = {rel_err:.6f}")
final_err = ((h_fp32 - h_fp16.float()).norm() / h_fp32.norm()).item()
print(f"Final relative error after {num_layers} layers: {final_err:.6f}")
return final_err
RoPE: FP32 for Trigonometric Computation
Rotary Position Embeddings compute and where is the position index and varies by dimension. At long contexts (), the product is large, and FP16 trigonometric functions lose precision.
def rope_precision_analysis(max_pos=131072, dim=128):
"""Show RoPE precision loss in FP16 at long positions."""
# Standard RoPE frequencies
base = 10000.0
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.tensor([1, 1000, 10000, 100000, max_pos], dtype=torch.float32)
for pos in positions:
angles_fp32 = pos * freqs
cos_fp32 = torch.cos(angles_fp32)
angles_fp16 = (pos * freqs).half()
cos_fp16 = torch.cos(angles_fp16.float()) # cos in fp32 of fp16 angle
# Angle quantization error
angle_err = (angles_fp32 - angles_fp16.float()).abs().max().item()
cos_err = (cos_fp32 - cos_fp16).abs().max().item()
print(f"Position {int(pos.item()):>7d}: "
f"max_angle_error={angle_err:.6f}, "
f"max_cos_error={cos_err:.6f}")
Embedding Lookup: FP16
Token embeddings are a lookup table. The embedding table stores vectors (vocabulary size times hidden dimension). There is no computation — just a table lookup. FP16 is sufficient because:
- The lookup itself is a memory operation, not a compute operation
- The embedding vectors have moderate magnitudes (initialized near zero, trained to values in the range of approximately -1 to 1)
- The subsequent LayerNorm (in FP32) handles any precision issues
def embedding_precision_analysis(vocab_size=128256, hidden=4096):
"""Embedding tables are fine in FP16."""
embed_fp32 = torch.nn.Embedding(vocab_size, hidden, dtype=torch.float32)
embed_fp16 = torch.nn.Embedding(vocab_size, hidden, dtype=torch.float16)
embed_fp16.weight.data = embed_fp32.weight.data.half()
# Random input tokens
input_ids = torch.randint(0, vocab_size, (1, 128))
out_fp32 = embed_fp32(input_ids)
out_fp16 = embed_fp16(input_ids).float()
rel_err = ((out_fp32 - out_fp16).norm() / out_fp32.norm()).item()
print(f"Embedding FP16 relative error: {rel_err:.8f}")
# Typically < 1e-3 -- negligible
Implementation: Per-Op Precision Annotation
The Precision Policy
from dataclasses import dataclass, field
from enum import Enum
class OpPrecision(Enum):
FP32 = "float32"
FP16 = "float16"
BF16 = "bfloat16"
FP8_E4M3 = "float8_e4m3fn"
FP8_E5M2 = "float8_e5m2"
INT8 = "int8"
INT4 = "int4"
@dataclass
class LayerPrecisionPolicy:
"""Precision policy for a single transformer layer."""
# GEMM inputs (weights and activations)
gemm_weight: OpPrecision = OpPrecision.FP8_E4M3
gemm_activation: OpPrecision = OpPrecision.FP8_E4M3
gemm_accumulator: OpPrecision = OpPrecision.FP32
# Normalization
layernorm: OpPrecision = OpPrecision.FP32
# Attention
softmax: OpPrecision = OpPrecision.FP32
rope: OpPrecision = OpPrecision.FP32
# Residual path
residual: OpPrecision = OpPrecision.FP32
# Activations (SiLU, GELU)
activation_fn: OpPrecision = OpPrecision.FP16
# KV cache storage
kv_cache: OpPrecision = OpPrecision.FP8_E4M3
# Output
lm_head: OpPrecision = OpPrecision.FP16
@dataclass
class ModelPrecisionPolicy:
"""Precision policy for the full model."""
embedding: OpPrecision = OpPrecision.FP16
layers: LayerPrecisionPolicy = field(default_factory=LayerPrecisionPolicy)
lm_head: OpPrecision = OpPrecision.FP16
def summary(self):
print("=== Model Precision Policy ===")
print(f"Embedding: {self.embedding.value}")
print(f"GEMM weight: {self.layers.gemm_weight.value}")
print(f"GEMM activation: {self.layers.gemm_activation.value}")
print(f"GEMM accumulate: {self.layers.gemm_accumulator.value}")
print(f"LayerNorm: {self.layers.layernorm.value}")
print(f"Softmax: {self.layers.softmax.value}")
print(f"RoPE: {self.layers.rope.value}")
print(f"Residual add: {self.layers.residual.value}")
print(f"Activation fn: {self.layers.activation_fn.value}")
print(f"KV cache: {self.layers.kv_cache.value}")
print(f"LM Head: {self.layers.lm_head.value}")
Standard Policies for Common Configurations
def policy_fp16_baseline():
"""Standard FP16 inference -- no quantization."""
return ModelPrecisionPolicy(
embedding=OpPrecision.FP16,
layers=LayerPrecisionPolicy(
gemm_weight=OpPrecision.FP16,
gemm_activation=OpPrecision.FP16,
gemm_accumulator=OpPrecision.FP32,
layernorm=OpPrecision.FP32,
softmax=OpPrecision.FP32,
rope=OpPrecision.FP32,
residual=OpPrecision.FP32,
activation_fn=OpPrecision.FP16,
kv_cache=OpPrecision.FP16,
lm_head=OpPrecision.FP16,
),
lm_head=OpPrecision.FP16,
)
def policy_w8a8_int8():
"""W8A8 INT8 inference (SmoothQuant style)."""
return ModelPrecisionPolicy(
embedding=OpPrecision.FP16,
layers=LayerPrecisionPolicy(
gemm_weight=OpPrecision.INT8,
gemm_activation=OpPrecision.INT8,
gemm_accumulator=OpPrecision.FP32,
layernorm=OpPrecision.FP32,
softmax=OpPrecision.FP32,
rope=OpPrecision.FP32,
residual=OpPrecision.FP32,
activation_fn=OpPrecision.FP16,
kv_cache=OpPrecision.INT8,
lm_head=OpPrecision.FP16,
),
lm_head=OpPrecision.FP16,
)
def policy_fp8_h100():
"""FP8 inference on H100 (optimal for Hopper)."""
return ModelPrecisionPolicy(
embedding=OpPrecision.FP16,
layers=LayerPrecisionPolicy(
gemm_weight=OpPrecision.FP8_E4M3,
gemm_activation=OpPrecision.FP8_E4M3,
gemm_accumulator=OpPrecision.FP32,
layernorm=OpPrecision.FP32,
softmax=OpPrecision.FP32,
rope=OpPrecision.FP32,
residual=OpPrecision.FP32,
activation_fn=OpPrecision.BF16,
kv_cache=OpPrecision.FP8_E4M3,
lm_head=OpPrecision.BF16,
),
lm_head=OpPrecision.BF16,
)
def policy_w4a16_gptq():
"""W4A16: INT4 weights, FP16 activations (GPTQ/AWQ style)."""
return ModelPrecisionPolicy(
embedding=OpPrecision.FP16,
layers=LayerPrecisionPolicy(
gemm_weight=OpPrecision.INT4,
gemm_activation=OpPrecision.FP16,
gemm_accumulator=OpPrecision.FP32,
layernorm=OpPrecision.FP32,
softmax=OpPrecision.FP32,
rope=OpPrecision.FP32,
residual=OpPrecision.FP32,
activation_fn=OpPrecision.FP16,
kv_cache=OpPrecision.FP16,
lm_head=OpPrecision.FP16,
),
lm_head=OpPrecision.FP16,
)
Applying the Policy to a Model
class MixedPrecisionWrapper(torch.nn.Module):
"""Wrap a transformer layer to enforce precision policy."""
def __init__(self, layer, policy):
super().__init__()
self.layer = layer
self.policy = policy
def cast_for_gemm(self, weight, activation):
"""Cast weight and activation to GEMM precision."""
wp = self.policy.gemm_weight
ap = self.policy.gemm_activation
if wp == OpPrecision.FP8_E4M3:
w = self.quantize_fp8(weight)
elif wp == OpPrecision.INT8:
w = self.quantize_int8(weight)
elif wp == OpPrecision.INT4:
w = weight # INT4 dequantized at kernel level
else:
w = weight.to(getattr(torch, wp.value))
if ap == OpPrecision.FP8_E4M3:
a = self.quantize_fp8(activation)
elif ap == OpPrecision.INT8:
a = self.quantize_int8(activation)
else:
a = activation.to(getattr(torch, ap.value))
return w, a
def quantize_fp8(self, tensor):
"""Quantize to FP8 E4M3 with per-tensor scale."""
abs_max = tensor.detach().abs().max()
# FP8 E4M3 max value is 448
scale = abs_max / 448.0
scale = max(scale.item(), 1e-12)
# Simulate FP8: scale down, clamp, scale back
t_scaled = tensor / scale
t_clamped = torch.clamp(t_scaled, -448.0, 448.0)
return t_clamped * scale, scale
def quantize_int8(self, tensor):
"""Quantize to INT8 with per-tensor scale."""
abs_max = tensor.detach().abs().max()
scale = abs_max / 127.0
scale = max(scale.item(), 1e-12)
t_int = torch.clamp(torch.round(tensor / scale), -128, 127)
return t_int * scale, scale
def forward_norm(self, norm_module, x):
"""Run normalization in policy-specified precision."""
target_dtype = getattr(torch, self.policy.layernorm.value)
x_cast = x.to(target_dtype)
out = norm_module(x_cast)
return out.to(x.dtype)
def forward_residual(self, residual, new_value):
"""Run residual addition in policy-specified precision."""
target_dtype = getattr(torch, self.policy.residual.value)
return (residual.to(target_dtype) + new_value.to(target_dtype)).to(residual.dtype)
The Memory Bandwidth Perspective
Why GEMM Precision Affects Decode Speed
During autoregressive decode, each token requires reading the entire weight matrix from HBM. The decode step is memory-bandwidth-bound, not compute-bound. Reducing weight precision from FP16 to FP8 halves the data read, directly translating to 2x higher tokens/second.
def compute_decode_bandwidth_requirements(
model_params_B=70,
num_layers=80,
hidden_dim=8192,
batch_size=1,
hbm_bandwidth_GBs=3350, # H100
):
"""Calculate time per token for different weight precisions."""
total_weight_bytes = {
"FP16": model_params_B * 1e9 * 2, # 2 bytes per param
"FP8": model_params_B * 1e9 * 1, # 1 byte per param
"INT8": model_params_B * 1e9 * 1, # 1 byte per param
"INT4": model_params_B * 1e9 * 0.5, # 0.5 bytes per param
}
print(f"Model: {model_params_B}B params, HBM BW: {hbm_bandwidth_GBs} GB/s")
print(f"{'Precision':<10} {'Weight Size':<15} {'Time/Token':<15} {'Tokens/sec':<15}")
for precision, size_bytes in total_weight_bytes.items():
size_gb = size_bytes / 1e9
time_per_token_ms = (size_gb / hbm_bandwidth_GBs) * 1000
tokens_per_sec = 1000 / time_per_token_ms
print(f"{precision:<10} {size_gb:>8.1f} GB "
f"{time_per_token_ms:>8.2f} ms "
f"{tokens_per_sec:>8.1f} tok/s")
compute_decode_bandwidth_requirements()
Decode Tokens/sec by Weight Precision (Llama-2 70B, H100 SXM)
(tokens/sec)LayerNorm, softmax, activation functions, and residual adds operate on tensors of size (batch, sequence length, hidden dim). For decode (), these tensors are tiny compared to the weight matrices. Running them in FP32 instead of FP16 doubles their size but has negligible impact on total memory bandwidth (<1% of total traffic).
Prefill vs Decode: Different Bottlenecks
During prefill (processing the prompt), the operation is compute-bound, not memory-bound. The weight matrices are read once but used for many tokens. Here, the tensor core throughput matters:
def prefill_vs_decode_analysis(
model_params_B=70,
hidden_dim=8192,
prompt_length=2048,
hbm_bandwidth_GBs=3350,
fp16_tflops=990, # H100 FP16 tensor core
fp8_tflops=1979, # H100 FP8 tensor core
int8_tops=1979, # H100 INT8 tensor core
):
"""Compare prefill (compute-bound) and decode (memory-bound)."""
weight_bytes_fp16 = model_params_B * 1e9 * 2
weight_bytes_fp8 = model_params_B * 1e9 * 1
# Prefill: compute-bound
# FLOPs = 2 * params * seq_len (for each token, 2 FLOPs per weight)
flops_prefill = 2 * model_params_B * 1e9 * prompt_length
prefill_fp16_ms = (flops_prefill / (fp16_tflops * 1e12)) * 1000
prefill_fp8_ms = (flops_prefill / (fp8_tflops * 1e12)) * 1000
prefill_int8_ms = (flops_prefill / (int8_tops * 1e12)) * 1000
# Decode: memory-bound
decode_fp16_ms = (weight_bytes_fp16 / (hbm_bandwidth_GBs * 1e9)) * 1000
decode_fp8_ms = (weight_bytes_fp8 / (hbm_bandwidth_GBs * 1e9)) * 1000
print("=== Prefill (compute-bound) ===")
print(f"FP16: {prefill_fp16_ms:.1f} ms ({prompt_length} tokens)")
print(f"FP8: {prefill_fp8_ms:.1f} ms (speedup: {prefill_fp16_ms/prefill_fp8_ms:.2f}x)")
print(f"INT8: {prefill_int8_ms:.1f} ms (speedup: {prefill_fp16_ms/prefill_int8_ms:.2f}x)")
print("\n=== Decode (memory-bound) ===")
print(f"FP16: {decode_fp16_ms:.2f} ms per token")
print(f"FP8: {decode_fp8_ms:.2f} ms per token "
f"(speedup: {decode_fp16_ms/decode_fp8_ms:.2f}x)")
prefill_vs_decode_analysis()
Prefill vs Decode Speedup from FP8 (Llama-2 70B, H100)
| Phase | FP16 Time | FP8 Time | Speedup | Bottleneck |
|---|---|---|---|---|
| Prefill (2K tokens) | 290 ms | 145 ms | 2.0x | Compute (tensor cores) |
| Decode (1 token) | 41.8 ms | 20.9 ms | 2.0x | Memory bandwidth |
| Prefill (128 tokens) | 18.1 ms | 9.1 ms | 2.0x | Compute |
| Decode (batch=32) | 41.8 ms | 20.9 ms | 2.0x | Memory bandwidth |
Mixed Precision in Production Systems
vLLM’s Precision Handling
vLLM applies precision per-op based on the quantization configuration:
# Pseudocode showing vLLM's mixed precision handling
class LlamaDecoderLayer:
def forward(self, hidden_states, kv_cache):
# 1. RMSNorm: always FP32 internally
residual = hidden_states
normed = self.input_layernorm(hidden_states) # FP32 internal
# 2. QKV Projection: quantized GEMM
# Weights stored in INT4/INT8/FP8
# Dequantize + GEMM in one fused kernel
qkv = self.qkv_proj(normed) # FP8 compute, FP16 output
# 3. RoPE: FP32 trig, cast back to FP16
q, k = apply_rope(qkv, positions) # FP32 sin/cos
# 4. Attention: FlashAttention handles precision internally
# QK^T in reduced precision, softmax in FP32,
# output in FP16
attn_out = flash_attention(q, k, v, kv_cache)
# 5. Output projection: quantized GEMM
attn_out = self.o_proj(attn_out) # FP8 compute, FP16 output
# 6. Residual: FP32
hidden_states = residual.float() + attn_out.float()
hidden_states = hidden_states.to(residual.dtype)
# 7. Post-attention norm: FP32 internal
residual = hidden_states
normed = self.post_attention_layernorm(hidden_states)
# 8. MLP: quantized GEMMs
gate = self.gate_proj(normed) # FP8 compute
up = self.up_proj(normed) # FP8 compute
mlp_out = F.silu(gate) * up # FP16 activation
mlp_out = self.down_proj(mlp_out) # FP8 compute
# 9. Residual: FP32
hidden_states = residual.float() + mlp_out.float()
hidden_states = hidden_states.to(residual.dtype)
return hidden_states
BF16 vs FP16 for Non-GEMM Operations
BF16 (bfloat16) has the same exponent range as FP32 (8 exponent bits) but only 7 mantissa bits (vs 10 for FP16). For non-GEMM operations:
def bf16_vs_fp16_comparison():
"""BF16 has larger range but lower precision than FP16."""
x = torch.tensor([65504.0, 0.00006, 1.0009765625])
fp16 = x.half()
bf16 = x.bfloat16()
fp32 = x.float()
print("Value | FP32 | FP16 | BF16")
print("-" * 65)
for i in range(len(x)):
print(f"{fp32[i].item():<12} | "
f"{fp32[i].item():<13} | "
f"{fp16[i].item():<13} | "
f"{bf16[i].item():<13}")
# BF16 advantage: no overflow for large intermediate values
large_val = torch.tensor([100000.0])
print(f"\n100000.0 in FP16: {large_val.half().item()}") # inf!
print(f"100000.0 in BF16: {large_val.bfloat16().item()}") # 99840.0
Modern LLM inference (H100 and later) uses BF16 for non-GEMM operations instead of FP16. BF16 cannot overflow for values that occur in practice (range up to ), eliminating the need for explicit overflow checks. The precision loss (7 vs 10 mantissa bits) is acceptable for intermediate computations.
Precision Validation Framework
Automated Precision Testing
class PrecisionValidator:
"""Validate that each op produces correct results at its assigned precision."""
def __init__(self, model, policy, tolerance=0.01):
self.model = model
self.policy = policy
self.tolerance = tolerance
self.results = {}
def validate_layernorm(self, test_input):
"""Test LayerNorm at different precisions."""
for name, module in self.model.named_modules():
if not isinstance(module, (torch.nn.LayerNorm, torch.nn.RMSNorm)):
continue
# FP32 reference
ref = module(test_input.float()).float()
# Test FP16
try:
fp16_out = module(test_input.half()).float()
fp16_err = ((ref - fp16_out).norm() / ref.norm()).item()
has_nan = torch.isnan(fp16_out).any().item()
except RuntimeError:
fp16_err = float('inf')
has_nan = True
# Test BF16
bf16_out = module(test_input.bfloat16()).float()
bf16_err = ((ref - bf16_out).norm() / ref.norm()).item()
self.results[name] = {
'fp16_error': fp16_err,
'fp16_has_nan': has_nan,
'bf16_error': bf16_err,
'recommendation': 'FP32' if has_nan or fp16_err > self.tolerance
else 'FP16/BF16'
}
return self.results
def validate_softmax(self, score_tensor):
"""Test softmax at different precisions."""
# FP32 reference
ref = torch.softmax(score_tensor.float(), dim=-1)
# FP16
fp16_out = torch.softmax(score_tensor.half(), dim=-1).float()
fp16_err = ((ref - fp16_out).norm() / ref.norm()).item()
fp16_nan = torch.isnan(fp16_out).any().item()
# BF16
bf16_out = torch.softmax(score_tensor.bfloat16(), dim=-1).float()
bf16_err = ((ref - bf16_out).norm() / ref.norm()).item()
self.results['softmax'] = {
'fp16_error': fp16_err,
'fp16_has_nan': fp16_nan,
'bf16_error': bf16_err,
}
return self.results
def report(self):
"""Print validation report."""
print("\n=== Precision Validation Report ===")
for name, result in self.results.items():
status = "PASS" if not result.get('fp16_has_nan', False) else "FAIL"
print(f"{name}: {status}")
for key, val in result.items():
print(f" {key}: {val}")
Summary
Mixed precision inference is not optional — it is required for correct and efficient LLM serving. The rules are concrete: GEMMs use FP8/INT8 for throughput (this is where 90%+ of compute lives), LayerNorm and softmax use FP32 for numerical correctness (FP16 produces NaN or overflow), residual additions use FP32 to prevent error accumulation across layers, and RoPE uses FP32 for trigonometric precision at long contexts.
The implementation is straightforward: define a precision policy per operation type, apply casts at operation boundaries, and validate with automated tests. Production systems like vLLM handle this internally — the user specifies the weight quantization format, and the framework applies the correct mixed precision policy for every other operation.