FP8 is the most significant precision reduction since the introduction of FP16 tensor cores in Volta. On NVIDIA H100, FP8 tensor cores deliver 1,979 TFLOPS — exactly 2x the FP16 rate of 989 TFLOPS — while consuming half the memory bandwidth per element. For large, compute-bound GEMMs (the prefill phase of LLM inference), this translates directly to a near-2x throughput improvement. For memory-bound GEMMs (decode phase), the bandwidth reduction from 2 bytes to 1 byte per element provides a 1.5-1.8x speedup depending on the shape.
The catch: FP8 has severely limited range and precision. The E4M3 format covers with roughly 0.1% relative precision at best. Getting FP8 inference to work without quality loss requires understanding where the precision matters, how to compute scaling factors, which operations to quantize (GEMMs only), and which to leave in higher precision (everything else). This post covers all of it.
1. The FP8 E4M3 Format
IEEE 754 does not define an 8-bit floating point format. NVIDIA, ARM, and Intel jointly specified two FP8 variants in the OFP8 (Open FP8) standard:
E4M3: 4 exponent bits, 3 mantissa bits
Bit layout: [S][EEEE][MMM]
1 4 3 = 8 bits
S = sign bit (0 = positive, 1 = negative)
E = exponent bits (4 bits, bias = 7)
M = mantissa bits (3 bits, implicit leading 1 for normals)
Value encoding:
For normal numbers ( and ):
For subnormal numbers (, ):
Special values:
- : NaN (only one NaN encoding, unlike IEEE 754)
- : valid number (NOT infinity — E4M3 sacrifices infinity for range)
- : zero
This is the key departure from IEEE 754: E4M3 has no infinity representation. The bit pattern that would be infinity in IEEE format () instead represents the value . The maximum representable value is:
FP8 E4M3 Format Properties
| Property | Value | Comparison to FP16 |
|---|---|---|
| Total bits | 8 | 16 |
| Sign bits | 1 | 1 |
| Exponent bits | 4 | 5 |
| Mantissa bits | 3 | 10 |
| Exponent bias | 7 | 15 |
| Max normal value | 448 | 65504 |
| Min normal value | 2^-6 = 0.015625 | 2^-14 = 6.1e-5 |
| Min subnormal | 2^-9 = 0.001953 | 2^-24 = 5.96e-8 |
| Precision (mantissa) | 3 bits = 12.5% | 10 bits = 0.098% |
| Has infinity? | No | Yes |
| NaN encodings | 1 | 2046 |
| Unique representable values | 448 | 65536 |
E5M2: 5 exponent bits, 2 mantissa bits
Bit layout: [S][EEEEE][MM]
1 5 2 = 8 bits
E5M2 follows IEEE 754 conventions: it has infinity and NaN. The tradeoff is even less precision (2 mantissa bits = 25% relative precision) but wider range ().
When to use which:
- E4M3 for forward pass (inference): The extra mantissa bit matters more than the range, because we can control the range via scaling factors.
- E5M2 for backward pass (training gradients): Gradients have wider dynamic range and benefit from the larger exponent. Less relevant for inference.
For LLM inference, E4M3 is the only format that matters. All FP8 inference implementations (TensorRT-LLM, vLLM, SGLang) use E4M3 for both weights and activations. E5M2 is used only during training for gradient accumulation. The rest of this post focuses exclusively on E4M3.
Representable Values
With 3 mantissa bits, the spacing between consecutive representable values within the same exponent range is:
where is the biased exponent. This means:
- Between 1.0 and 2.0: 8 evenly spaced values (1.0, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875)
- Between 2.0 and 4.0: 8 values spaced by 0.25
- Between 128 and 256: 8 values spaced by 16
- Between 256 and 448: 8 values spaced by 32
import torch
# Enumerate all positive E4M3 values
def enumerate_e4m3_values():
values = []
# Subnormals: E=0, M=1..7
for m in range(1, 8):
val = 2**(-6) * (m / 8)
values.append(val)
# Normals: E=1..14, M=0..7
for e in range(1, 15):
for m in range(8):
val = 2**(e - 7) * (1 + m / 8)
values.append(val)
# E=15, M=0..6 (M=7 is NaN)
for m in range(7):
val = 2**(15 - 7) * (1 + m / 8)
values.append(val)
return sorted(values)
e4m3_values = enumerate_e4m3_values()
print(f"Number of positive values: {len(e4m3_values)}")
# 119 positive values + 119 negative + 1 zero + 1 NaN = 240 total
print(f"Smallest subnormal: {e4m3_values[0]}") # 0.001953125
print(f"Smallest normal: {e4m3_values[7]}") # 0.015625
print(f"Largest value: {e4m3_values[-1]}") # 448.0
The entire format has only 240 representable values (including negatives and zero). Compare with FP16’s approximately 65,000 representable values. Every quantization from FP16 to E4M3 maps approximately 273 FP16 values to a single E4M3 value.
2. Per-Tensor Dynamic Scaling
The fundamental problem with FP8’s limited range: transformer activations and weights span very different numerical ranges. Layer norm outputs might be in , while FFN intermediate activations after SiLU might be in . If we naively cast to E4M3, small values lose all precision and large values might overflow.
The solution: per-tensor scaling. Before casting to FP8, multiply by a scale factor that maps the tensor’s range to E4M3’s representable range.
Scale Factor Computation
The quantized tensor is:
And dequantization recovers the original:
def compute_scale(tensor, fp8_max=448.0):
"""
Compute per-tensor scale for FP8 quantization.
The scale maps the tensor's range to [-448, 448].
"""
amax = tensor.abs().max().float()
# Avoid division by zero
if amax == 0:
return torch.tensor(1.0, device=tensor.device)
scale = fp8_max / amax
# Clamp scale to avoid overflow in the scale factor itself
# Scale is stored in FP32 — no precision concerns
return scale.clamp(min=1e-12)
def quantize_to_fp8(tensor, scale):
"""
Quantize a tensor to FP8 E4M3.
Args:
tensor: FP16/BF16/FP32 tensor
scale: per-tensor scale factor (FP32 scalar)
Returns:
fp8_tensor: quantized tensor in torch.float8_e4m3fn
"""
# Scale and clamp to E4M3 range
scaled = tensor.float() * scale
scaled = scaled.clamp(-448.0, 448.0)
# Cast to FP8
fp8_tensor = scaled.to(torch.float8_e4m3fn)
return fp8_tensor
def dequantize_from_fp8(fp8_tensor, scale):
"""Dequantize FP8 tensor back to FP16."""
return fp8_tensor.float() / scale
Per-Tensor vs Per-Channel vs Per-Token Scaling
The granularity of the scale factor affects both accuracy and performance:
Per-tensor: one scale for the entire [M, K] or [K, N] matrix
Per-channel: one scale per column (weights) or per row (activations)
Per-token: one scale per row in the activation tensor
Per-group: one scale per group of G elements (e.g., G=128)
FP8 Scaling Granularity Comparison
| Granularity | Scales Count (weights [K,N]) | Scales Count (acts [M,K]) | Accuracy | HW Support (H100) |
|---|---|---|---|---|
| Per-tensor | 1 | 1 | Baseline | Native TMA |
| Per-channel (weights) | N | — | +0.1-0.3% acc | Requires custom kernel |
| Per-token (activations) | — | M | +0.2-0.5% acc | Supported via row scaling |
| Per-group (G=128) | K*N/G | M*K/G | +0.5-1.0% acc | Not natively supported |
In practice, the standard approach for FP8 inference is:
- Weights: Per-tensor or per-channel scaling, computed offline during calibration
- Activations: Per-tensor or per-token scaling, computed dynamically at runtime
Per-tensor scaling for both operands is the simplest and best-supported path on H100:
# FP8 GEMM with per-tensor scaling:
# C = (A_fp8 / scale_A) @ (B_fp8 / scale_B)
# = (A_fp8 @ B_fp8) / (scale_A * scale_B)
#
# The tensor core computes A_fp8 @ B_fp8 in FP8,
# accumulates in FP32, and the epilogue divides by
# (scale_A * scale_B). Single kernel, no overhead.
Per-tensor scaling adds exactly ONE FP32 division in the GEMM epilogue per output element. Since the GEMM itself performs FLOPs per output element (where is thousands to tens of thousands), the scaling overhead is negligible — less than 0.01% of total compute.
3. Which Operations Use FP8
Not every operation in a transformer can use FP8. The rule is:
FP8 for GEMMs. Higher precision for everything else.
Operations That Use FP8
| Operation | Input Precision | Weight Precision | Accumulation | Output |
|---|---|---|---|---|
| QKV projection | FP8 (E4M3) | FP8 (E4M3) | FP32 | BF16 |
| Output projection | FP8 | FP8 | FP32 | BF16 |
| Gate projection | FP8 | FP8 | FP32 | BF16 |
| Up projection | FP8 | FP8 | FP32 | BF16 |
| Down projection | FP8 | FP8 | FP32 | BF16 |
Operations That Stay in BF16/FP16
| Operation | Why Not FP8 |
|---|---|
| Layer norm / RMS norm | Requires high-precision running statistics. 3 mantissa bits produce incorrect variance. |
| Softmax | Exponential and division are numerically sensitive. The log-sum-exp trick requires precision. |
| SiLU / GELU activation | Non-linear; small input differences produce large output differences. |
| Residual addition | Accumulates across layers. FP8 rounding errors compound. |
| Rotary embeddings | Sine/cosine computation requires precision. |
| Embedding lookup | Table lookup, not compute. No benefit from FP8. |
| Attention scores (QK) | Typically done in FP16/BF16 within FlashAttention. Could use FP8 but quality degrades. |
| Attention values (PV) | Same as above. |
class FP8TransformerLayer:
"""
Transformer layer with FP8 GEMMs and BF16 everything else.
"""
def __init__(self, config):
# FP8 weights (quantized offline)
self.qkv_weight_fp8 = None # [d, n_h*d_h + 2*n_kv*d_h] in E4M3
self.qkv_scale = None # FP32 scalar
self.o_weight_fp8 = None # [n_h*d_h, d] in E4M3
self.o_scale = None
self.gate_up_weight_fp8 = None # [d, 2*d_ff] in E4M3
self.gate_up_scale = None
self.down_weight_fp8 = None # [d_ff, d] in E4M3
self.down_scale = None
# BF16 parameters (NOT quantized)
self.rms_norm_weight = None # [d] in BF16
self.rms_norm2_weight = None # [d] in BF16
def forward(self, x):
"""
x: [B, d] in BF16
All GEMMs use FP8. All other ops use BF16.
"""
# ---- RMS Norm (BF16) ----
normed = rms_norm(x, self.rms_norm_weight) # BF16
# ---- QKV Projection (FP8 GEMM) ----
act_scale = compute_scale(normed)
normed_fp8 = quantize_to_fp8(normed, act_scale)
qkv = fp8_gemm(
normed_fp8, act_scale,
self.qkv_weight_fp8, self.qkv_scale
) # Output in BF16
# ---- Attention (BF16 — FlashAttention) ----
q, k, v = split_qkv(qkv)
attn_out = flash_attention(q, k, v) # BF16
# ---- Output Projection (FP8 GEMM) ----
act_scale = compute_scale(attn_out)
attn_fp8 = quantize_to_fp8(attn_out, act_scale)
o = fp8_gemm(
attn_fp8, act_scale,
self.o_weight_fp8, self.o_scale
) # BF16
# ---- Residual Add (BF16) ----
x = x + o
# ---- RMS Norm 2 (BF16) ----
normed2 = rms_norm(x, self.rms_norm2_weight)
# ---- Gate + Up Projection (FP8 GEMM) ----
act_scale = compute_scale(normed2)
normed2_fp8 = quantize_to_fp8(normed2, act_scale)
gate_up = fp8_gemm(
normed2_fp8, act_scale,
self.gate_up_weight_fp8, self.gate_up_scale
) # BF16
# ---- SiLU Activation (BF16) ----
gate, up = gate_up.chunk(2, dim=-1)
intermediate = torch.nn.functional.silu(gate) * up # BF16
# ---- Down Projection (FP8 GEMM) ----
act_scale = compute_scale(intermediate)
inter_fp8 = quantize_to_fp8(intermediate, act_scale)
down = fp8_gemm(
inter_fp8, act_scale,
self.down_weight_fp8, self.down_scale
) # BF16
# ---- Residual Add (BF16) ----
x = x + down
return x
Some implementations quantize the QK dot product and PV multiply to FP8. This can work for large models (70B+) where individual head dimensions are large, but degrades quality for smaller models. The attention mechanism is particularly sensitive because softmax amplifies small numerical errors — a 1% error in attention scores can become a 5-10% error in attention weights after exponentiation. The safe default is to keep attention in BF16 and only quantize the linear projections.
4. H100 FP8 Tensor Cores: 1,979 TFLOPS
The H100 SXM5 delivers these peak throughput numbers:
H100 SXM5 Tensor Core Throughput by Precision
| Precision | Peak TFLOPS | Bytes/Element | Bandwidth-Equivalent TFLOPS | Ridge Point (FLOP/byte) |
|---|---|---|---|---|
| FP64 | 67 | 8 | — | 20 |
| TF32 | 495 | 4 | — | 148 |
| FP16 | 989 | 2 | 989 | 295 |
| BF16 | 989 | 2 | 989 | 295 |
| FP8 (E4M3) | 1979 | 1 | 1979 | 591 |
| INT8 | 1979 | 1 | 1979 | 591 |
Why Exactly 2x
The H100 tensor core pipeline processes two FP8 elements in the same cycle that it processes one FP16 element. The FP8 MMA instruction shape is m16n8k32 (K=32 for FP8 vs K=16 for FP16), meaning each instruction processes twice as many multiply-accumulate operations. The accumulator is FP32 in both cases.
FP16 MMA: m16n8k16 -> 16*8*16*2 = 4096 FLOPs per instruction
FP8 MMA: m16n8k32 -> 16*8*32*2 = 8192 FLOPs per instruction
Same number of instructions per cycle -> 2x FLOPs
Practical FP8 Throughput
The 2x is a peak number. Actual throughput depends on arithmetic intensity:
FP8 vs FP16 Actual Throughput: FFN GEMM [M, 57344, 8192]
(TFLOPS)At (decode), both FP8 and FP16 are memory-bandwidth-bound. FP8 loads 1 byte per weight instead of 2, but the GEMM shape means the weight matrix dominates bandwidth. The speedup from halving weight bytes is offset by the same activation reads and output writes. Net: approximately 1.03x.
At , FP8 achieves 1120 TFLOPS (57% of peak) vs FP16 at 790 TFLOPS (80% of peak). The FP8 speedup is 1.42x — less than 2x because the FP8 GEMM has not yet reached its ridge point.
At , FP8 achieves 1720 TFLOPS (87% of peak) vs FP16 at 920 TFLOPS (93% of peak). The speedup is 1.87x — approaching but not reaching 2x because of memory traffic for activations.
The FP8 speedup over FP16 ranges from 1.0x (decode, ) to 1.95x (large prefill, ). For a serving system running a mix of prefill and decode, the aggregate speedup is typically 1.3-1.6x at realistic batch sizes. Claims of “2x from FP8” assume compute-bound regimes that many workloads do not reach.
5. Quantization Workflow
The full FP8 quantization workflow has three phases: calibration, offline weight quantization, and online activation quantization.
Phase 1: Calibration
Run a representative dataset through the model in FP16/BF16. For each tensor that will be quantized, record the maximum absolute value (amax):
class CalibrationObserver:
"""
Collect activation statistics for FP8 scale computation.
Attach to each linear layer as a forward hook.
"""
def __init__(self):
self.amax_history = []
self.max_samples = 512 # Number of calibration samples
def observe(self, tensor):
"""Record the amax of a tensor."""
amax = tensor.abs().max().item()
self.amax_history.append(amax)
def compute_scale(self, fp8_max=448.0, percentile=99.99):
"""
Compute scale from observed amax values.
Using the percentile instead of absolute max provides
robustness against outliers. A single outlier value
would set the scale too conservatively, wasting precision
for the rest of the distribution.
"""
import numpy as np
amax = np.percentile(self.amax_history, percentile)
return fp8_max / max(amax, 1e-12)
def calibrate_model(model, calibration_loader, num_batches=32):
"""
Run calibration to determine per-tensor scales.
Args:
model: FP16/BF16 model
calibration_loader: representative data
num_batches: number of calibration batches
Returns:
scales: dict mapping layer_name -> {input_scale, weight_scale}
"""
observers = {}
# Register observers for each linear layer
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
observer_input = CalibrationObserver()
observer_weight = CalibrationObserver()
# Record weight amax (static — does not change)
observer_weight.observe(module.weight.data)
# Hook for input activations
def make_hook(obs):
def hook(mod, inp, out):
obs.observe(inp[0])
return hook
module.register_forward_hook(make_hook(observer_input))
observers[name] = {
'input': observer_input,
'weight': observer_weight,
}
# Run calibration
model.eval()
with torch.no_grad():
for i, batch in enumerate(calibration_loader):
if i >= num_batches:
break
model(batch['input_ids'].cuda())
# Compute scales
scales = {}
for name, obs in observers.items():
scales[name] = {
'input_scale': obs['input'].compute_scale(),
'weight_scale': obs['weight'].compute_scale(),
}
return scales
Phase 2: Offline Weight Quantization
Quantize model weights to FP8 and save. This is done once.
def quantize_weights_fp8(model, scales):
"""
Quantize all linear layer weights to FP8 E4M3.
The original FP16 weights are replaced with FP8 weights
plus an FP32 scale factor. Memory: 1 byte/param + negligible
scale overhead (one FP32 per tensor).
"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and name in scales:
weight = module.weight.data # [out, in] in FP16/BF16
scale = scales[name]['weight_scale']
# Quantize
weight_scaled = weight.float() * scale
weight_scaled = weight_scaled.clamp(-448.0, 448.0)
weight_fp8 = weight_scaled.to(torch.float8_e4m3fn)
# Replace weight
module.weight = torch.nn.Parameter(
weight_fp8, requires_grad=False
)
# Store scale as a buffer (not a parameter)
module.register_buffer(
'weight_scale',
torch.tensor(scale, dtype=torch.float32)
)
return model
def save_fp8_model(model, path):
"""Save FP8-quantized model."""
state_dict = {}
for name, param in model.named_parameters():
state_dict[name] = param.data
for name, buf in model.named_buffers():
state_dict[name] = buf
torch.save(state_dict, path)
# Model size: ~50% of FP16 (1 byte/param vs 2 bytes/param)
Phase 3: Online Activation Quantization
During inference, activations are quantized to FP8 dynamically before each GEMM. The scale is computed on-the-fly from the activation tensor’s amax.
class FP8Linear(torch.nn.Module):
"""
Linear layer with FP8 weights and dynamic FP8 activation
quantization.
"""
def __init__(self, in_features, out_features):
super().__init__()
# FP8 weight and scale (set during quantization)
self.weight_fp8 = None # [out, in] in float8_e4m3fn
self.weight_scale = None # FP32 scalar
# Optional: use delayed scaling (reuse previous step's scale)
self.use_delayed_scaling = False
self.prev_input_scale = None
def forward(self, x):
"""
x: [B, in_features] in BF16
1. Dynamically quantize x to FP8
2. Execute FP8 GEMM
3. Descale output to BF16
"""
# Dynamic activation scaling
if self.use_delayed_scaling and self.prev_input_scale is not None:
input_scale = self.prev_input_scale
else:
amax = x.abs().max()
input_scale = (448.0 / amax).clamp(min=1e-12)
# Store for next step (delayed scaling)
if self.use_delayed_scaling:
self.prev_input_scale = input_scale.detach()
# Quantize activation to FP8
x_scaled = (x.float() * input_scale).clamp(-448.0, 448.0)
x_fp8 = x_scaled.to(torch.float8_e4m3fn)
# FP8 GEMM with FP32 accumulation
# Output descaling: divide by (input_scale * weight_scale)
output = torch._scaled_mm(
x_fp8,
self.weight_fp8.t(),
out_dtype=torch.bfloat16,
scale_a=torch.tensor(1.0 / input_scale,
dtype=torch.float32,
device=x.device),
scale_b=torch.tensor(1.0 / self.weight_scale,
dtype=torch.float32,
device=x.device),
)
return output
Computing x.abs().max() requires a full reduction over the activation tensor — an extra kernel launch and global memory read. Delayed scaling reuses the previous step’s scale factor, eliminating this overhead. The assumption is that activation ranges change slowly between tokens, which holds in practice for autoregressive generation. The first token uses a default scale (e.g., 1.0) or a calibrated scale.
Dynamic vs Static Scaling
| Approach | Scale Source | Advantages | Disadvantages |
|---|---|---|---|
| Static (calibration) | Pre-computed from calibration set | No runtime overhead, simple | May not cover all input distributions |
| Dynamic (per-tensor) | Computed from current tensor | Adapts to actual data | Extra amax kernel per GEMM |
| Delayed dynamic | Previous step’s amax | Minimal overhead, adaptive | Slight staleness (1 token lag) |
Production systems typically use:
- Static scaling for weights (computed once during quantization)
- Delayed dynamic scaling for activations (previous step’s amax)
6. Complete FP8 Inference Implementation
Here is a complete, runnable implementation of FP8 inference for a transformer model using PyTorch’s native FP8 support:
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
"""RMS Normalization — always in FP32/BF16, never FP8."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
# Compute in FP32 for numerical stability
x_float = x.float()
rms = torch.sqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
normed = x_float / rms
return (normed * self.weight.float()).to(x.dtype)
class FP8LinearLayer(nn.Module):
"""
FP8 linear layer with per-tensor scaling.
Uses torch._scaled_mm for FP8 GEMM on H100.
"""
def __init__(self, in_features, out_features, bias=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Placeholder — will be set during quantization
self.register_buffer(
'weight_fp8',
torch.zeros(out_features, in_features,
dtype=torch.float8_e4m3fn)
)
self.register_buffer(
'weight_scale',
torch.tensor(1.0, dtype=torch.float32)
)
self.register_buffer(
'input_scale',
torch.tensor(1.0, dtype=torch.float32)
)
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.bias = None
@torch.no_grad()
def quantize_weight(self, weight_fp16):
"""Quantize an FP16 weight tensor to FP8."""
amax = weight_fp16.abs().max().float()
scale = (448.0 / amax).clamp(min=1e-12)
w_scaled = (weight_fp16.float() * scale).clamp(-448.0, 448.0)
self.weight_fp8.copy_(w_scaled.to(torch.float8_e4m3fn))
self.weight_scale.fill_(scale.item())
def forward(self, x):
"""
x: [*, in_features] in BF16
Returns: [*, out_features] in BF16
"""
orig_shape = x.shape
x = x.reshape(-1, self.in_features) # [B, in]
# Dynamic activation quantization
amax = x.abs().max().float()
act_scale = (448.0 / amax).clamp(min=1e-12)
x_scaled = (x.float() * act_scale).clamp(-448.0, 448.0)
x_fp8 = x_scaled.to(torch.float8_e4m3fn)
# FP8 GEMM: x_fp8 @ weight_fp8.T
# torch._scaled_mm handles the descaling in the epilogue
inv_act_scale = (1.0 / act_scale).to(torch.float32)
inv_weight_scale = (1.0 / self.weight_scale).to(torch.float32)
output = torch._scaled_mm(
x_fp8, # [B, in] E4M3
self.weight_fp8.t().contiguous(), # [in, out] E4M3
out_dtype=torch.bfloat16,
scale_a=inv_act_scale,
scale_b=inv_weight_scale,
) # [B, out] BF16
if self.bias is not None:
output = output + self.bias.to(output.dtype)
return output.reshape(*orig_shape[:-1], self.out_features)
class FP8TransformerBlock(nn.Module):
"""
Single transformer block with FP8 GEMMs.
Norms, attention, activations stay in BF16.
"""
def __init__(self, dim, n_heads, n_kv_heads, ff_dim):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = dim // n_heads
# Norms (BF16)
self.attn_norm = RMSNorm(dim)
self.ffn_norm = RMSNorm(dim)
# Attention projections (FP8)
self.q_proj = FP8LinearLayer(dim, n_heads * self.head_dim)
self.k_proj = FP8LinearLayer(dim, n_kv_heads * self.head_dim)
self.v_proj = FP8LinearLayer(dim, n_kv_heads * self.head_dim)
self.o_proj = FP8LinearLayer(n_heads * self.head_dim, dim)
# FFN projections (FP8)
self.gate_proj = FP8LinearLayer(dim, ff_dim)
self.up_proj = FP8LinearLayer(dim, ff_dim)
self.down_proj = FP8LinearLayer(ff_dim, dim)
def forward(self, x, cos_freqs, sin_freqs, mask=None):
"""
x: [B, S, dim] in BF16
"""
# --- Attention ---
residual = x
x = self.attn_norm(x) # BF16 norm
# FP8 projections
q = self.q_proj(x) # FP8 GEMM -> BF16 output
k = self.k_proj(x) # FP8 GEMM -> BF16 output
v = self.v_proj(x) # FP8 GEMM -> BF16 output
# Reshape for attention
B, S, _ = q.shape
q = q.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, S, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = v.view(B, S, self.n_kv_heads, self.head_dim).transpose(1, 2)
# Apply rotary embeddings (BF16)
q = apply_rotary_emb(q, cos_freqs, sin_freqs)
k = apply_rotary_emb(k, cos_freqs, sin_freqs)
# GQA: expand KV heads
if self.n_kv_heads < self.n_heads:
rep = self.n_heads // self.n_kv_heads
k = k.repeat_interleave(rep, dim=1)
v = v.repeat_interleave(rep, dim=1)
# Attention (BF16 — FlashAttention)
attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, is_causal=True
) # BF16
attn_out = attn_out.transpose(1, 2).reshape(B, S, -1)
# Output projection (FP8 GEMM)
attn_out = self.o_proj(attn_out)
# Residual (BF16)
x = residual + attn_out
# --- FFN ---
residual = x
x = self.ffn_norm(x) # BF16 norm
# FP8 GEMMs
gate = self.gate_proj(x) # FP8 -> BF16
up = self.up_proj(x) # FP8 -> BF16
# SiLU activation (BF16)
x = F.silu(gate) * up
# Down projection (FP8 GEMM)
x = self.down_proj(x)
# Residual (BF16)
x = residual + x
return x
def apply_rotary_emb(x, cos, sin):
"""Apply rotary positional embeddings. Always BF16."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([
x1 * cos - x2 * sin,
x2 * cos + x1 * sin,
], dim=-1)
Loading and Quantizing a Pretrained Model
def convert_to_fp8(model_fp16):
"""
Convert an FP16 model to FP8 inference.
Replaces all nn.Linear with FP8LinearLayer.
"""
for name, module in model_fp16.named_children():
if isinstance(module, nn.Linear):
fp8_layer = FP8LinearLayer(
module.in_features,
module.out_features,
bias=module.bias is not None,
).to(module.weight.device)
# Quantize weight
fp8_layer.quantize_weight(module.weight.data)
if module.bias is not None:
fp8_layer.bias.data.copy_(module.bias.data)
setattr(model_fp16, name, fp8_layer)
else:
convert_to_fp8(module) # Recurse
return model_fp16
# Usage:
# model = load_pretrained_model("llama-70b", dtype=torch.bfloat16)
# model = convert_to_fp8(model)
# model.eval()
#
# Memory: 70B params * 1 byte = 70 GB (vs 140 GB in FP16)
# Throughput: 1.3-1.9x depending on batch size
Memory Savings
Memory Comparison: FP16 vs FP8 (Llama 70B, H100 80GB)
| Component | FP16 Size | FP8 Size | Savings |
|---|---|---|---|
| Model weights | 140 GB | 70 GB | 50% |
| Scale factors | 0 GB | ~0.001 GB | Negligible overhead |
| KV cache (batch=64, seq=4096) | 85.9 GB | 85.9 GB (still BF16) | 0% |
| Activation memory | ~2 GB | ~2 GB (still BF16) | 0% |
| Total (batch=64) | 228 GB | 158 GB | 31% |
At batch size 64 with sequence length 4096, the KV cache for Llama 70B in BF16 is 85.9 GB — larger than the FP8 weight savings of 70 GB. To maximize serving capacity, combine FP8 weight quantization with FP8 KV cache quantization (reducing KV cache from 85.9 GB to 42.9 GB). The combination reduces total memory from 228 GB to approximately 115 GB — fitting on two H100s instead of three.
7. Quality Impact and When FP8 Fails
FP8 inference is not lossless. The 3 mantissa bits introduce quantization error on every weight and activation. For most production models, the accuracy degradation is small enough to be acceptable.
FP8 Quality Impact: Perplexity on WikiText-2
| Model | FP16 PPL | FP8 (static scale) PPL | FP8 (dynamic scale) PPL | Degradation |
|---|---|---|---|---|
| Llama 2 7B | 5.47 | 5.58 | 5.51 | +0.04-0.11 |
| Llama 2 13B | 4.88 | 4.96 | 4.91 | +0.03-0.08 |
| Llama 2 70B | 3.32 | 3.35 | 3.33 | +0.01-0.03 |
| Mistral 7B | 5.25 | 5.39 | 5.30 | +0.05-0.14 |
| Llama 3 8B | 6.14 | 6.38 | 6.21 | +0.07-0.24 |
| Llama 3 70B | 2.86 | 2.89 | 2.87 | +0.01-0.03 |
When FP8 Produces Unacceptable Quality
-
Small models (less than 3B parameters): Each weight carries more information. FP8 quantization error is proportionally larger. Consider INT8 weight-only quantization instead, which preserves activations in FP16.
-
Models with outlier channels: Some transformer models develop outlier features — channels where activation values are 10-100x larger than the rest. Per-tensor scaling is dominated by these outliers, causing severe precision loss for normal-range values. SmoothQuant-style techniques migrate the outlier magnitude from activations to weights before quantization.
-
Fine-tuned models with narrow weight distributions: Models fine-tuned on narrow domains (e.g., medical, legal) may have weights concentrated in a very small range. FP8’s 3 mantissa bits may not provide enough resolution to distinguish between close weight values.
-
Long-context generation: Quantization errors accumulate across the sequence length through the residual stream. At 100K+ tokens, the accumulated error from 80 layers of FP8 GEMMs can produce noticeably different outputs from FP16. This is model-dependent and difficult to predict without testing.
def check_fp8_compatibility(model):
"""
Quick diagnostic: check for conditions that make FP8
quantization risky.
"""
warnings = []
for name, param in model.named_parameters():
if 'weight' not in name:
continue
w = param.data.float()
amax = w.abs().max().item()
mean_abs = w.abs().mean().item()
std = w.std().item()
# Check 1: Outlier ratio
outlier_ratio = amax / mean_abs
if outlier_ratio > 20:
warnings.append(
f"{name}: outlier ratio {outlier_ratio:.1f} "
f"(amax={amax:.3f}, mean_abs={mean_abs:.5f}). "
f"Consider SmoothQuant."
)
# Check 2: Very small dynamic range
dynamic_range = amax / (w.abs()[w.abs() > 0].min().item())
if dynamic_range > 1000:
warnings.append(
f"{name}: dynamic range {dynamic_range:.0f}. "
f"FP8 may not resolve small values."
)
# Check 3: Near-zero standard deviation
if std < 0.001:
warnings.append(
f"{name}: std={std:.6f}. Very narrow distribution, "
f"FP8 quantization noise may dominate."
)
return warnings
8. Hardware Support Matrix
FP8 is not universally available. Here is the current support landscape:
FP8 Hardware Support (as of early 2025)
| Hardware | FP8 Support | E4M3 | E5M2 | FP8 Tensor Cores | FP8 TFLOPS |
|---|---|---|---|---|---|
| H100 SXM | Yes | Yes | Yes | Yes | 1979 |
| H100 PCIe | Yes | Yes | Yes | Yes | 1513 |
| H200 SXM | Yes | Yes | Yes | Yes | 1979 |
| L40S | Yes | Yes | Yes | Yes | 733 |
| A100 | No | — | — | — | — |
| A10G | No | — | — | — | — |
| RTX 4090 (Ada) | Yes | Yes | Yes | Yes | 660 |
| RTX 3090 (Ampere) | No | — | — | — | — |
| AMD MI300X | Yes (OCP FP8) | Yes | Yes | Yes | ~2600 |
| Intel Gaudi 2 | Yes | Yes | No | MME only | ~600 |
Software Requirements
NVIDIA FP8:
- CUDA 12.0+
- cuDNN 8.9+
- PyTorch 2.1+ (for torch.float8_e4m3fn dtype)
- PyTorch 2.4+ (for torch._scaled_mm with proper H100 support)
- Driver 525.60+
Frameworks with FP8 support:
- TensorRT-LLM: Native FP8 since v0.5
- vLLM: FP8 quantization via compressed-tensors (v0.4+)
- SGLang: FP8 via torch._scaled_mm (v0.2+)
- DeepSpeed: FP8 via Transformer Engine integration
- NVIDIA Transformer Engine: The reference FP8 library
Transformer Engine Integration
NVIDIA’s Transformer Engine library provides the most optimized FP8 path:
import transformer_engine.pytorch as te
# Replace nn.Linear with te.Linear for automatic FP8
class TETransformerLayer(nn.Module):
def __init__(self, dim, ff_dim):
super().__init__()
# Transformer Engine handles FP8 quantization internally
self.qkv_proj = te.Linear(dim, 3 * dim, bias=False)
self.o_proj = te.Linear(dim, dim, bias=False)
self.gate_proj = te.Linear(dim, ff_dim, bias=False)
self.up_proj = te.Linear(dim, ff_dim, bias=False)
self.down_proj = te.Linear(ff_dim, dim, bias=False)
# Norms integrated into TE layers
self.attn_norm = te.LayerNorm(dim)
self.ffn_norm = te.LayerNorm(dim)
def forward(self, x):
# Transformer Engine automatically:
# 1. Manages FP8 scaling (delayed dynamic)
# 2. Quantizes activations before each GEMM
# 3. Accumulates in FP32
# 4. Outputs in BF16
# All within optimized fused kernels
with te.fp8_autocast(enabled=True):
normed = self.attn_norm(x)
qkv = self.qkv_proj(normed)
# ... attention ...
o = self.o_proj(attn_out)
x = x + o
normed2 = self.ffn_norm(x)
gate = self.gate_proj(normed2)
up = self.up_proj(normed2)
x = x + self.down_proj(F.silu(gate) * up)
return x
Transformer Engine provides several advantages over manual FP8:
- Delayed scaling with automatic amax history management
- Fused GEMM + scaling kernels (no separate quantize kernel)
- Automatic mixed-precision recipes (which layers use FP8)
- Communication-efficient FP8 for tensor parallelism
Key Takeaways
-
E4M3 is the inference format: 4 exponent bits, 3 mantissa bits, range , no infinity. 240 total representable values. Use E5M2 only for training gradients.
-
Per-tensor scaling is mandatory: Without scaling, FP8’s limited range causes overflow or severe precision loss. Scale = 448.0 / max(abs(tensor)). One FP32 scalar per tensor.
-
GEMMs only: Quantize linear projections (QKV, output, FFN) to FP8. Keep norms, softmax, activations, residuals, and embeddings in BF16. Quantizing non-GEMM operations degrades quality with no throughput benefit.
-
2x throughput is the ceiling, not the floor: H100 FP8 tensor cores deliver 1,979 vs 989 TFLOPS (FP16). The actual speedup depends on arithmetic intensity — 1.0x at (memory-bound), 1.9x at (compute-bound). Typical serving workloads see 1.3-1.6x.
-
Larger models quantize better: Llama 70B loses 0.01-0.03 perplexity points; Llama 7B loses 0.04-0.11. Each parameter carries less marginal information in larger models, so quantization noise has less impact.
-
Combine with KV cache quantization: FP8 weights save 70 GB for a 70B model. But KV cache at BF16 can exceed 85 GB at high batch sizes. FP8 KV cache quantization provides additive savings.
-
Check hardware support: FP8 requires Hopper (H100/H200) or Ada Lovelace (L40S, RTX 4090). A100 and older GPUs do not have FP8 tensor cores. On unsupported hardware, use INT8 weight-only quantization instead.