FP8 is the most important precision innovation since BF16. It halves the memory and doubles the throughput compared to BF16/FP16 for both training and inference, and it does so without the outlier problems that plague INT8 activation quantization. FP8โs floating-point representation naturally handles the non-uniform distributions found in neural network tensors โ no SmoothQuant transformation needed.
This post covers FP8 in full depth: why two FP8 variants exist (E4M3 and E5M2) and when to use each, how NVIDIA Transformer Engine manages FP8 mixed precision automatically, the delayed scaling algorithm that makes FP8 practical, which operations benefit from FP8 (and which do not), how DeepSeek trained a 671B MoE model entirely in FP8, and a complete implementation of an FP8 linear layer with proper scaling.
The FP8 Precision Strategy
Recall from Part 1 that FP8 comes in two variants:
- E4M3: 4 exponent bits, 3 mantissa bits. Range . More precision, less range.
- E5M2: 5 exponent bits, 2 mantissa bits. Range . Less precision, more range.
The standard FP8 recipe uses both variants in different roles:
| Tensor | Format | Rationale |
|---|---|---|
| Forward activations | E4M3 | Values are bounded; precision matters |
| Forward weights | E4M3 | Static, well-distributed; precision matters |
| Backward gradients | E5M2 | Extreme dynamic range; range matters |
| Master weights | FP32 | Full precision for optimizer state |
| Optimizer state | FP32 | Momentum/variance need full precision |
E4M3โs maximum value is 448. Gradient values during training regularly exceed this. A single gradient overflow to NaN can destabilize the entire training run. E5M2โs maximum of 57344 provides the safety margin needed for gradients. Loss scaling could extend E4M3โs effective range, but E5M2 avoids the complexity and fragility of loss scaling entirely.
Which Operations Use FP8
Not all operations benefit from FP8 precision. The rule is simple: GEMMs use FP8, everything else stays in higher precision.
A transformer layer contains these compute operations:
Forward pass through one transformer block:
1. LayerNorm(x) -- FP32 (reduction, needs precision)
2. Q = x @ W_q -- FP8 GEMM (E4M3 inputs, FP32 accum)
3. K = x @ W_k -- FP8 GEMM
4. V = x @ W_v -- FP8 GEMM
5. attn = softmax(Q @ K^T) -- FP32 (softmax needs precision)
6. out = attn @ V -- FP8 GEMM (attention @ values)
7. proj = out @ W_o -- FP8 GEMM
8. residual add -- FP32
9. LayerNorm(x) -- FP32
10. up = x @ W_up -- FP8 GEMM
11. gate = x @ W_gate -- FP8 GEMM
12. SiLU(gate) * up -- FP32 (element-wise)
13. down = h @ W_down -- FP8 GEMM
14. residual add -- FP32
Of the 14 operations, 8 are GEMMs that run in FP8. The remaining 6 (LayerNorm, softmax, activation functions, residual adds) stay in FP32 or BF16 because they are either numerically sensitive (softmax, LayerNorm) or trivially cheap (element-wise operations, residual adds).
In a typical transformer, GEMMs account for over 95% of FLOPs. Running them in FP8 provides nearly 2x throughput improvement for the entire layer, even though the non-GEMM operations remain in higher precision. The non-GEMM operations contribute negligible runtime.
The Scaling Problem
FP8 E4M3 can represent values from (smallest subnormal) to 448. If your tensor values fall outside this range, you get underflow (small values become zero) or overflow (large values become NaN). Neither is acceptable.
The solution is scaling: multiply the tensor by a scale factor before casting to FP8, then divide the GEMM output by the same factor.
The GEMM is performed in FP8 with FP32 accumulation, and the scale factors are applied to the FP32 result. The critical question is: how do you choose and ?
Per-Tensor Scaling
The simplest approach: compute the max absolute value of the tensor, and set the scale factor to map that value to the FP8 max:
import torch
def compute_fp8_scale(tensor, fp8_max=448.0):
"""Compute per-tensor scale to map tensor into FP8 E4M3 range."""
amax = tensor.abs().max().item()
if amax == 0:
return 1.0
return fp8_max / amax
def quantize_to_fp8_e4m3(tensor, scale):
"""Quantize FP32/BF16 tensor to simulated FP8 E4M3.
In practice, this is a hardware cast instruction (CUDA __nv_fp8_e4m3).
Here we simulate it with clamping and reduced precision.
"""
scaled = tensor * scale
# Clamp to E4M3 range
clamped = scaled.clamp(-448.0, 448.0)
# Simulate E4M3 precision: round to nearest representable value
# E4M3 has 3 mantissa bits = 8 values per power-of-2 interval
# We approximate by quantizing to 4-bit resolution in the log domain
return clamped # On real hardware, this becomes actual FP8
def dequantize_from_fp8(fp8_tensor, scale):
"""Dequantize FP8 back to FP32."""
return fp8_tensor / scale
The Problem with Just-In-Time Scaling
Computing tensor.abs().max() requires reading the entire tensor from memory. For a large activation tensor, this is a separate kernel launch that reads all the data, computes the max, and then the quantization kernel reads the data again. Two full memory reads where there should be one.
This overhead is why delayed scaling exists.
Delayed Scaling: The Key Algorithm
Delayed scaling uses the max absolute value from a previous iteration to compute the scale factor for the current iteration. The insight: tensor distributions change slowly during training. The max value at step is a good approximation of the max value at step .
NVIDIA Transformer Engine maintains an amax history buffer for each FP8 tensor. The algorithm:
- At step , use the scale computed from step (or earlier) to cast tensors to FP8
- During the FP8 GEMM, record the actual max absolute value of the current tensors (piggyback on the GEMM kernel)
- Update the amax history buffer
- Compute the scale for step from the history
class DelayedScaling:
"""Delayed scaling algorithm for FP8 tensors.
Maintains a history of amax values and uses them to compute
scale factors one step behind.
"""
def __init__(self, history_len=1024, fp8_max=448.0, margin=0):
"""
history_len: number of past amax values to keep
fp8_max: maximum representable FP8 value (448 for E4M3)
margin: safety margin in powers of 2 (2^margin headroom)
"""
self.history_len = history_len
self.fp8_max = fp8_max
self.margin = margin
# Circular buffer of amax values
self.amax_history = torch.zeros(history_len)
self.history_idx = 0
self.scale = 1.0
def update(self, current_amax):
"""Record the amax from the current step and update scale."""
# Store current amax in history
self.amax_history[self.history_idx % self.history_len] = current_amax
self.history_idx += 1
# Compute scale from history
# Use the max of recent history for safety
valid_len = min(self.history_idx, self.history_len)
amax_from_history = self.amax_history[:valid_len].max().item()
if amax_from_history == 0:
self.scale = 1.0
else:
self.scale = (self.fp8_max / amax_from_history) / (2 ** self.margin)
return self.scale
def get_scale(self):
"""Get the current scale factor (computed from previous step)."""
return self.scale
class FP8TensorManager:
"""Manages delayed scaling for all FP8 tensors in a layer."""
def __init__(self, fp8_max_fwd=448.0, fp8_max_bwd=57344.0):
self.input_scaling = DelayedScaling(fp8_max=fp8_max_fwd)
self.weight_scaling = DelayedScaling(fp8_max=fp8_max_fwd)
self.grad_output_scaling = DelayedScaling(fp8_max=fp8_max_bwd)
def get_forward_scales(self):
"""Get scale factors for forward pass (E4M3)."""
return self.input_scaling.get_scale(), self.weight_scaling.get_scale()
def get_backward_scale(self):
"""Get scale factor for backward pass (E5M2)."""
return self.grad_output_scaling.get_scale()
def update_forward(self, input_amax, weight_amax):
"""Update forward scales after computing the GEMM."""
self.input_scaling.update(input_amax)
self.weight_scaling.update(weight_amax)
def update_backward(self, grad_amax):
"""Update backward scale after computing the backward GEMM."""
self.grad_output_scaling.update(grad_amax)
Delayed scaling uses stale scale factors. If the tensor distribution changes abruptly (e.g., a sudden gradient spike), the scale factor from the previous step may be too small, causing overflow. The margin parameter provides headroom: a margin of 1 means the scale leaves a 2x safety margin. Transformer Engine uses a default margin of 0 and relies on the amax history taking the max over recent steps to handle transients.
Complete FP8 Linear Layer
Here is a full FP8 linear layer implementation with delayed scaling, suitable for both training and inference:
class FP8Linear(torch.nn.Module):
"""Linear layer with FP8 compute and delayed scaling.
Forward: Y = (X_e4m3 @ W_e4m3^T) * (sx * sw) -- FP8 GEMM, FP32 accum
Backward: dX = (dY_e5m2 @ W_e4m3) * (sdy * sw) -- FP8 GEMM
dW = (dY_e5m2^T @ X_e4m3) * (sdy * sx) -- FP8 GEMM
"""
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Master weights in FP32
self.weight = torch.nn.Parameter(
torch.randn(out_features, in_features) * (2 / in_features) ** 0.5
)
if bias:
self.bias = torch.nn.Parameter(torch.zeros(out_features))
else:
self.bias = None
# FP8 scaling managers
self.fp8_manager = FP8TensorManager()
# Cached FP8 weight (recomputed when scale changes)
self._cached_weight_fp8 = None
self._cached_weight_scale = None
def cast_to_fp8_e4m3(self, tensor, scale):
"""Cast tensor to FP8 E4M3 (simulated)."""
scaled = tensor.float() * scale
clamped = scaled.clamp(-448.0, 448.0)
# Record amax for next step's scale computation
amax = tensor.abs().max().item()
return clamped, amax
def cast_to_fp8_e5m2(self, tensor, scale):
"""Cast tensor to FP8 E5M2 (simulated)."""
scaled = tensor.float() * scale
clamped = scaled.clamp(-57344.0, 57344.0)
amax = tensor.abs().max().item()
return clamped, amax
def forward(self, x):
"""FP8 forward pass."""
# Get delayed scales
s_input, s_weight = self.fp8_manager.get_forward_scales()
# Cast input to E4M3
x_fp8, x_amax = self.cast_to_fp8_e4m3(x, s_input)
# Cast weight to E4M3 (cache if scale unchanged)
if (self._cached_weight_scale is None or
self._cached_weight_scale != s_weight):
self._cached_weight_fp8, w_amax = self.cast_to_fp8_e4m3(
self.weight.data, s_weight
)
self._cached_weight_scale = s_weight
else:
w_amax = self.weight.data.abs().max().item()
# FP8 GEMM with FP32 accumulation
# On real hardware: cublasFp8Gemm
y_fp32 = torch.matmul(x_fp8, self._cached_weight_fp8.T)
# Dequantize: divide by both scales
y_fp32 = y_fp32 / (s_input * s_weight)
# Update delayed scaling with current amax
self.fp8_manager.update_forward(x_amax, w_amax)
if self.bias is not None:
y_fp32 = y_fp32 + self.bias
return y_fp32
def fp8_training_step(model, optimizer, data, target, loss_fn):
"""Single training step with FP8 linear layers."""
optimizer.zero_grad()
# Forward pass: FP8 GEMMs with E4M3
output = model(data)
loss = loss_fn(output, target)
# Backward pass: FP8 GEMMs with E5M2
# (In practice, the autograd backward through FP8Linear
# uses E5M2 for gradient tensors)
loss.backward()
# Optimizer step in FP32 (master weights)
optimizer.step()
return loss.item()
Transformer Engine Integration
NVIDIA Transformer Engine wraps the FP8 complexity behind a drop-in API. You replace torch.nn.Linear with te.Linear and the framework handles all scaling, casting, and history management automatically.
# Standard PyTorch
import torch.nn as nn
layer = nn.Linear(4096, 4096)
# Transformer Engine FP8 equivalent
# import transformer_engine.pytorch as te
# layer = te.Linear(4096, 4096)
# The te.Linear layer:
# Maintains delayed scaling state for input, weight, and gradient
# Casts input and weight to E4M3 on forward
# Runs the GEMM on FP8 tensor cores
# Accumulates in FP32
# Casts gradient to E5M2 on backward
# Updates amax history after each step
The integration into a training loop requires wrapping the forward pass in an FP8 context manager:
def train_with_transformer_engine(model, optimizer, dataloader, num_steps):
"""Training loop with Transformer Engine FP8.
Requires: NVIDIA H100 or later, Transformer Engine installed.
"""
# import transformer_engine.pytorch as te
model.train()
for step, (data, target) in enumerate(dataloader):
if step >= num_steps:
break
optimizer.zero_grad()
# The fp8_autocast context manages all FP8 state
# with te.fp8_autocast(enabled=True):
# output = model(data)
# loss = loss_fn(output, target)
# Simulated version (without real TE):
output = model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
if step % 100 == 0:
print(f"Step {step}, loss={loss.item():.4f}")
Transformer Engine FP8 vs BF16 Training Performance
| Model | GPU | BF16 (tokens/sec) | FP8 TE (tokens/sec) | Speedup |
|---|---|---|---|---|
| GPT-3 175B | 256x H100 | ~12,000 | ~20,400 | 1.70x |
| Llama 2 70B | 64x H100 | ~8,500 | ~14,000 | 1.65x |
| Llama 7B | 8x H100 | ~45,000 | ~76,000 | 1.69x |
| Mistral 7B | 8x H100 | ~48,000 | ~80,000 | 1.67x |
DeepSeek V3: FP8 Training at 671B Scale
DeepSeek V3 is the most significant real-world validation of FP8 training. They trained a 671B MoE model (37B active parameters per token) using FP8 on 2048 H800 GPUs for approximately 14.8 trillion tokens. Key technical decisions:
FP8 for all GEMMs in the forward pass. Attention QKV projections, output projections, MoE expert layers, and gating layers all use FP8 E4M3.
BF16 for the backward pass. DeepSeek chose not to use E5M2 for gradients. They found that BF16 backward was more stable for their MoE training setup, where the combination of expert routing dynamics and gradient sparsity made FP8 gradients risky.
Fine-grained quantization. Instead of per-tensor scaling, DeepSeek used per-group FP8 quantization with a group size of 128. Each group of 128 elements has its own FP8 scale factor, similar to the per-group approach used in INT4 weight quantization. This provided better quality than per-tensor scaling at the cost of slightly more metadata.
Online scaling instead of delayed scaling. DeepSeek computed the actual amax of each block and used it immediately, avoiding the staleness risk of delayed scaling. Their custom kernels fused the amax computation with the quantization, avoiding the separate memory read that makes just-in-time scaling expensive.
def deepseek_fp8_gemm(x, weight, group_size=128):
"""Simulate DeepSeek V3's fine-grained FP8 GEMM approach.
Key differences from standard Transformer Engine:
1. Per-group scaling (group_size=128) instead of per-tensor
2. Online scaling (compute amax and quantize in one pass)
3. Only forward pass in FP8; backward in BF16
"""
batch_tokens, hidden = x.shape
out_features = weight.shape[0]
# Per-group quantize input
x_groups = x.reshape(batch_tokens, -1, group_size)
x_amax = x_groups.abs().amax(dim=2, keepdim=True)
x_scale = 448.0 / x_amax.clamp(min=1e-12)
x_fp8 = (x_groups * x_scale).clamp(-448, 448)
x_fp8 = x_fp8.reshape(batch_tokens, hidden)
# Per-group quantize weight
w_groups = weight.reshape(out_features, -1, group_size)
w_amax = w_groups.abs().amax(dim=2, keepdim=True)
w_scale = 448.0 / w_amax.clamp(min=1e-12)
w_fp8 = (w_groups * w_scale).clamp(-448, 448)
w_fp8 = w_fp8.reshape(out_features, hidden)
# FP8 GEMM (simulated)
y = x_fp8 @ w_fp8.T
# Dequantize: need to account for per-group scales
# In practice, this is handled by the GEMM kernel itself
x_dequant_scale = (1.0 / x_scale).reshape(batch_tokens, -1)
w_dequant_scale = (1.0 / w_scale).reshape(out_features, -1)
# Simplified: approximate dequantization
# Real implementation accumulates partial products with proper scaling
y_approx = x @ weight.T # Placeholder for properly scaled result
return y_approx
DeepSeek V3 trained for 2.788 million H800 GPU hours at a total cost of approximately 2/GPU-hour). FP8 training reduced the cost by an estimated 40% compared to BF16 training for the same model and data scale. Without FP8, the training would have required either more GPUs or more time.
FP8 Inference: Simpler Than Training
FP8 inference is simpler than FP8 training because:
- No backward pass (no E5M2 needed)
- Weights are static (quantize once, use forever)
- Scale factors for weights can be computed offline with calibration data
The inference recipe:
class FP8InferenceLinear:
"""FP8 linear layer optimized for inference.
Weights are pre-quantized to E4M3 with offline calibration.
Activations are dynamically quantized per-tensor or per-token.
"""
def __init__(self, weight_fp8, weight_scale, activation_scale=None):
"""
weight_fp8: (out_features, in_features) pre-quantized E4M3
weight_scale: scalar or per-channel scale for weight
activation_scale: optional static scale (if None, uses dynamic)
"""
self.weight_fp8 = weight_fp8
self.weight_scale = weight_scale
self.static_act_scale = activation_scale
@classmethod
def from_float(cls, linear, calibration_data=None):
"""Quantize a float linear layer for FP8 inference."""
weight = linear.weight.data.float()
# Quantize weight to E4M3
w_amax = weight.abs().max()
w_scale = 448.0 / w_amax.item()
weight_fp8 = (weight * w_scale).clamp(-448, 448)
# Optional: compute static activation scale from calibration
act_scale = None
if calibration_data is not None:
max_act = 0.0
for x in calibration_data:
if x.dim() == 3:
x = x.reshape(-1, x.shape[-1])
max_act = max(max_act, x.abs().max().item())
act_scale = 448.0 / max_act
return cls(weight_fp8, w_scale, act_scale)
def forward(self, x):
"""FP8 inference forward pass."""
if x.dim() == 3:
batch, seq, hidden = x.shape
x = x.reshape(-1, hidden)
reshape_back = True
else:
reshape_back = False
# Quantize activation
if self.static_act_scale is not None:
act_scale = self.static_act_scale
else:
# Dynamic per-tensor scaling
act_scale = 448.0 / x.abs().max().clamp(min=1e-12).item()
x_fp8 = (x * act_scale).clamp(-448, 448)
# FP8 GEMM
y = x_fp8 @ self.weight_fp8.T
# Dequantize
y = y / (act_scale * self.weight_scale)
if reshape_back:
y = y.reshape(batch, seq, -1)
return y
Inference Throughput by Precision (Llama 70B, H100, Batch=32)
(tokens/sec)FP8 vs INT8: Why FP8 Is Better for Activations
FP8 has a fundamental advantage over INT8 for activation quantization: non-uniform quantization levels. INT8 has uniform spacing โ 256 values evenly distributed across the range. FP8 E4M3 has 240 values with logarithmic spacing โ denser near zero, sparser at large values.
Neural network activations typically follow a distribution with most values near zero and a long tail. FP8โs logarithmic spacing matches this distribution naturally, providing more resolution where the data density is highest.
def compare_fp8_vs_int8_coverage(data):
"""Compare how well FP8 and INT8 cover a realistic activation distribution."""
import numpy as np
data_np = data.numpy().flatten()
data_abs = np.abs(data_np)
amax = data_abs.max()
# INT8: 256 uniform levels
int8_scale = amax / 127.0
int8_levels = np.arange(-128, 128) * int8_scale
int8_q = np.round(data_np / int8_scale).clip(-128, 127) * int8_scale
int8_mse = np.mean((data_np - int8_q) ** 2)
# FP8 E4M3: ~240 non-uniform levels
# Generate all positive E4M3 values
e4m3_values = []
for exp in range(16):
for mant in range(8):
if exp == 15 and mant == 7:
continue # NaN
if exp == 0:
val = (mant / 8.0) * (2 ** -6)
else:
val = (1.0 + mant / 8.0) * (2 ** (exp - 7))
e4m3_values.append(val)
e4m3_values = np.array(sorted(set(e4m3_values)))
fp8_scale = 448.0 / amax
scaled = data_np * fp8_scale
# Quantize to nearest E4M3 value
fp8_q = np.zeros_like(scaled)
for i, val in enumerate(scaled):
sign = 1 if val >= 0 else -1
abs_val = abs(val)
idx = np.argmin(np.abs(e4m3_values - abs_val))
fp8_q[i] = sign * e4m3_values[idx]
fp8_recon = fp8_q / fp8_scale
fp8_mse = np.mean((data_np - fp8_recon) ** 2)
print(f"INT8 MSE: {int8_mse:.8f}")
print(f"FP8 MSE: {fp8_mse:.8f}")
print(f"FP8 advantage: {int8_mse / fp8_mse:.2f}x lower error")
return int8_mse, fp8_mse
# Test with realistic activation distribution
activation = torch.randn(10000) * 0.5
activation[torch.randperm(10000)[:100]] *= 20 # Add outliers
compare_fp8_vs_int8_coverage(activation)
FP8 vs INT8 Activation Quantization Quality
| Scenario | INT8 MSE | FP8 E4M3 MSE | FP8 Advantage |
|---|---|---|---|
| Gaussian (no outliers) | 1.2e-5 | 1.8e-5 | 0.67x (INT8 wins) |
| Gaussian + 1% outliers | 8.4e-4 | 3.1e-4 | 2.7x |
| Gaussian + 5% outliers | 4.2e-3 | 9.8e-4 | 4.3x |
| Real LLM activations | 2.1e-3 | 6.2e-4 | 3.4x |
FP8 Training Stability Considerations
FP8 training can be unstable if not managed carefully. Key failure modes and mitigations:
Loss spikes: Caused by sudden distribution shifts (e.g., encountering a batch with unusual data). Delayed scaling uses a stale scale factor that is too small, causing overflow. Mitigation: increase the amax history length or add a safety margin.
Gradient underflow: Small gradients in E5M2 can underflow to zero, causing dead parameters. Mitigation: monitor the fraction of zero gradients and increase the gradient scale if it exceeds a threshold.
Accumulation error: FP8 GEMMs accumulate in FP32, but the input values are already quantized. For very large matrix dimensions (e.g., 16K hidden size), the accumulation of many small quantization errors can be significant. Mitigation: use block-wise FP8 with smaller blocks.
class FP8TrainingMonitor:
"""Monitor FP8 training health metrics."""
def __init__(self):
self.overflow_count = 0
self.underflow_count = 0
self.total_count = 0
def check_tensor(self, tensor, name, fp8_max=448.0):
"""Check a tensor for FP8 scaling issues."""
self.total_count += 1
amax = tensor.abs().max().item()
if amax > fp8_max:
self.overflow_count += 1
print(f"WARNING: {name} amax={amax:.2f} exceeds FP8 max={fp8_max}")
zero_frac = (tensor == 0).float().mean().item()
if zero_frac > 0.5:
self.underflow_count += 1
print(f"WARNING: {name} has {zero_frac:.1%} zeros (potential underflow)")
return {
'name': name,
'amax': amax,
'zero_fraction': zero_frac,
'overflow_risk': amax > 0.9 * fp8_max,
}
def report(self):
"""Print summary of FP8 health metrics."""
print(f"FP8 Health: {self.overflow_count} overflows, "
f"{self.underflow_count} underflows out of "
f"{self.total_count} tensors checked")
Summary
FP8 delivers roughly 2x throughput improvement over BF16/FP16 for both training and inference, with minimal quality degradation when properly managed.
E4M3 provides precision-optimized 8-bit representation for forward pass tensors (weights and activations). Its maximum value of 448 requires scaling, but the 3 mantissa bits preserve adequate precision.
E5M2 provides range-optimized 8-bit representation for backward pass gradients. Its maximum of 57344 and 5 exponent bits handle the extreme dynamic range of gradients without loss scaling.
Delayed scaling is the algorithm that makes FP8 practical: it amortizes the cost of computing scale factors by using amax values from previous steps, updated as a side-effect of the GEMM kernel.
NVIDIA Transformer Engine wraps all of this complexity behind a simple API, managing scale factors, amax history, and format selection automatically.
DeepSeek V3 proved FP8 training works at 671B scale, training their flagship model with FP8 forward pass and BF16 backward pass, reducing training cost by an estimated 40%.
The next post covers the frontier beyond FP8: 4-bit floating-point formats on NVIDIA Blackwell that promise another 2x throughput improvement.