Training a 70B parameter model in FP32 requires 280 GB just for weights (70 billion parameters times 4 bytes). Add optimizer states (AdamW stores two additional FP32 copies: first and second moment), and you need 840 GB. Add activations for a reasonable batch size, and you exceed 1 TB. No single GPU has this much memory.
Mixed precision training solves this by using lower-precision formats (BF16 or FP16) for the bulk of computation while keeping critical operations in FP32. The result: 2x memory savings, 2-8x faster matrix multiplications, and — when done correctly — zero loss in training quality.
This post covers the exact precision hierarchy used in modern LLM training: which operations run in which precision, why, and how to implement it.
Number Formats
1.1 The Three Formats That Matter
import struct
import torch
def analyze_format(name, torch_dtype):
"""Show the bit layout and range of a floating-point format."""
info = torch.finfo(torch_dtype)
return {
"name": name,
"bits": info.bits,
"sign_bits": 1,
"exponent_bits": {
torch.float32: 8,
torch.float16: 5,
torch.bfloat16: 8,
}[torch_dtype],
"mantissa_bits": {
torch.float32: 23,
torch.float16: 10,
torch.bfloat16: 7,
}[torch_dtype],
"max": info.max,
"min_positive": info.tiny,
"eps": info.eps,
"decimal_digits": info.resolution,
}
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
info = analyze_format(str(dtype), dtype)
print(f"{info['name']}: {info['bits']} bits "
f"({info['exponent_bits']}e + {info['mantissa_bits']}m), "
f"range [{info['min_positive']:.2e}, {info['max']:.2e}], "
f"eps={info['eps']:.2e}")
Output:
torch.float32: 32 bits (8e + 23m), range [1.18e-38, 3.40e+38], eps=1.19e-07
torch.float16: 16 bits (5e + 10m), range [6.10e-05, 6.55e+04], eps=9.77e-04
torch.bfloat16: 16 bits (8e + 7m), range [1.18e-38, 3.39e+38], eps=7.81e-03
The critical difference: BF16 has the same exponent range as FP32 (8 exponent bits) but much lower precision (7 mantissa bits vs. 23). FP16 has higher precision than BF16 (10 mantissa bits) but drastically smaller range (max 65,504 vs. ).
1.2 Why BF16 Won Over FP16
FP16’s range problem: gradients during LLM training can have magnitudes spanning from to . FP16’s maximum value is 65,504. A single gradient magnitude above this causes overflow, producing infinity, which corrupts the entire training run.
BF16’s range matches FP32 (), so overflows that would not occur in FP32 will not occur in BF16 either. The tradeoff is precision: BF16 has only 7 mantissa bits, meaning it can represent about 2.5 decimal digits. For weight magnitudes (typically 0.001 to 1.0), BF16 quantization error is:
This is acceptable for forward and backward passes, where the result will be accumulated into FP32 master weights.
def demonstrate_precision_loss():
"""Show the precision difference between BF16 and FP32."""
# Create a weight value
w_fp32 = torch.tensor(0.123456789, dtype=torch.float32)
w_bf16 = w_fp32.to(torch.bfloat16)
w_fp16 = w_fp32.to(torch.float16)
print(f"FP32: {w_fp32.item():.10f}")
print(f"BF16: {w_bf16.item():.10f}")
print(f"FP16: {w_fp16.item():.10f}")
print(f"BF16 error: {abs(w_fp32.item() - w_bf16.item()):.2e}")
print(f"FP16 error: {abs(w_fp32.item() - w_fp16.item()):.2e}")
# FP32: 0.1234567890
# BF16: 0.1230468750 (error: ~4e-4)
# FP16: 0.1234130859 (error: ~4e-5)
def demonstrate_range_problem():
"""Show why FP16 overflows but BF16 does not."""
# Gradient that exceeds FP16 range
grad = torch.tensor(100000.0, dtype=torch.float32)
grad_bf16 = grad.to(torch.bfloat16)
grad_fp16 = grad.to(torch.float16)
print(f"FP32 grad: {grad.item()}")
print(f"BF16 grad: {grad_bf16.item()}") # 98304.0 (rounded but finite)
print(f"FP16 grad: {grad_fp16.item()}") # inf (overflow)
FP16 overflow is the primary reason BF16 replaced FP16 for LLM training. With FP16, you need loss scaling to prevent overflow. With BF16, loss scaling is unnecessary because the exponent range matches FP32. Every major training framework and hardware vendor now defaults to BF16.
The Precision Hierarchy
2.1 Overview
Modern LLM training uses a three-tier precision hierarchy:
| Operation | Precision | Why |
|---|---|---|
| Forward pass (matmuls) | BF16 | 2x faster on tensor cores, sufficient precision |
| Backward pass (matmuls) | BF16 | Same as forward |
| Forward pass (norms, softmax) | FP32 | Numerical stability requires high precision |
| Backward pass (norms, softmax) | FP32 | Gradient precision for sensitive ops |
| Master weights | FP32 | Accumulation of small updates requires precision |
| Optimizer states (m, v) | FP32 | Running averages must not lose precision |
| Gradient accumulation | FP32 | Sum of many small values needs precision |
| Weight update | FP32 | update = lr * m / sqrt(v); computed in FP32 |
| Loss computation | FP32 | Cross-entropy with log-softmax needs range |
2.2 Why Master Weights Must Be FP32
The weight update in AdamW is:
For a typical learning rate and a typical update magnitude , the update per step is roughly . In BF16, the smallest representable change to a weight of magnitude 0.1 is:
The update () is 250x smaller than BF16’s resolution at that magnitude. In BF16, the update would be rounded to zero. The weight would never change. Training would not converge.
def demonstrate_stale_weights():
"""Show that BF16 weights cannot accumulate small updates."""
w_bf16 = torch.tensor(0.1, dtype=torch.bfloat16)
w_fp32 = torch.tensor(0.1, dtype=torch.float32)
update = 3e-6 # Typical per-step update magnitude
# Simulate 10000 updates
for _ in range(10000):
w_bf16 = w_bf16 - update # BF16 arithmetic
w_fp32 = w_fp32 - update # FP32 arithmetic
print(f"BF16 after 10K updates: {w_bf16.item():.6f}") # Still ~0.1
print(f"FP32 after 10K updates: {w_fp32.item():.6f}") # 0.070000
# BF16 rounds every update to 0 -- no learning happens
The solution: maintain a FP32 copy of all weights (master weights). The training loop becomes:
- Cast master weights (FP32) to BF16 for forward/backward
- Compute gradients in BF16
- Cast gradients to FP32
- Update master weights in FP32
- Repeat
2.3 Memory Budget
def memory_budget(params_B, seq_len=4096, batch_size=4, n_layers=80):
"""Calculate memory budget for mixed-precision training.
Args:
params_B: number of parameters in billions
"""
params = params_B * 1e9
# Weights
master_weights_fp32 = params * 4 # 4 bytes per param
bf16_weights = params * 2 # 2 bytes per param
# Optimizer (AdamW)
optimizer_m = params * 4 # First moment (FP32)
optimizer_v = params * 4 # Second moment (FP32)
# Gradients
gradients_fp32 = params * 4 # Accumulated in FP32
# Total model state
total_model = (master_weights_fp32 + bf16_weights +
optimizer_m + optimizer_v + gradients_fp32)
print(f"Model: {params_B}B parameters")
print(f" Master weights (FP32): {master_weights_fp32 / 1e9:.1f} GB")
print(f" BF16 weights: {bf16_weights / 1e9:.1f} GB")
print(f" Optimizer m (FP32): {optimizer_m / 1e9:.1f} GB")
print(f" Optimizer v (FP32): {optimizer_v / 1e9:.1f} GB")
print(f" Gradients (FP32): {gradients_fp32 / 1e9:.1f} GB")
print(f" Total model state: {total_model / 1e9:.1f} GB")
print(f" Per-param bytes: {total_model / params:.1f}")
return total_model
# Example: Llama 70B
memory_budget(70)
# Master weights (FP32): 280.0 GB
# BF16 weights: 140.0 GB
# Optimizer m (FP32): 280.0 GB
# Optimizer v (FP32): 280.0 GB
# Gradients (FP32): 280.0 GB
# Total model state: 1260.0 GB
# Per-param bytes: 18.0
18 bytes per parameter. For 70B parameters, that is 1.26 TB of model state alone, before activations. This is why 70B training requires at least 16 H100 GPUs (80 GB each, 1.28 TB total) with FSDP to shard the model state across GPUs.
Per-Operation Precision Requirements
3.1 Matrix Multiplications: BF16
GEMMs (General Matrix Multiplications) dominate training FLOPs. On NVIDIA tensor cores, BF16 GEMMs are 2x faster than FP32 GEMMs (H100: 990 TFLOP/s BF16 vs. 495 TFLOP/s FP32). The internal accumulation inside the tensor core is done in FP32 — only the inputs and outputs are BF16.
import torch
def bf16_matmul(a, b):
"""BF16 matmul with FP32 internal accumulation.
Tensor cores compute: C_fp32 = A_bf16 @ B_bf16 (accumulated in FP32)
Then: C_bf16 = cast(C_fp32)
The accumulation in FP32 is critical -- without it,
summing many BF16 products would lose too much precision.
"""
# PyTorch handles this automatically when inputs are BF16
a_bf16 = a.to(torch.bfloat16)
b_bf16 = b.to(torch.bfloat16)
# Tensor core does: FP32 accumulation of BF16 * BF16 products
c = torch.matmul(a_bf16, b_bf16) # Output is BF16
return c
# The key matmuls in a transformer:
# QKV projection: X_bf16 @ W_qkv_bf16
# Attention scores: Q_bf16 @ K^T_bf16
# Attention output: attn_bf16 @ V_bf16
# Output projection: attn_out_bf16 @ W_o_bf16
# FFN up/gate/down: X_bf16 @ W_ff_bf16
3.2 Normalization: FP32
RMSNorm and LayerNorm require FP32 because they compute statistics (variance, RMS) over the hidden dimension. In BF16, the sum of squares can overflow or suffer catastrophic cancellation:
class RMSNormMixedPrecision(torch.nn.Module):
"""RMSNorm with FP32 computation, BF16 input/output."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
# Input may be BF16
input_dtype = x.dtype
# Upcast to FP32 for the norm computation
x_fp32 = x.float()
# Compute RMS in FP32
rms = torch.sqrt(x_fp32.pow(2).mean(dim=-1, keepdim=True) + self.eps)
# Normalize in FP32, then cast back
normed = x_fp32 / rms
# Apply weight (in FP32) and cast back to input dtype
return (normed * self.weight.float()).to(input_dtype)
Why FP32 is needed for norms: consider hidden dimension . If all elements are 1.0 in BF16, the sum of squares is 4096.0, and the mean is 1.0. But with BF16 addition of 4096 terms, the accumulated rounding error can be up to . That is a 50% error in the variance estimate, which would cause wildly wrong normalization.
3.3 Softmax: FP32
The softmax function involves exponentiation and normalization. Both are sensitive to precision:
def softmax_precision_comparison():
"""Show why softmax needs FP32."""
# Attention scores for a sequence of length 2048
torch.manual_seed(42)
scores = torch.randn(1, 32, 2048, 2048) * 3.0 # Typical scale
# BF16 softmax
scores_bf16 = scores.to(torch.bfloat16)
attn_bf16 = torch.softmax(scores_bf16, dim=-1)
# FP32 softmax
attn_fp32 = torch.softmax(scores, dim=-1)
# Compare: do rows sum to 1.0?
row_sums_bf16 = attn_bf16.float().sum(dim=-1)
row_sums_fp32 = attn_fp32.sum(dim=-1)
print(f"BF16 softmax row sums: "
f"mean={row_sums_bf16.mean():.6f}, "
f"std={row_sums_bf16.std():.6f}")
print(f"FP32 softmax row sums: "
f"mean={row_sums_fp32.mean():.6f}, "
f"std={row_sums_fp32.std():.6f}")
# BF16 rows may not sum to exactly 1.0, causing attention to
# systematically over- or under-weight certain positions
3.4 Cross-Entropy Loss: FP32
The cross-entropy loss involves log-softmax, which can produce very negative values (log of small probabilities). BF16 cannot represent values below (same as FP32 range), but the precision is insufficient for stable gradient computation:
def cross_entropy_precision():
"""Cross-entropy always computed in FP32."""
vocab_size = 128256
batch_seq = 4 * 4096 # batch * seq_len
# Logits from the model (BF16)
logits_bf16 = torch.randn(batch_seq, vocab_size, dtype=torch.bfloat16)
targets = torch.randint(0, vocab_size, (batch_seq,))
# Must upcast to FP32 for loss computation
logits_fp32 = logits_bf16.float()
loss = torch.nn.functional.cross_entropy(logits_fp32, targets)
# The gradient of CE w.r.t. logits is:
# d_loss/d_logit_i = softmax(logit_i) - target_i
# This difference can be very small (1e-6 to 1e-8),
# which would be rounded to 0 in BF16
return loss
3.5 Summary: The Precision Map
class TransformerLayerMixedPrecision(torch.nn.Module):
"""A single transformer layer with correct mixed precision."""
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn_norm = RMSNormMixedPrecision(d_model) # FP32 internal
self.ffn_norm = RMSNormMixedPrecision(d_model) # FP32 internal
# All linear layers store weights in BF16
self.q_proj = torch.nn.Linear(d_model, d_model, bias=False)
self.k_proj = torch.nn.Linear(d_model, d_model, bias=False)
self.v_proj = torch.nn.Linear(d_model, d_model, bias=False)
self.o_proj = torch.nn.Linear(d_model, d_model, bias=False)
self.gate_proj = torch.nn.Linear(d_model, d_ff, bias=False)
self.up_proj = torch.nn.Linear(d_model, d_ff, bias=False)
self.down_proj = torch.nn.Linear(d_ff, d_model, bias=False)
self.n_heads = n_heads
self.head_dim = d_model // n_heads
def forward(self, x):
"""
x: BF16 tensor [B, S, d_model]
Precision flow:
1. x (BF16) -> norm (FP32 internal, BF16 output) -> BF16
2. BF16 -> Q,K,V projections (BF16 matmul) -> BF16
3. BF16 -> attention scores Q@K^T (BF16 matmul) -> BF16
4. BF16 -> softmax (FP32 internal, BF16 output) -> BF16
5. BF16 -> attention @ V (BF16 matmul) -> BF16
6. BF16 -> output proj (BF16 matmul) -> BF16
7. BF16 -> residual add -> BF16
8. Repeat for FFN
"""
# Attention block
normed = self.attn_norm(x) # BF16 -> FP32 -> BF16
B, S, D = normed.shape
q = self.q_proj(normed) # BF16 matmul
k = self.k_proj(normed)
v = self.v_proj(normed)
q = q.reshape(B, S, self.n_heads, self.head_dim).transpose(1, 2)
k = k.reshape(B, S, self.n_heads, self.head_dim).transpose(1, 2)
v = v.reshape(B, S, self.n_heads, self.head_dim).transpose(1, 2)
# Attention scores (BF16 matmul)
scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
# Softmax in FP32
attn_weights = torch.softmax(scores.float(), dim=-1).to(x.dtype)
# Weighted sum (BF16 matmul)
attn_out = attn_weights @ v
attn_out = attn_out.transpose(1, 2).reshape(B, S, D)
attn_out = self.o_proj(attn_out)
# Residual (BF16 add)
x = x + attn_out
# FFN block
normed = self.ffn_norm(x)
ffn_out = self.down_proj(
torch.nn.functional.silu(self.gate_proj(normed)) * self.up_proj(normed)
)
x = x + ffn_out
return x
Loss Scaling (FP16 Only)
4.1 The FP16 Gradient Problem
If you must use FP16 instead of BF16 (older hardware without BF16 support), gradients can underflow. Small gradients () fall below FP16’s minimum positive value () and become zero. Loss scaling fixes this by multiplying the loss by a large constant before backward, then dividing the gradients by the same constant after backward:
class LossScaler:
"""Dynamic loss scaling for FP16 training.
Not needed for BF16 -- only for FP16 on older hardware.
"""
def __init__(self, init_scale=2**16, growth_factor=2.0,
backoff_factor=0.5, growth_interval=2000):
self.scale = init_scale
self.growth_factor = growth_factor
self.backoff_factor = backoff_factor
self.growth_interval = growth_interval
self.steps_since_last_overflow = 0
def scale_loss(self, loss):
"""Multiply loss by scale factor before backward."""
return loss * self.scale
def unscale_gradients(self, optimizer):
"""Divide gradients by scale factor after backward."""
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is not None:
param.grad.data /= self.scale
def update(self, overflow_detected):
"""Adjust scale based on whether overflow occurred."""
if overflow_detected:
# Overflow: reduce scale, skip this step
self.scale *= self.backoff_factor
self.steps_since_last_overflow = 0
else:
self.steps_since_last_overflow += 1
if self.steps_since_last_overflow >= self.growth_interval:
# No overflow for a while: try increasing scale
self.scale *= self.growth_factor
self.steps_since_last_overflow = 0
def check_overflow(self, optimizer):
"""Check if any gradient contains inf or nan."""
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is not None:
if torch.isinf(param.grad).any() or torch.isnan(param.grad).any():
return True
return False
4.2 Why BF16 Does Not Need Loss Scaling
BF16 has the same exponent range as FP32 (8 exponent bits). Any value representable in FP32 (in terms of magnitude) is representable in BF16 (with reduced precision). Gradients of magnitude map to the BF16 value (the nearest representable value). The precision loss is up to 0.78%, but the value is not zero.
def bf16_vs_fp16_gradient_survival():
"""Compare how small gradients survive in BF16 vs FP16."""
small_grads = [1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
for g in small_grads:
g_fp32 = torch.tensor(g, dtype=torch.float32)
g_bf16 = g_fp32.to(torch.bfloat16)
g_fp16 = g_fp32.to(torch.float16)
print(f"Grad {g:.0e}: "
f"FP32={g_fp32.item():.2e}, "
f"BF16={g_bf16.item():.2e}, "
f"FP16={g_fp16.item():.2e}")
# Output:
# Grad 1e-04: FP32=1.00e-04, BF16=1.00e-04, FP16=1.00e-04
# Grad 1e-05: FP32=1.00e-05, BF16=9.78e-06, FP16=9.97e-06
# Grad 1e-06: FP32=1.00e-06, BF16=9.78e-07, FP16=1.01e-06
# Grad 1e-07: FP32=1.00e-07, BF16=1.01e-07, FP16=0.00e+00 <-- FP16 underflow
# Grad 1e-08: FP32=1.00e-08, BF16=9.31e-09, FP16=0.00e+00 <-- FP16 underflow
If your hardware supports BF16 (NVIDIA Ampere and later, AMD MI250 and later, Google TPUs), always use BF16 over FP16. You avoid the entire loss scaling complexity with no quality cost. The only reason to use FP16 is on older hardware (NVIDIA V100, T4) that lacks BF16 tensor cores.
Implementation with torch.cuda.amp
5.1 The AMP Autocast Context Manager
PyTorch’s Automatic Mixed Precision (AMP) handles precision casting automatically. You wrap the forward pass in torch.cuda.amp.autocast, and it applies the precision hierarchy described above:
import torch
from torch.cuda.amp import autocast, GradScaler
def training_step_bf16(model, batch, optimizer):
"""A single training step with BF16 mixed precision.
No GradScaler needed for BF16 -- only for FP16.
"""
input_ids = batch["input_ids"].cuda()
labels = batch["labels"].cuda()
# autocast handles: matmuls in BF16, norms/softmax in FP32
with autocast(dtype=torch.bfloat16):
outputs = model(input_ids)
# Loss computed in FP32 (autocast upcasts for CE loss)
loss = torch.nn.functional.cross_entropy(
outputs.logits.float(), # Explicit upcast for safety
labels.reshape(-1)
)
# Backward (gradients computed in mixed precision)
loss.backward()
# Gradient clipping (in FP32, on master weights' gradients)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Optimizer step (updates FP32 master weights)
optimizer.step()
optimizer.zero_grad()
return loss.item()
def training_step_fp16(model, batch, optimizer, scaler):
"""A single training step with FP16 mixed precision.
Requires GradScaler to prevent gradient underflow.
"""
input_ids = batch["input_ids"].cuda()
labels = batch["labels"].cuda()
with autocast(dtype=torch.float16):
outputs = model(input_ids)
loss = torch.nn.functional.cross_entropy(
outputs.logits.float(),
labels.reshape(-1)
)
# Scale loss before backward to prevent gradient underflow
scaler.scale(loss).backward()
# Unscale gradients for clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Step (with overflow checking)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
return loss.item()
5.2 What autocast Does Internally
The autocast context manager maintains a list of operations and their target precision:
# Operations that autocast runs in BF16/FP16 (fast on tensor cores):
LOWER_PRECISION_OPS = [
"torch.matmul",
"torch.nn.functional.linear",
"torch.nn.functional.conv1d",
"torch.nn.functional.conv2d",
"torch.bmm",
"torch.addmm",
"torch.addbmm",
"torch.baddbmm",
]
# Operations that autocast keeps in FP32 (need precision):
FP32_OPS = [
"torch.nn.functional.softmax",
"torch.nn.functional.cross_entropy",
"torch.nn.functional.log_softmax",
"torch.nn.functional.layer_norm",
"torch.nn.functional.group_norm",
"torch.nn.functional.batch_norm",
"torch.pow",
"torch.norm",
"torch.sum", # Large reductions
"torch.mean",
]
# Operations that autocast promotes to the widest input type:
PROMOTE_OPS = [
"torch.add",
"torch.sub",
"torch.mul",
"torch.div",
"torch.cat",
"torch.stack",
]
5.3 Full Training Loop
def train(model, train_loader, val_loader, config):
"""Complete training loop with BF16 mixed precision."""
# Model in BF16 for forward/backward
model = model.cuda().to(torch.bfloat16)
# Optimizer operates on FP32 master weights
# PyTorch handles this: optimizer stores FP32 copies internally
# when model params are BF16
optimizer = torch.optim.AdamW(
model.parameters(),
lr=config["lr"],
betas=(0.9, 0.95),
eps=1e-8,
weight_decay=config["weight_decay"],
)
# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=config["total_steps"], eta_min=config["min_lr"]
)
# Gradient accumulation for effective batch size
grad_accum_steps = config["grad_accum_steps"]
step = 0
for epoch in range(config["epochs"]):
model.train()
for micro_step, batch in enumerate(train_loader):
with autocast(dtype=torch.bfloat16):
outputs = model(batch["input_ids"].cuda())
loss = torch.nn.functional.cross_entropy(
outputs.logits.float().reshape(-1, outputs.logits.size(-1)),
batch["labels"].cuda().reshape(-1),
)
# Scale for gradient accumulation
loss = loss / grad_accum_steps
loss.backward()
if (micro_step + 1) % grad_accum_steps == 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=1.0
)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
step += 1
if step % config["log_interval"] == 0:
print(f"Step {step}: loss={loss.item() * grad_accum_steps:.4f}, "
f"lr={scheduler.get_last_lr()[0]:.2e}")
if step % config["eval_interval"] == 0:
val_loss = evaluate(model, val_loader)
print(f"Step {step}: val_loss={val_loss:.4f}")
return model
def evaluate(model, val_loader):
"""Evaluate model in BF16."""
model.eval()
total_loss = 0
total_tokens = 0
with torch.no_grad(), autocast(dtype=torch.bfloat16):
for batch in val_loader:
outputs = model(batch["input_ids"].cuda())
loss = torch.nn.functional.cross_entropy(
outputs.logits.float().reshape(-1, outputs.logits.size(-1)),
batch["labels"].cuda().reshape(-1),
reduction="sum",
)
total_loss += loss.item()
total_tokens += batch["labels"].numel()
return total_loss / total_tokens
Precision-Related Failure Modes
6.1 Loss Spikes from BF16 Accumulation
When using BF16, gradient accumulation across many micro-batches can lose precision. Each BF16 addition loses up to 0.78% relative error. Over 64 accumulation steps, the error compounds:
def gradient_accumulation_error(n_steps=64):
"""Show precision loss from BF16 gradient accumulation."""
# Simulate: accumulate small gradients
grad_per_step = torch.randn(4096, dtype=torch.float32) * 0.001
# FP32 accumulation (ground truth)
accum_fp32 = torch.zeros(4096, dtype=torch.float32)
for _ in range(n_steps):
accum_fp32 += grad_per_step
# BF16 accumulation (problematic)
accum_bf16 = torch.zeros(4096, dtype=torch.bfloat16)
for _ in range(n_steps):
accum_bf16 += grad_per_step.to(torch.bfloat16)
# Compare
error = (accum_bf16.float() - accum_fp32).abs()
relative_error = error / (accum_fp32.abs() + 1e-10)
print(f"Accumulation over {n_steps} steps:")
print(f" Mean relative error: {relative_error.mean():.4f}")
print(f" Max relative error: {relative_error.max():.4f}")
# With 64 steps: mean relative error can reach 5-10%
The fix: accumulate gradients in FP32, even if individual gradients are computed in BF16.
def safe_gradient_accumulation(model, micro_batches, grad_accum_steps):
"""Accumulate gradients in FP32 for precision."""
# Option 1: PyTorch autograd accumulates in param.grad dtype
# If param is BF16, grad is BF16 -- problematic for many steps
# Option 2: Keep separate FP32 gradient buffers
fp32_grads = {
name: torch.zeros_like(param, dtype=torch.float32)
for name, param in model.named_parameters()
}
for micro_batch in micro_batches:
with autocast(dtype=torch.bfloat16):
loss = compute_loss(model, micro_batch) / grad_accum_steps
loss.backward()
# Accumulate in FP32
with torch.no_grad():
for name, param in model.named_parameters():
if param.grad is not None:
fp32_grads[name] += param.grad.float()
param.grad = None # Free BF16 grad
# Copy FP32 accumulated grads back
with torch.no_grad():
for name, param in model.named_parameters():
param.grad = fp32_grads[name].to(param.dtype)
6.2 Norm Instability
If norms accidentally run in BF16, training can become unstable after tens of thousands of steps. The symptom: loss spikes that recover but become more frequent:
def diagnose_norm_precision(model):
"""Check that all norm layers operate in FP32 internally."""
for name, module in model.named_modules():
if "norm" in name.lower():
# Check weight dtype
if hasattr(module, "weight"):
w_dtype = module.weight.dtype
if w_dtype != torch.float32:
print(f"WARNING: {name} weight is {w_dtype}, should be FP32")
# Test forward precision
test_input = torch.randn(1, 10, module.weight.shape[0],
device="cuda", dtype=torch.bfloat16)
with torch.no_grad():
output = module(test_input)
# Check if intermediate computation uses FP32
# (this is a simplified check)
print(f"{name}: input={test_input.dtype}, output={output.dtype}")
6.3 Embedding and Output Head Precision
The embedding lookup and the final output projection (logits) are often overlooked. The embedding lookup itself is exact (just a table lookup), but the output logits feed into softmax and cross-entropy, which require FP32:
class SafeLMHead(torch.nn.Module):
"""Output head that ensures FP32 for loss computation."""
def __init__(self, d_model, vocab_size):
super().__init__()
self.proj = torch.nn.Linear(d_model, vocab_size, bias=False)
def forward(self, hidden_states, labels=None):
# Logits in BF16 (matmul)
logits = self.proj(hidden_states)
if labels is not None:
# Upcast to FP32 for loss
loss = torch.nn.functional.cross_entropy(
logits.float().reshape(-1, logits.size(-1)),
labels.reshape(-1),
)
return logits, loss
return logits
FP8: The Next Frontier
7.1 FP8 Formats on Hopper/Blackwell
NVIDIA Hopper (H100) introduced FP8 tensor cores. Two formats:
- E4M3: 4 exponent bits, 3 mantissa bits. Range: . Precision: (6.25%). Used for forward pass.
- E5M2: 5 exponent bits, 2 mantissa bits. Range: . Precision: (25%). Used for backward pass (needs more range for gradients).
def fp8_training_concept():
"""FP8 training: the next step after BF16.
H100 FP8 tensor cores: 1980 TFLOP/s (2x BF16, 4x FP32)
"""
config = {
"forward_matmuls": "FP8-E4M3 (higher precision for activations)",
"backward_matmuls": "FP8-E5M2 (higher range for gradients)",
"norms": "FP32 (still needs high precision)",
"softmax": "FP32 (still needs high precision)",
"master_weights": "FP32 (still needs accumulation precision)",
"optimizer": "FP32 (still needs accumulation precision)",
"per_tensor_scaling": "Required (dynamic scale per tensor)",
}
return config
7.2 Per-Tensor Scaling for FP8
FP8’s narrow range ( for E4M3) means that tensors must be scaled to fit. Each tensor gets a per-tensor scale factor that maps its values into the FP8 representable range:
def fp8_quantize(tensor, fp8_max=448.0):
"""Quantize a tensor to FP8 with per-tensor scaling.
scale = fp8_max / tensor.abs().max()
fp8_tensor = round(tensor * scale)
Dequantize: tensor_approx = fp8_tensor / scale
"""
amax = tensor.abs().max()
scale = fp8_max / amax.clamp(min=1e-12)
# Scale and clamp to FP8 range
scaled = tensor * scale
quantized = scaled.clamp(-fp8_max, fp8_max)
# In real hardware, this is stored as 8-bit FP values
# Here we simulate with float
return quantized, scale
def fp8_matmul(a, b, a_scale, b_scale):
"""Simulated FP8 matrix multiplication.
Real hardware: tensor cores compute A_fp8 @ B_fp8 with FP32 accumulation.
Result: C_fp32 = (A_fp8 @ B_fp8) / (a_scale * b_scale)
"""
c = torch.matmul(a, b) # FP8 matmul (simulated in float)
c = c / (a_scale * b_scale) # Descale
return c
FP8 training on H100 achieves 1980 TFLOP/s for matrix multiplications — 2x over BF16 (990 TFLOP/s). For a 70B model, this reduces training time by roughly 30-40% (not 2x, because not all operations are matmuls). The quality impact is minimal with proper per-tensor scaling, typically less than 0.1% loss increase.
Practical Checklist
def mixed_precision_checklist():
"""Checklist for correct mixed precision training."""
return {
"1_use_bf16": (
"Use BF16, not FP16, if hardware supports it. "
"Eliminates need for loss scaling."
),
"2_fp32_master_weights": (
"Optimizer must maintain FP32 copies of all weights. "
"PyTorch AdamW does this automatically."
),
"3_fp32_norms": (
"RMSNorm/LayerNorm internal computation must be FP32. "
"Upcast input, compute, downcast output."
),
"4_fp32_softmax": (
"Attention softmax must be computed in FP32. "
"FlashAttention handles this internally."
),
"5_fp32_loss": (
"Cross-entropy loss must be computed in FP32. "
"Upcast logits before loss function."
),
"6_fp32_grad_accum": (
"If using many gradient accumulation steps (more than 8), "
"accumulate in FP32 to prevent drift."
),
"7_no_bf16_reduction": (
"Never reduce (sum, mean) over large dimensions in BF16. "
"Upcast first, reduce in FP32."
),
"8_check_grad_norms": (
"Monitor gradient norms per layer during training. "
"Sudden spikes indicate precision issues."
),
}
Mixed Precision Training Speed (Llama 7B, single H100)
| Precision | Throughput (tok/s) | Speedup vs FP32 |
|---|---|---|
| FP32 (all operations) | 1,200 | baseline |
| BF16 mixed (standard) | 3,400 | +183% |
| BF16 + torch.compile | 4,100 | +242% |
| FP8 mixed (H100) | 5,800 | +383% |
Mixed precision training is not optional for LLMs. It is a prerequisite. The precision hierarchy described in this post — BF16 for matmuls, FP32 for norms/softmax/loss/optimizer/master weights — is the standard used by every major training framework. Understanding why each operation requires its specific precision prevents subtle training bugs that manifest as loss spikes, divergence, or silently degraded model quality.
References
- Micikevicius, P. et al. “Mixed Precision Training.” ICLR 2018.
- Kalamkar, D. et al. “A Study of BFLOAT16 for Deep Learning Training.” arXiv 2019.
- NVIDIA. “Transformer Engine: FP8 Training.” Documentation, 2023.
- Dettmers, T. et al. “LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale.” NeurIPS 2022.
- PyTorch. “Automatic Mixed Precision.” Documentation, 2024.