Training quantization is fundamentally different from inference quantization. In inference, we quantize a frozen model to reduce memory and increase throughput — any quality loss is permanent. In training, we quantize the forward and backward passes to speed up gradient computation, but maintain high-precision master weights that accumulate gradients. The quantization error in each step is ephemeral: it affects the gradient estimate but not the converged model.
This distinction means training can tolerate higher quantization error per step (because gradient descent is inherently noisy), but it requires careful management of numerical range (because gradients can be extremely small or large). The history of low-precision training is a progression from FP32 to FP16 to BF16 to FP8, with each step requiring new techniques to handle the precision/range tradeoff.
Mixed-Precision Training: The FP16 Recipe
Mixed-precision training (Micikevicius et al., 2018) established the fundamental pattern:
- Master weights: FP32 (full precision, stored in optimizer)
- Forward pass: FP16 (weights cast down, activations computed in FP16)
- Backward pass: FP16 (gradients computed in FP16)
- Weight update: FP32 (gradients cast up, applied to FP32 master weights)
import torch
import torch.nn as nn
class FP16MixedPrecisionTrainer:
"""Simplified FP16 mixed-precision training loop."""
def __init__(self, model, optimizer, loss_scale_init=65536.0):
self.model = model
self.optimizer = optimizer
self.loss_scale = loss_scale_init
self.master_weights = {}
# Store FP32 master copies
for name, param in model.named_parameters():
self.master_weights[name] = param.data.float().clone()
def train_step(self, input_data, targets):
"""One training step with FP16 forward/backward + FP32 update."""
# Step 1: Cast weights to FP16 for forward pass
for name, param in self.model.named_parameters():
param.data = self.master_weights[name].half()
# Step 2: Forward pass in FP16
output = self.model(input_data.half())
loss = nn.functional.cross_entropy(output.float(), targets)
# Step 3: Scale loss before backward (for FP16 gradient range)
scaled_loss = loss * self.loss_scale
# Step 4: Backward pass in FP16
self.optimizer.zero_grad()
scaled_loss.backward()
# Step 5: Unscale gradients and check for overflow
overflow = False
for name, param in self.model.named_parameters():
if param.grad is not None:
param.grad.data /= self.loss_scale
if torch.any(torch.isinf(param.grad)) or torch.any(torch.isnan(param.grad)):
overflow = True
break
if overflow:
# Skip this step, reduce loss scale
self.loss_scale /= 2
return loss.item(), True # Skipped
# Step 6: Update FP32 master weights
for name, param in self.model.named_parameters():
if param.grad is not None:
grad_fp32 = param.grad.data.float()
self.master_weights[name] -= (
self.optimizer.defaults['lr'] * grad_fp32
)
# Periodically increase loss scale
self.loss_scale = min(self.loss_scale * 2, 65536.0)
return loss.item(), False # Not skipped
Why Loss Scaling is Necessary
FP16 has a range of . Gradients in deep networks often fall below , becoming zero in FP16 (underflow). Loss scaling multiplies the loss by a large factor (e.g., ) before backward, which scales all gradients up proportionally, keeping them in FP16’s representable range.
def demonstrate_gradient_underflow(hidden_dim=4096, depth=32):
"""Show that gradients underflow in FP16 without loss scaling."""
# Simulate gradient magnitudes through a deep network
# Each layer multiplies gradient by ~sqrt(2/hidden_dim)
grad_scale = 1.0
for layer in range(depth):
grad_scale *= (2 / hidden_dim) ** 0.5
print(f"Expected gradient scale after {depth} layers: {grad_scale:.2e}")
print(f"FP16 minimum subnormal: {np.float16(0).itemsize}")
# FP16 smallest representable positive: 2^(-24) ~ 5.96e-8
fp16_min = 2 ** (-24)
print(f"FP16 min subnormal: {fp16_min:.2e}")
if grad_scale < fp16_min:
print(f"UNDERFLOW: gradient ({grad_scale:.2e}) < FP16 min ({fp16_min:.2e})")
required_scale = fp16_min / grad_scale
print(f"Required loss scale: >= {required_scale:.0f}")
else:
print("No underflow risk")
BF16: Why It Displaced FP16
BF16 (Brain Float 16) has 8 exponent bits and 7 mantissa bits, compared to FP16’s 5 exponent bits and 10 mantissa bits:
| Format | Sign | Exponent | Mantissa | Range | Precision |
|---|---|---|---|---|---|
| FP16 | 1 | 5 | 10 | digits | |
| BF16 | 1 | 8 | 7 | digits |
BF16 has the same range as FP32 (both have 8 exponent bits), which eliminates the need for loss scaling entirely. The reduced precision (7 vs 10 mantissa bits) causes slightly more rounding error per operation, but this is compensated by the elimination of gradient underflow and overflow.
def compare_fp16_bf16_training():
"""Compare FP16 and BF16 training characteristics."""
# FP16 gradient range issues
fp16_max = 65504
fp16_min_normal = 2 ** (-14) # ~6.1e-5
fp16_min_subnormal = 2 ** (-24) # ~5.96e-8
# BF16 gradient range
bf16_max = 3.389e38 # Same as FP32
bf16_min_normal = 2 ** (-126) # ~1.18e-38
bf16_min_subnormal = 2 ** (-133) # ~9.18e-41
print("FP16:")
print(f" Max: {fp16_max:.0f}")
print(f" Min normal: {fp16_min_normal:.2e}")
print(f" Needs loss scaling: Yes")
print("\nBF16:")
print(f" Max: {bf16_max:.2e}")
print(f" Min normal: {bf16_min_normal:.2e}")
print(f" Needs loss scaling: No")
# Precision comparison
fp16_precision = 2 ** (-10) # Relative precision
bf16_precision = 2 ** (-7)
print(f"\nRelative precision:")
print(f" FP16: {fp16_precision:.2e} (~3.3 decimal digits)")
print(f" BF16: {bf16_precision:.2e} (~2.4 decimal digits)")
print(f" BF16 rounding error is {bf16_precision / fp16_precision:.0f}x "
f"larger per operation")
BF16 training is simpler than FP16 training because the FP32-equivalent range eliminates gradient underflow. The training recipe reduces to: (1) master weights in FP32, (2) forward and backward in BF16, (3) weight update in FP32. No loss scaling, no overflow detection, no skipped steps. This simplicity is why BF16 became the default for LLM training starting with T5 (2019).
FP8 Training on Hopper
FP8 training uses two FP8 formats:
- E4M3 (4 exponent, 3 mantissa): for forward pass weights and activations. Higher precision, narrower range.
- E5M2 (5 exponent, 2 mantissa): for backward pass gradients. Wider range, lower precision. Gradients need range more than precision.
def fp8_format_comparison():
"""Compare E4M3 and E5M2 FP8 formats."""
formats = {
'E4M3': {
'exponent_bits': 4,
'mantissa_bits': 3,
'bias': 7,
'max_value': 448,
'min_normal': 2 ** (-6), # 0.015625
'precision': 2 ** (-3), # 0.125 relative
'use': 'Forward pass (weights, activations)',
},
'E5M2': {
'exponent_bits': 5,
'mantissa_bits': 2,
'bias': 15,
'max_value': 57344,
'min_normal': 2 ** (-14), # ~6.1e-5
'precision': 2 ** (-2), # 0.25 relative
'use': 'Backward pass (gradients)',
},
}
for name, fmt in formats.items():
print(f"\n{name}:")
print(f" Range: [{fmt['min_normal']:.2e}, {fmt['max_value']}]")
print(f" Precision: {fmt['precision']} relative")
print(f" Dynamic range: {fmt['max_value'] / fmt['min_normal']:.0f}x")
print(f" Use: {fmt['use']}")
return formats
# E4M3 for forward: precision matters (weight values affect output directly)
# E5M2 for backward: range matters (gradients can be tiny or huge)
The FP8 Training Pipeline
class FP8TrainingStep:
"""One training step with FP8 GEMM on Hopper."""
def __init__(self, model, optimizer):
self.model = model
self.optimizer = optimizer
# Per-tensor scale factors (maintained across steps)
self.weight_scales = {}
self.activation_scales = {}
self.gradient_scales = {}
def compute_scale(self, tensor, format_type='e4m3'):
"""Compute dynamic per-tensor scale for FP8 quantization.
Scale maps the tensor's range to FP8's representable range.
"""
abs_max = tensor.abs().max()
if format_type == 'e4m3':
fp8_max = 448.0
elif format_type == 'e5m2':
fp8_max = 57344.0
# Scale so that abs_max maps to fp8_max
scale = fp8_max / abs_max.clamp(min=1e-12)
return scale
def forward_fp8(self, x, weight, layer_name):
"""FP8 forward pass for a single linear layer.
1. Quantize weight to FP8 E4M3
2. Quantize activation to FP8 E4M3
3. Run FP8 GEMM (accumulate in FP32)
4. Output in BF16/FP32
"""
# Compute scales
w_scale = self.compute_scale(weight, 'e4m3')
x_scale = self.compute_scale(x, 'e4m3')
# Quantize to FP8 (simulated -- real HW uses native FP8)
w_fp8 = self.simulate_fp8_quantize(weight * w_scale, 'e4m3') / w_scale
x_fp8 = self.simulate_fp8_quantize(x * x_scale, 'e4m3') / x_scale
# FP8 GEMM with FP32 accumulation
# On Hopper: fp8_e4m3 x fp8_e4m3 -> fp32 (tensor cores)
output = x_fp8 @ w_fp8.T # Simulated as FP32
# Store scales for backward
self.weight_scales[layer_name] = w_scale
self.activation_scales[layer_name] = x_scale
return output
def backward_fp8(self, grad_output, x, weight, layer_name):
"""FP8 backward pass.
Gradient w.r.t. input: dX = dY @ W (use E5M2 for gradients)
Gradient w.r.t. weight: dW = dY^T @ X (use E5M2 for gradients)
"""
# Quantize gradients to FP8 E5M2
g_scale = self.compute_scale(grad_output, 'e5m2')
grad_fp8 = self.simulate_fp8_quantize(
grad_output * g_scale, 'e5m2'
) / g_scale
# dX = grad @ W (FP8 GEMM)
w_scale = self.weight_scales[layer_name]
w_fp8 = self.simulate_fp8_quantize(weight * w_scale, 'e4m3') / w_scale
dX = grad_fp8 @ w_fp8 # E5M2 x E4M3 -> FP32
# dW = grad^T @ X (FP8 GEMM)
x_scale = self.activation_scales[layer_name]
x_fp8 = self.simulate_fp8_quantize(x * x_scale, 'e4m3') / x_scale
dW = grad_fp8.T @ x_fp8 # E5M2 x E4M3 -> FP32
return dX, dW
@staticmethod
def simulate_fp8_quantize(tensor, format_type):
"""Simulate FP8 quantization by rounding to FP8 precision."""
if format_type == 'e4m3':
# 3 mantissa bits: round to nearest 1/8
mantissa_bits = 3
elif format_type == 'e5m2':
# 2 mantissa bits: round to nearest 1/4
mantissa_bits = 2
# Simplified simulation: add noise proportional to precision
precision = 2 ** (-mantissa_bits)
# Stochastic rounding (better for training than RTN)
noise = torch.rand_like(tensor) - 0.5
quantized = torch.round(tensor / precision + noise) * precision
return quantized
NVIDIA Transformer Engine
NVIDIA’s Transformer Engine (TE) is the production implementation of FP8 training. It manages per-tensor scales using a delayed scaling strategy:
class TransformerEngineScaling:
"""Simplified Transformer Engine delayed scaling.
TE maintains a history of tensor max values and uses the
max from the PREVIOUS iteration to set the scale for the
CURRENT iteration. This avoids an extra synchronization
point within each iteration.
"""
def __init__(self, format_type='e4m3', history_length=1024, margin=0):
self.format_type = format_type
self.history = []
self.history_length = history_length
self.margin = margin
if format_type == 'e4m3':
self.fp8_max = 448.0
elif format_type == 'e5m2':
self.fp8_max = 57344.0
def get_scale(self):
"""Get scale factor based on history.
Uses max of recent history with safety margin.
"""
if not self.history:
return 1.0
amax = max(self.history)
scale = self.fp8_max / (amax * (2 ** self.margin))
return max(scale, 1e-12)
def update_history(self, tensor_abs_max):
"""Record the current tensor's max for future scaling."""
self.history.append(tensor_abs_max)
if len(self.history) > self.history_length:
self.history.pop(0)
# Transformer Engine integration in PyTorch:
# import transformer_engine.pytorch as te
#
# # Replace nn.Linear with te.Linear
# linear = te.Linear(4096, 4096, bias=False)
#
# # The te.Linear layer automatically:
# # 1. Maintains FP8 scale history for weights and activations
# # 2. Quantizes to FP8 E4M3 in forward, E5M2 in backward
# # 3. Uses FP8 tensor cores for GEMM
# # 4. Accumulates in FP32
# # 5. Returns BF16 output
#
# # Wrap training loop in FP8 context:
# with te.fp8_autocast(enabled=True):
# output = model(input)
# loss = criterion(output, target)
# loss.backward()
Transformer Engine uses the previous iteration’s tensor max to compute the current iteration’s scale. This is safe because tensor magnitudes change slowly between iterations (learning rate is small). The alternative — computing the max within the current iteration — would require an extra device-wide synchronization, adding latency to every GEMM.
FP8 Training Results
FP8 vs BF16 Training: Quality and Throughput
| Model | BF16 Loss | FP8 Loss | Degradation | Throughput Gain |
|---|---|---|---|---|
| GPT-3 175B (reported) | 2.80 | 2.80 | +0.00 | ~1.6x |
| Llama-2 7B (reproduced) | 1.82 | 1.83 | +0.01 | ~1.4x |
| Llama-2 13B (reproduced) | 1.72 | 1.72 | +0.00 | ~1.5x |
| Llama-2 70B (reproduced) | 1.56 | 1.56 | +0.00 | ~1.7x |
| Mistral 7B (reported) | --- | --- | matched | ~1.4x |
H100 Training Throughput by Precision
(Tokens/sec (normalized to BF16=1.0))FP8 tensor cores are 2x faster than BF16, but end-to-end training throughput is only 1.4-1.7x faster. The gap comes from: (1) non-GEMM operations (LayerNorm, softmax, activation functions) still run in BF16/FP32, (2) FP8 quantization/dequantization overhead, (3) scale factor management. As the model grows larger (more compute vs overhead), the throughput gain approaches 2x.
Why INT8 Training Fails
INT8 has been tried for training but does not work well:
def why_int8_training_fails():
"""Demonstrate why INT8 is unsuitable for training.
Three fundamental problems:
1. No subnormals: INT8 has a hard zero at 0 and jumps to +/- 1.
Gradients near zero are quantized to exactly 0, destroying
information. FP8's subnormal range provides gradual underflow.
2. Uniform spacing: Gradients span many orders of magnitude within
a single layer. INT8's uniform spacing wastes precision on the
large range while losing small gradients. FP8's logarithmic
spacing naturally handles this.
3. No signed zero: INT8 symmetric has zero but cannot distinguish
very small positive from very small negative gradients. Both
round to 0, losing the sign information needed for SGD.
"""
# Simulate gradient distribution in a deep network
grad_magnitudes = np.abs(np.random.randn(10000))
# Scale to typical gradient range
grad_magnitudes *= 1e-5
# INT8 symmetric: scale to fit range
int8_max = 127
scale = np.max(grad_magnitudes) / int8_max
int8_quant = np.round(grad_magnitudes / scale).clip(-128, 127)
# Count gradients that become zero
zeros_pct = np.mean(int8_quant == 0) * 100
# FP8 E5M2: logarithmic spacing preserves small values
# (simplified simulation)
fp8_min = 2 ** (-16) # E5M2 smallest subnormal
fp8_nonzero_pct = np.mean(grad_magnitudes > fp8_min) * 100
print(f"Gradients zeroed by INT8: {zeros_pct:.1f}%")
print(f"Gradients representable by FP8 E5M2: {fp8_nonzero_pct:.1f}%")
# Expected: INT8 zeros ~40% of gradients, FP8 preserves >99%
FP4 Training on Blackwell
Blackwell’s FP4 tensor cores enable 4-bit training, achieving 4x the throughput of BF16. Early results use FP4 for the forward pass with FP8 for the backward:
class FP4TrainingConfig:
"""Configuration for FP4 training on Blackwell."""
def __init__(self):
# Forward pass: FP4 weights and activations
self.forward_weight_format = 'fp4_e2m1' # 2 exponent, 1 mantissa
self.forward_activation_format = 'fp4_e2m1'
self.forward_accumulation = 'fp32' # FP32 accumulation
# Backward pass: FP8 gradients (FP4 gradients lose too much)
self.backward_gradient_format = 'fp8_e5m2'
self.backward_weight_format = 'fp8_e4m3'
# Master weights: FP32
self.master_weight_format = 'fp32'
# Optimizer states: FP32 (momentum, variance for Adam)
self.optimizer_format = 'fp32'
# Scaling
self.weight_block_size = 16 # MXFP4: 16 elements share one scale
self.activation_block_size = 16
self.gradient_per_tensor = True
def memory_estimate(self, num_params_B):
"""Estimate memory usage for FP4 training."""
# Master weights: FP32 = 4 bytes per param
master = num_params_B * 4
# Optimizer states (Adam): 2 FP32 states per param
optimizer = num_params_B * 4 * 2
# FP4 weights: 0.5 bytes per param + scales
fp4_weights = num_params_B * 0.5 + num_params_B * 2 / 16 # FP16 scales per 16
# Activation memory: depends on batch size and seq length
# Roughly proportional to batch * seq * hidden * num_layers * 0.5
# (FP4 activations for checkpointing)
return {
'master_weights_GB': master,
'optimizer_states_GB': optimizer,
'fp4_weights_GB': fp4_weights,
'total_static_GB': master + optimizer + fp4_weights,
}
config = FP4TrainingConfig()
mem = config.memory_estimate(70) # 70B model
print(f"70B FP4 training memory estimate:")
print(f" Master weights: {mem['master_weights_GB']:.0f} GB")
print(f" Optimizer states: {mem['optimizer_states_GB']:.0f} GB")
print(f" FP4 weights: {mem['fp4_weights_GB']:.0f} GB")
print(f" Total static: {mem['total_static_GB']:.0f} GB")
The BF16 Default: Why It Persists
Despite the availability of FP8 and FP4, BF16 remains the default training precision for most organizations:
Training Precision Selection Guide
| Format | Throughput vs FP32 | Quality Risk | Complexity | When to Use |
|---|---|---|---|---|
| FP32 | 1.0x | None | Minimal | Debugging, small models |
| BF16 mixed | 2.0x | None | Low | Default for all training |
| FP16 mixed | 2.0x | Low (need loss scaling) | Medium | Legacy, A100 without BF16 |
| FP8 (TE) | ~3.0x | Very low | Medium | Large models on H100+ |
| FP4 (MXFP4) | ~4.0x | Under research | High | Blackwell, experimental |
def recommend_training_precision(
gpu_type,
model_size_B,
risk_tolerance,
team_experience,
):
"""Recommend training precision."""
if gpu_type in ['V100', 'T4']:
return 'FP16 mixed', 'Only FP16 tensor cores available'
if gpu_type == 'A100':
if risk_tolerance == 'zero':
return 'BF16 mixed', 'Safe default, well-validated'
return 'BF16 mixed', 'FP8 not available on A100'
if gpu_type in ['H100', 'H200']:
if risk_tolerance == 'zero':
return 'BF16 mixed', 'Proven safe, slight throughput sacrifice'
if model_size_B >= 13 and team_experience == 'advanced':
return 'FP8 (TE)', 'Meaningful throughput gain at this scale'
return 'BF16 mixed', 'FP8 benefit small for models < 13B'
if gpu_type in ['B200', 'B100']:
if team_experience == 'advanced' and model_size_B >= 70:
return 'FP4 experimental', 'Maximum throughput for large models'
if model_size_B >= 13:
return 'FP8 (TE)', 'Well-validated on Blackwell'
return 'BF16 mixed', 'Default safe choice'
return 'BF16 mixed', 'Unknown GPU, use safe default'
Stochastic Rounding for Training
Training benefits from stochastic rounding instead of round-to-nearest. With RTN, small gradients that are below the quantization step are always rounded to zero, introducing a consistent bias. Stochastic rounding randomly rounds up or down with probability proportional to the distance to the nearest level:
def stochastic_round(x, scale):
"""Stochastic rounding for quantization.
Unlike RTN, stochastic rounding is unbiased:
E[SR(x)] = x
This means that on average, the gradient direction is preserved
even when individual gradients are below the quantization step.
"""
x_scaled = x / scale
x_floor = torch.floor(x_scaled)
# Probability of rounding up = fractional part
prob_up = x_scaled - x_floor
# Random rounding
x_rounded = x_floor + (torch.rand_like(x_scaled) < prob_up).float()
return x_rounded * scale
# Demonstrate: 1000 applications of SR to a small value
# should average to the true value
value = torch.tensor(0.3)
scale = 1.0 # Step size
rounds = torch.tensor([stochastic_round(value, scale).item()
for _ in range(10000)])
print(f"True value: {value.item()}")
print(f"RTN: {torch.round(value / scale).item() * scale}")
print(f"SR mean: {rounds.mean():.4f}")
print(f"SR std: {rounds.std():.4f}")
# SR mean ~ 0.3 (unbiased), RTN = 0.0 (biased to zero)
At FP8 and especially FP4 precision, the quantization step size is large enough that many gradients fall below it. Stochastic rounding preserves the expected gradient direction, allowing convergence despite aggressive quantization. Transformer Engine uses stochastic rounding for FP8 backward pass quantization.
Quantized Optimizer States
Beyond GEMM quantization, optimizer states can also be quantized to reduce memory:
class Int8AdamW:
"""Adam optimizer with INT8 quantized momentum and variance states.
Standard Adam: ~12 bytes per parameter (FP32 m, v, master weight)
INT8 Adam: ~6 bytes per parameter (INT8 m, v, FP32 master weight)
Block-wise quantization: each block of 2048 parameters shares
one FP32 scale for m and one for v.
"""
def __init__(self, params, lr=1e-4, betas=(0.9, 0.999),
eps=1e-8, weight_decay=0.01, block_size=2048):
self.lr = lr
self.beta1, self.beta2 = betas
self.eps = eps
self.weight_decay = weight_decay
self.block_size = block_size
self.states = {}
for p in params:
n = p.numel()
num_blocks = (n + block_size - 1) // block_size
self.states[p] = {
'm_int8': torch.zeros(n, dtype=torch.int8, device=p.device),
'm_scales': torch.zeros(num_blocks, device=p.device),
'v_int8': torch.zeros(n, dtype=torch.int8, device=p.device),
'v_scales': torch.zeros(num_blocks, device=p.device),
'step': 0,
}
def step(self, params_with_grads):
for p, grad in params_with_grads:
state = self.states[p]
state['step'] += 1
t = state['step']
# Dequantize states
m = self._dequantize(state['m_int8'], state['m_scales'])
v = self._dequantize(state['v_int8'], state['v_scales'])
# Standard Adam update
m = self.beta1 * m + (1 - self.beta1) * grad.flatten()
v = self.beta2 * v + (1 - self.beta2) * grad.flatten() ** 2
m_hat = m / (1 - self.beta1 ** t)
v_hat = v / (1 - self.beta2 ** t)
update = m_hat / (v_hat.sqrt() + self.eps)
if self.weight_decay > 0:
update += self.weight_decay * p.data.flatten()
p.data.flatten().add_(-self.lr * update)
# Re-quantize states
state['m_int8'], state['m_scales'] = self._quantize(m)
state['v_int8'], state['v_scales'] = self._quantize(v)
def _quantize(self, tensor):
n = tensor.numel()
num_blocks = (n + self.block_size - 1) // self.block_size
scales = torch.zeros(num_blocks, device=tensor.device)
q = torch.zeros(n, dtype=torch.int8, device=tensor.device)
for b in range(num_blocks):
start = b * self.block_size
end = min(start + self.block_size, n)
block = tensor[start:end]
amax = block.abs().max()
s = amax / 127 if amax > 0 else 1.0
scales[b] = s
q[start:end] = (block / s).round().clamp(-128, 127).to(torch.int8)
return q, scales
def _dequantize(self, q, scales):
n = q.numel()
result = torch.zeros(n, device=q.device)
for b in range(len(scales)):
start = b * self.block_size
end = min(start + self.block_size, n)
result[start:end] = q[start:end].float() * scales[b]
return result