Every floating-point multiply in FP32 that could have run in a lower precision format is wasted memory bandwidth, wasted compute cycles, and wasted power. Modern GPU architectures have made this tradeoff explicit: their fastest execution units — Tensor Cores — operate on FP16, BF16, or FP8, not FP32. If your training loop runs in pure FP32, the most powerful hardware on your chip sits idle. Mixed precision training exists to fix this: run forward and backward passes in reduced precision for throughput, keep master weights in FP32 for numerical correctness.
This post covers the full precision landscape from FP32 down to FP4, with the engineering details that matter in practice: why BF16 displaced FP16, how dynamic loss scaling actually works, what NVIDIA’s Transformer Engine does under the hood for FP8, and exact memory calculations for production-scale models.
Why FP32 Is Wasteful for Training
The Memory Bandwidth Argument
GPU compute has scaled faster than memory bandwidth for over a decade. An H100 SXM delivers 3,958 TFLOPS of FP8 Tensor Core compute but only 3.35 TB/s of HBM3 bandwidth. The arithmetic intensity required to keep the chip busy at peak throughput is over 1,000 ops/byte for FP8. Most training workloads fall well below this, making them memory-bandwidth bound.
Every parameter stored in FP32 (4 bytes) instead of FP16/BF16 (2 bytes) doubles the bandwidth required to load and store it. For the massive weight matrices in transformer models, this means the GPU spends more cycles waiting for data than doing useful math. Reducing precision to 16 bits halves the traffic. Reducing to 8 bits quarters it.
Memory Bandwidth per Parameter by Precision
(bytes)The Tensor Core Argument
NVIDIA Tensor Cores, starting with Volta (V100) in 2017, are specialized matrix-multiply-accumulate units. They operate on reduced-precision inputs and accumulate in higher precision. Critically, they are the only way to reach peak FLOPS on modern NVIDIA GPUs. The CUDA core FP32 throughput on an H100 is 67 TFLOPS. The Tensor Core FP16 throughput is 1,979 TFLOPS — a 30x gap. Running pure FP32 training on an H100 means you are using roughly 1.7% of the chip’s peak compute capability.
H100 SXM Peak Throughput by Precision
| Precision | Tensor Core TFLOPS | vs FP32 CUDA | Format |
|---|---|---|---|
| FP32 (CUDA cores) | 67 | 1.0x (baseline) | IEEE 754 |
| TF32 (Tensor Cores) | 989 | 14.8x | 19-bit internal |
| FP16 (Tensor Cores) | 1,979 | 29.5x | IEEE 754 half |
| BF16 (Tensor Cores) | 1,979 | 29.5x | Brain float |
| FP8 (Tensor Cores) | 3,958 | 59.1x | E4M3 / E5M2 |
The takeaway is clear: if you want to use the hardware you paid for, you must run in reduced precision. Mixed precision training is the technique that lets you do this without destroying model quality.
The Core Idea: Mixed Precision Training
The foundational approach, introduced by Micikevicius et al. (2018), maintains three copies of information:
- FP16 weights — used in the forward and backward passes for compute
- FP32 master weights — the authoritative copy, updated by the optimizer
- FP16 gradients — computed during backpropagation, optionally scaled
Each training step: (a) copy FP32 master weights to FP16, (b) run forward pass in FP16, (c) compute loss in FP16, (d) scale the loss, (e) run backward pass in FP16, (f) unscale gradients, (g) update FP32 master weights with optimizer. The FP32 master weights are essential because optimizer updates can be extremely small — on the order of when multiplying a learning rate of by a gradient of . In FP16, the smallest representable subnormal is , and values below lose significant precision. The FP32 master copy preserves these tiny updates.
IEEE 754 Floating-Point: The Bit Layout That Matters
To understand why BF16 displaced FP16, you need to understand how floating-point numbers are encoded.
A floating-point number is represented as:
where is the sign bit, is the stored exponent, bias is for exponent bits, and is the fractional mantissa.
Floating-Point Format Comparison
| Format | Total Bits | Sign | Exponent | Mantissa | Range (max) | Smallest Normal | Precision (decimal digits) |
|---|---|---|---|---|---|---|---|
| FP32 | 32 | 1 | 8 | 23 | 3.4e38 | 1.2e-38 | ~7.2 |
| TF32 | 19 | 1 | 8 | 10 | 3.4e38 | 1.2e-38 | ~3.3 |
| FP16 | 16 | 1 | 5 | 10 | 65,504 | 6.1e-5 | ~3.3 |
| BF16 | 16 | 1 | 8 | 7 | 3.4e38 | 1.2e-38 | ~2.4 |
| FP8 E4M3 | 8 | 1 | 4 | 3 | 448 | 1.5e-2 (sub: 2e-3) | ~1.1 |
| FP8 E5M2 | 8 | 1 | 5 | 2 | 57,344 | 6.1e-5 | ~0.9 |
| FP4 E2M1 | 4 | 1 | 2 | 1 | 6 | 1.0 | ~0.6 |
The two numbers that matter most are exponent bits (which determine representable range) and mantissa bits (which determine precision within that range). This is the fundamental tradeoff that defines every format in the precision landscape.
FP16 vs BF16: Why BF16 Won
FP16: More Precision, Less Range
FP16 (IEEE 754 binary16) uses 5 exponent bits and 10 mantissa bits. This gives it decent precision — roughly 3.3 decimal digits — but a maximum representable value of only 65,504 and a minimum normal of about .
The problem for training: gradients and activations routinely exceed 65,504 (overflow) or fall below (underflow into the subnormal range where precision degrades, or to zero). Both situations are catastrophic. Overflow produces infinities and NaNs that propagate and destroy the training run. Underflow silently zeros out gradients, causing the model to stop learning.
This is why FP16 mixed precision training requires loss scaling — an engineering workaround to shift gradient magnitudes into FP16’s representable range.
BF16: Same Range as FP32, Less Precision
BF16 (Brain Floating Point, developed at Google Brain) uses 8 exponent bits and 7 mantissa bits. The 8 exponent bits give it exactly the same dynamic range as FP32: max value , min normal . The tradeoff is precision — only about 2.4 decimal digits, compared to FP16’s 3.3.
Why Range Matters More Than Precision
In practice, BF16 won decisively for training because:
-
No loss scaling needed. BF16’s range matches FP32, so gradients almost never overflow or underflow. This eliminates an entire class of training instabilities and removes the engineering complexity of dynamic loss scaling.
-
Gradient distributions are wide. During training, gradient magnitudes span many orders of magnitude across layers. Early layers often have gradients near while loss-adjacent layers can have gradients near . This range fits comfortably in BF16 but not in FP16.
-
Precision loss is tolerable. The stochastic nature of SGD-based optimizers means that tiny per-step precision errors are equivalent to noise, which SGD is inherently robust to. The FP32 master weights absorb any precision deficit in BF16 during the optimizer step.
-
Conversion is trivial. Converting FP32 to BF16 is a simple truncation of the lower 16 mantissa bits — no rounding logic needed at the hardware level. This makes the cast essentially free.
The single most important practical benefit of BF16 over FP16 is eliminating loss scaling. With FP16, a bad loss scale causes either gradient underflow (scale too low) or overflow (scale too high), and dynamic adjustment adds complexity and occasional skipped steps. With BF16, you simply cast and compute. This reduced training instability and simplified codebases significantly.
FP16 vs BF16: Practical Training Comparison
| Property | FP16 | BF16 |
|---|---|---|
| Loss scaling required | Yes (dynamic) | No |
| Gradient overflow risk | High (max 65,504) | Negligible (max 3.4e38) |
| Gradient underflow risk | High (min normal 6.1e-5) | Negligible (min normal 1.2e-38) |
| Precision (mantissa bits) | 10 | 7 |
| Tensor Core support (H100) | 1,979 TFLOPS | 1,979 TFLOPS |
| Typical training stability | Requires careful tuning | Drop-in replacement for FP32 |
| First GPU support | Volta V100 (2017) | Ampere A100 (2020) |
The industry largely migrated to BF16 once Ampere hardware became available. Today, nearly all large-scale training runs (GPT-4, Llama 3, Gemini, Claude, etc.) use BF16 as the default reduced-precision format.
Loss Scaling for FP16: The Engineering Workaround
Even though BF16 has largely superseded FP16 for training, understanding loss scaling is important: it explains a fundamental numerical challenge in reduced-precision training, and variants of scaling appear in FP8 training.
The Underflow Problem
Consider a gradient value of . In FP32, this is perfectly representable. In FP16, the smallest normal number is . Values below this enter the subnormal (denormalized) range where precision degrades rapidly, and values below approximately round to zero. If a significant fraction of your gradients live in this region, the model stops learning.
Empirically, Micikevicius et al. showed that for many networks, a large fraction of gradient values fall below FP16’s minimum normal. For SSD object detection, over 80% of gradient values were below (), meaning they would be flushed to zero in FP16.
How Loss Scaling Works
The fix is simple in concept: multiply the loss by a large constant before backpropagation. By the chain rule, all gradients are also multiplied by , shifting their magnitudes up into FP16’s representable range. After backpropagation but before the optimizer step, divide by to recover the true gradient values (in FP32).
Static vs Dynamic Loss Scaling
Static loss scaling uses a fixed constant (e.g., ). This works for some models but fails when gradient magnitudes change during training.
Dynamic loss scaling is what production systems use. The algorithm:
- Start with a large scale factor (e.g., ).
- After each backward pass, check gradients for inf/NaN (overflow detection).
- If overflow detected: halve , skip this optimizer step, discard the gradients.
- If no overflow for consecutive steps (e.g., ): double .
class DynamicLossScaler:
def __init__(self, init_scale=2**16, growth_factor=2.0,
backoff_factor=0.5, growth_interval=2000):
self.scale = init_scale
self.growth_factor = growth_factor
self.backoff_factor = backoff_factor
self.growth_interval = growth_interval
self.good_steps = 0
def check_overflow(self, gradients):
"""Returns True if any gradient is inf or NaN."""
for grad in gradients:
if grad is not None:
if torch.isinf(grad).any() or torch.isnan(grad).any():
return True
return False
def update(self, overflow_detected):
if overflow_detected:
# Overflow: reduce scale, skip step
self.scale *= self.backoff_factor
self.good_steps = 0
return False # Signal: do NOT run optimizer.step()
else:
self.good_steps += 1
if self.good_steps >= self.growth_interval:
# No overflow for a while: try increasing scale
self.scale *= self.growth_factor
self.good_steps = 0
return True # Signal: safe to run optimizer.step()
With dynamic loss scaling in FP16 training, it is normal to see occasional skipped optimizer steps when the scaler detects overflow. A few skipped steps per thousand are benign. If you see more than 1-2% of steps being skipped, the training may be numerically unstable and may need hyperparameter adjustment (lower learning rate, gradient clipping, or switching to BF16).
PyTorch AMP: The Standard Implementation
PyTorch’s Automatic Mixed Precision (AMP) wraps this machinery in a clean API:
import torch
from torch.cuda.amp import autocast, GradScaler
model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler() # Manages dynamic loss scaling
for batch in dataloader:
optimizer.zero_grad()
# Forward pass in FP16 (on CUDA) / BF16
with autocast(dtype=torch.float16):
outputs = model(batch["input_ids"].cuda())
loss = criterion(outputs, batch["labels"].cuda())
# Backward pass: scaler scales loss, then calls backward
scaler.scale(loss).backward()
# Unscale gradients, then clip
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Optimizer step: scaler checks for inf/NaN,
# skips step if overflow, otherwise calls optimizer.step()
scaler.step(optimizer)
scaler.update()
With BF16, the code simplifies because no scaler is needed:
for batch in dataloader:
optimizer.zero_grad()
with autocast(dtype=torch.bfloat16):
outputs = model(batch["input_ids"].cuda())
loss = criterion(outputs, batch["labels"].cuda())
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
No GradScaler, no scaler.scale(), no scaler.unscale_(). This simplicity is another reason BF16 won.
FP8 Training: The Hopper/Blackwell Generation
FP8 training cuts the precision further to 8 bits, offering approximately 2x the throughput of BF16 on Hopper (H100) and Blackwell (B200) GPUs. This is not a minor optimization — it is a generational step that required new formats, new scaling strategies, and NVIDIA’s Transformer Engine library to make practical.
Two FP8 Formats: E4M3 and E5M2
FP8 defines two sub-formats that trade off range and precision differently:
E4M3 (4 exponent bits, 3 mantissa bits): Maximum value of 448, roughly 1.1 decimal digits of precision. Optimized for values that need more precision but have a bounded range — primarily forward pass activations and weights.
E5M2 (5 exponent bits, 2 mantissa bits): Maximum value of 57,344, roughly 0.9 decimal digits of precision. Optimized for values that span a wider range but can tolerate less precision — primarily backward pass gradients.
FP8 Sub-Format Assignment in Training
| Training Phase | Tensor | FP8 Format | Reason |
|---|---|---|---|
| Forward | Weights | E4M3 | Bounded range, need precision |
| Forward | Activations | E4M3 | Bounded range after normalization |
| Forward | GEMM output | BF16/FP32 | Accumulated in higher precision |
| Backward | Gradient activations | E5M2 | Wide range, precision less critical |
| Backward | Weight gradients | E5M2 | Wide range across layers |
| Optimizer | Master weights | FP32 | Must preserve small updates |
The FP8 E4M3 format uses a non-standard encoding: the bit pattern 0x7F (0111 1111) represents NaN, and there is no representation for positive or negative infinity. All 4-exponent-bit patterns combined with the 3-mantissa-bit all-ones pattern produce NaN instead of infinity. This frees up one bit pattern for the value 448 instead of using it for infinity. E5M2, by contrast, follows the IEEE 754 convention with distinct inf and NaN encodings.
Per-Tensor Scaling: Why FP8 Needs It
FP8 E4M3 has a maximum value of 448. If a tensor has values exceeding 448, they overflow to NaN. If values are much smaller than 1, they lose almost all precision or underflow to zero (the smallest representable subnormal in E4M3 is approximately ). The solution is per-tensor scaling: each tensor gets its own scale factor that maps its dynamic range into FP8’s representable range.
The scale is chosen so that the maximum absolute value in the tensor maps to the maximum representable FP8 value:
For E4M3, . For E5M2, .
Delayed Scaling: The Production Strategy
Computing the optimal scale requires a pass over the entire tensor to find its maximum absolute value — this adds latency. NVIDIA’s Transformer Engine uses delayed scaling: use the scale factor from the previous iteration (or a running maximum over recent iterations) to quantize the current iteration’s tensors.
The reasoning: tensor statistics change slowly between adjacent training steps. A scale computed from step is a good approximation for step . The delayed scaling algorithm:
- Maintain a history buffer of recent
amaxvalues (e.g., the last 1024 iterations) for each tensor. - Compute the scale from the maximum of this history buffer.
- Quantize the current tensor using this scale.
- After the GEMM, record the actual
amaxof the current tensor into the history buffer.
# Simplified delayed scaling logic (Transformer Engine internals)
class DelayedScaling:
def __init__(self, fp8_max=448.0, margin=0, history_len=1024):
self.fp8_max = fp8_max
self.margin = margin
self.amax_history = torch.zeros(history_len)
self.history_idx = 0
self.scale = 1.0
def compute_scale(self):
"""Compute scale from historical amax values."""
amax = self.amax_history.max()
# Add safety margin to prevent overflow
exp = torch.floor(torch.log2(self.fp8_max / amax)) - self.margin
self.scale = (2.0 ** exp).item()
return self.scale
def record_amax(self, tensor):
"""Record current tensor's amax for future scale computation."""
self.amax_history[self.history_idx % len(self.amax_history)] = tensor.abs().max()
self.history_idx += 1
def quantize(self, tensor):
"""Quantize tensor to FP8 using delayed scale."""
scale = self.compute_scale()
scaled = tensor * scale
fp8_tensor = cast_to_fp8_e4m3(scaled)
self.record_amax(tensor)
return fp8_tensor, scale
Delayed scaling assumes tensor statistics are locally stationary. This assumption breaks during learning rate warmup, sudden loss spikes, or when transitioning between training phases. Transformer Engine handles this by detecting overflow (NaN in output) and falling back to recomputation with a freshly computed scale. In practice, such fallbacks are rare after the first few hundred steps.
Which Operations Use FP8
Not all operations benefit from or tolerate FP8 precision. The Transformer Engine selectively applies FP8 only where it provides the most benefit:
FP8 operations (GEMMs only):
- Linear layer forward: where and are cast to FP8, accumulation in FP32
- Linear layer backward: gradient computation via GEMMs
Higher precision operations (BF16/FP32):
- LayerNorm / RMSNorm (requires precision for mean/variance computation)
- Softmax (exponentials are sensitive to precision)
- Attention score computation (after QK^T, before softmax)
- Embedding lookups
- Residual additions
- All non-GEMM element-wise operations
The reason GEMMs dominate: in a transformer model, over 70% of training FLOPS are in the linear projections (Q, K, V, O projections, and feed-forward layers). These are all matrix multiplications, and they map directly to Tensor Core GEMMs. Making only these operations FP8 captures most of the throughput benefit while keeping numerically sensitive operations in higher precision.
FLOPS Distribution in Transformer Training (per layer)
| Operation | % of FLOPS | FP8 Eligible | Precision Used |
|---|---|---|---|
| QKV Projection (GEMM) | ~18% | Yes | FP8 E4M3 |
| Attention Output Projection (GEMM) | ~6% | Yes | FP8 E4M3 |
| FFN Up/Gate Projection (GEMM) | ~24% | Yes | FP8 E4M3 |
| FFN Down Projection (GEMM) | ~24% | Yes | FP8 E4M3 |
| Attention Scores (QK^T) | ~10% | Partially | BF16 or FP8 |
| Softmax | ~3% | No | FP32 |
| LayerNorm / RMSNorm | ~2% | No | FP32 / BF16 |
| Other (residual, etc.) | ~13% | No | BF16 |
Transformer Engine: NVIDIA’s FP8 Library
NVIDIA’s Transformer Engine (TE) is the production library for FP8 mixed-precision training. It provides drop-in replacements for torch.nn.Linear, torch.nn.LayerNorm, and transformer layer building blocks that automatically manage FP8 casting, scaling, and format selection.
import transformer_engine.pytorch as te
# Replace standard linear layers with TE equivalents
class TransformerBlock(torch.nn.Module):
def __init__(self, hidden_size, ffn_hidden_size, num_heads):
super().__init__()
# These layers automatically handle FP8 quantization
self.self_attention = te.MultiheadAttention(
hidden_size, num_heads,
fuse_qkv_params=True,
)
self.layernorm1 = te.LayerNorm(hidden_size)
self.ffn = te.LayerNormMLP(
hidden_size, ffn_hidden_size,
activation="gelu",
)
def forward(self, x):
# FP8 is handled internally by TE layers
residual = x
x = self.layernorm1(x)
x = self.self_attention(x) + residual
x = self.ffn(x) + x
return x
# Enable FP8 training with a context manager
with te.fp8_autocast(enabled=True):
output = model(input_tensor)
loss = criterion(output, labels)
loss.backward()
Under the hood, te.Linear performs:
- Compute
amaxof the input tensor and cache it for the delayed scaling history. - Look up the current scale factor from the delayed scaling state.
- Quantize the input to FP8 E4M3 using this scale.
- Quantize the weight to FP8 E4M3 using its own scale.
- Call a Tensor Core FP8 GEMM with FP32 accumulation.
- Dequantize the output (multiply by input_scale * weight_scale).
- During backward, use E5M2 for gradient tensors.
FP8 Training Results: Throughput and Quality
The throughput gains from FP8 are substantial and well-documented.
Training Throughput on H100 SXM: BF16 vs FP8
(TFLOPS (achieved))FP8 vs BF16 Training Quality (Published Results)
| Model | Format | Benchmark | Score | Delta vs BF16 |
|---|---|---|---|---|
| GPT-3 175B | BF16 | LAMBADA acc | 76.2% | baseline |
| GPT-3 175B | FP8 | LAMBADA acc | 76.0% | -0.2% |
| LLaMA 7B | BF16 | HellaSwag | 76.1% | baseline |
| LLaMA 7B | FP8 | HellaSwag | 75.9% | -0.2% |
| DeepSeek V3 671B | FP8 | MMLU | 87.1% | N/A (FP8 only) |
DeepSeek V3: FP8 at 671B Scale
DeepSeek V3, a 671B parameter Mixture-of-Experts model, was trained entirely in FP8 on H800 GPUs. This was a landmark result because:
- No loss spikes. Previous large-scale FP8 experiments sometimes produced training instabilities at scale. DeepSeek V3 trained stably through 14.8 trillion tokens.
- No BF16 fallback. All GEMMs used FP8 throughout training, with no phases that reverted to BF16.
- Cost efficiency. The training cost was approximately $5.6M, roughly 1/10th of comparably-sized models trained in BF16, largely due to the throughput gains from FP8.
Their approach included fine-grained quantization: instead of per-tensor scaling, they used per-block scaling with 128-element blocks, which better handles tensors with non-uniform value distributions. They also employed an auxiliary loss-free load balancing strategy for the MoE routing, which may have contributed to training stability.
FP4: The Emerging Frontier
FP4 training is an active research area, primarily driven by Microsoft and DeepSeek. With only 4 bits per value, FP4 offers a theoretical 2x memory and bandwidth reduction over FP8, but the engineering challenges are severe.
The Precision Challenge
FP4 E2M1 (2 exponent bits, 1 mantissa bit) can represent exactly 7 distinct positive values (plus zero, their negatives, and NaN): 6. This is barely a quantization grid, not a continuous number line. The rounding error for any individual value can be up to 33% of the value itself.
Research Approaches
Microsoft’s FP4 Training (2024): Proposed a mixed FP4/FP8 scheme where forward-pass GEMMs use FP4 weights and FP8 activations. Key innovations include:
- Outlier-aware quantization: Identifying and separately handling activation outliers that would destroy FP4 accuracy
- Hadamard rotation: Applying random orthogonal transforms to spread outlier energy across dimensions before quantization
- Two-level scaling: Block-level scaling with a coarser group-level scale to handle dynamic range
DeepSeek FP4 Research: Explored per-channel quantization with learned scale factors, showing that FP4 training of language models up to 7B parameters could match BF16 quality with careful compensation strategies.
As of early 2025, FP4 training remains a research topic. No major production training run has been published using FP4 as the primary precision. The Blackwell B200 GPU includes FP4 Tensor Core support (theoretically 2x the FP8 throughput at ~8,000 TFLOPS), but the software ecosystem and numerical techniques are still maturing. Expect FP4 training to become practical for production use in the 2025-2026 timeframe as Transformer Engine and frameworks add support.
Emerging FP4 Research Results
| Paper / Group | Model Size | Method | Quality vs BF16 | Status |
|---|---|---|---|---|
| Microsoft (2024) | Up to 7B | FP4 weights + FP8 activations | Within 0.5% on MMLU | Research |
| DeepSeek (2024) | Up to 7B | Per-channel FP4 + compensation | Within 0.3% on HellaSwag | Research |
| NVIDIA Blackwell | TBD | Hardware FP4 Tensor Cores | TBD | Hardware available |
Memory Savings Breakdown: The 70B Model Case Study
Understanding exact memory requirements is critical for capacity planning. Let us trace through a 70B parameter model (e.g., LLaMA 2 70B) under different precision regimes.
Parameter Memory
The number of parameters is 70 billion. Raw parameter storage:
- FP32:
- BF16:
- FP8:
Optimizer State Memory
Adam/AdamW maintains two additional states per parameter: first moment (m) and second moment (v). These are always kept in FP32 for numerical stability.
- Optimizer states (always FP32):
Gradient Memory
Gradients are the same size as parameters, stored in the training precision:
- FP32 gradients:
- BF16 gradients:
- FP8 gradients:
Total Training Memory (Excluding Activations)
Memory Breakdown: 70B Parameter Model (Excluding Activations)
| Component | Pure FP32 | BF16 Mixed | FP8 Mixed |
|---|---|---|---|
| Model params | 280 GB (FP32) | 140 GB (BF16) | 70 GB (FP8) |
| FP32 master weights | --- (same as above) | 280 GB | 280 GB |
| Optimizer (m + v) | 560 GB (FP32) | 560 GB (FP32) | 560 GB (FP32) |
| Gradients | 280 GB (FP32) | 140 GB (BF16) | 70 GB (FP8) |
| Total | 1,120 GB | 1,120 GB | 980 GB |
| Total (with ZeRO-3, 8 GPUs) | 140 GB/GPU | 140 GB/GPU | 122.5 GB/GPU |
A surprising insight from this breakdown: for Adam-based optimizers, the optimizer states (m and v) consume more memory than the model parameters in every precision regime. This is why memory-efficient optimizers (Adafactor, CAME, 8-bit Adam) and techniques like ZeRO offloading matter so much for large-scale training — they attack the largest memory consumer. Reducing parameter precision from FP32 to FP8 saves 210 GB on the parameters and gradients, but the optimizer states remain at 560 GB regardless.
A more practical way to think about the memory savings from mixed precision:
Effective Memory Savings: 70B Model with Adam
| Precision Regime | Total Memory | vs Pure FP32 | Practical Savings |
|---|---|---|---|
| Pure FP32 | 1,120 GB | baseline | --- |
| BF16 mixed (standard) | 1,120 GB | 0% less total | But enables Tensor Cores (2x compute) |
| BF16 mixed + FP32 opt | ~1,120 GB | 0% | Working set fits in less HBM |
| FP8 mixed + FP32 opt | ~980 GB | 12.5% less | 2x compute + reduced bandwidth |
| FP8 + 8-bit Adam | ~560 GB | 50% less | Aggressive but proven at scale |
The real value of reduced precision is not just memory reduction — it is the compute throughput gain from using Tensor Cores and the bandwidth reduction from moving smaller tensors.
Throughput Comparison: Real Training Numbers on H100
These are representative achieved TFLOPS for training different model sizes on H100 SXM GPUs, using standard configurations (Megatron-LM style parallelism, activation checkpointing, sequence length 2048-4096).
Achieved Training TFLOPS on H100 SXM (Single Node, 8 GPUs)
| Model | FP32 (CUDA) | TF32 | BF16 | FP8 | FP8 vs BF16 |
|---|---|---|---|---|---|
| 1.3B | ~12 | ~85 | ~145 | ~240 | 1.66x |
| 7B | ~10 | ~90 | ~160 | ~275 | 1.72x |
| 13B | ~9 | ~88 | ~155 | ~270 | 1.74x |
| 70B | N/A (OOM) | ~80 | ~140 | ~245 | 1.75x |
| 175B (multi-node) | N/A | ~70 | ~130 | ~225 | 1.73x |
Achieved Training TFLOPS by Precision (7B Model, H100 SXM)
(TFLOPS)The progression is dramatic: moving from FP32 CUDA cores to FP8 Tensor Cores delivers a ~27x throughput improvement on the same hardware. Even the “lazy” optimization of enabling TF32 (which PyTorch does by default on Ampere+) yields a 9x improvement with zero code changes.
When Precision Reduction Fails
Not every workload benefits from reduced precision, and some actively break. Understanding the failure modes is as important as understanding the benefits.
Small Models
For models with fewer than ~100M parameters, the overhead of mixed-precision infrastructure (maintaining master weights, scaling, casting) can outweigh the compute savings. Small models also tend to be less tolerant of numerical noise because each parameter carries proportionally more “responsibility” for the model’s behavior. At small scale, the memory savings are also less impactful since the model fits comfortably on a single GPU in FP32.
Tasks Requiring Fine Numerical Precision
Some applications produce outputs where small numerical differences matter:
- Scientific simulation models where outputs represent physical quantities
- Financial modeling where rounding errors accumulate over long sequences of operations
- Regression tasks with targets spanning many orders of magnitude
- Reinforcement learning with reward signals near the precision floor
Architecture-Specific Challenges
Very deep networks without normalization: In networks with hundreds of layers and no LayerNorm/BatchNorm, gradient magnitudes can vary by 20+ orders of magnitude across layers. Even BF16 handles this (same range as FP32), but FP8 with per-tensor scaling may struggle because a single scale factor cannot simultaneously represent very large and very small values in the same tensor.
Attention with extremely long sequences: Softmax over long sequences (greater than 100K tokens) involves computing for values that can be very negative. In FP16, the limited range can cause issues. BF16 and FP8 E5M2 handle this better due to wider range, but the softmax itself should always run in FP32.
Training instability during critical phases: The first few hundred steps of training (when gradients are large and chaotic) and fine-tuning with very small learning rates (when gradient signals are small) are the most precision-sensitive phases. Some practitioners start training in BF16, switch to FP8 after warmup, and revert to BF16 for the final fine-tuning stage.
The most dangerous failure mode of reduced precision is not a crash or NaN — it is silent quality degradation. The model trains, the loss decreases, but final evaluation metrics are 1-2% worse than they would have been in higher precision. This is hard to detect without running a BF16 baseline for comparison. Always validate reduced-precision training against a higher-precision reference on a representative subset of your evaluation suite.
When to Avoid Reduced Precision
| Scenario | Risk Level | Failure Mode | Mitigation |
|---|---|---|---|
| Model under 100M params | Medium | Overhead exceeds benefit | Use BF16 (simpler than FP8) |
| Scientific regression | High | Accumulated rounding error | Keep FP32 for critical ops |
| Very deep nets (500+ layers) | Medium | Gradient range exceeds format | Per-layer scaling, BF16 minimum |
| Long-context attention (128K+) | Medium | Softmax precision | Softmax always in FP32 |
| Fine-tuning with lr under 1e-6 | High | Updates below precision floor | FP32 optimizer, BF16 compute |
| RL with sparse rewards | High | Reward signal lost to rounding | FP32 for reward/value heads |
Production Setup: Framework-Specific Configurations
PyTorch Native (torch.amp)
The simplest path for single-GPU or DDP training:
import torch
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
# BF16 training (recommended for Ampere+ GPUs)
model = MyModel().cuda()
model = DDP(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for batch in dataloader:
optimizer.zero_grad(set_to_none=True) # set_to_none saves memory
with autocast(device_type="cuda", dtype=torch.bfloat16):
loss = model(batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# FP16 training (when BF16 not available, e.g., V100)
scaler = GradScaler()
for batch in dataloader:
optimizer.zero_grad(set_to_none=True)
with autocast(device_type="cuda", dtype=torch.float16):
loss = model(batch)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
DeepSpeed ZeRO + Mixed Precision
DeepSpeed configuration for BF16 training with ZeRO Stage 2:
{
"bf16": {
"enabled": true
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"overlap_comm": true,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"contiguous_gradients": true
},
"gradient_clipping": 1.0,
"train_micro_batch_size_per_gpu": 4,
"gradient_accumulation_steps": 8,
"steps_per_print": 100
}
For FP16 with loss scaling:
{
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
}
}
"loss_scale": 0 enables dynamic loss scaling. initial_scale_power: 16 means the initial scale is . hysteresis: 2 means the scale must overflow 2 consecutive times before being reduced.
Megatron-LM: Large-Scale Training
Megatron-LM command-line flags for FP8 training with Transformer Engine:
python pretrain_gpt.py \
--tensor-model-parallel-size 8 \
--pipeline-model-parallel-size 4 \
--num-layers 80 \
--hidden-size 8192 \
--num-attention-heads 64 \
--seq-length 4096 \
--micro-batch-size 2 \
--global-batch-size 1024 \
--bf16 \
--fp8-format hybrid \
--fp8-amax-history-len 1024 \
--fp8-amax-compute-algo max \
--transformer-impl transformer_engine \
--attention-softmax-in-fp32 \
--accumulate-allreduce-grads-in-fp32 \
--use-flash-attn
Key flags explained:
--bf16: Master weights and non-GEMM ops in BF16--fp8-format hybrid: E4M3 for forward, E5M2 for backward--fp8-amax-history-len 1024: Delayed scaling history window--fp8-amax-compute-algo max: Use max of history for scale computation--transformer-impl transformer_engine: Use TE layers--attention-softmax-in-fp32: Keep softmax in FP32 for stability--accumulate-allreduce-grads-in-fp32: Reduce precision errors in gradient all-reduce
Production Precision Configurations by Scale
| Model Scale | Recommended Precision | Framework | Key Configuration |
|---|---|---|---|
| Under 1B | BF16 | PyTorch native AMP | autocast(dtype=torch.bfloat16) |
| 1B - 13B | BF16 or FP8 | DeepSpeed ZeRO-2 | bf16.enabled + ZeRO Stage 2 |
| 13B - 70B | FP8 | Megatron-LM + TE | TP=8, PP=2-4, FP8 hybrid |
| 70B+ | FP8 | Megatron-LM + TE + ZeRO | TP=8, PP=8+, FP8 hybrid + ZeRO-1 |
| MoE 200B+ | FP8 | Custom (DeepSeek-style) | Per-block FP8 scaling, EP |
The Precision Ladder: A Summary
Training precision has evolved in a clear progression, driven by hardware support and numerical research:
The Precision Landscape: Past, Present, and Future
| Format | Era | GPU Generation | Key Innovation | Status (2025) |
|---|---|---|---|---|
| FP32 | 2012-2017 | Kepler through Pascal | Default, no tricks needed | Baseline / optimizer states only |
| FP16 + loss scaling | 2017-2020 | Volta V100 | Tensor Cores + dynamic loss scaling | Legacy, replaced by BF16 |
| BF16 | 2020-present | Ampere A100+ | FP32 range in 16 bits, no loss scaling | Current default for training |
| TF32 (auto) | 2020-present | Ampere A100+ | Transparent FP32 replacement using TC | Enabled by default in PyTorch |
| FP8 (E4M3/E5M2) | 2022-present | Hopper H100+ | Per-tensor delayed scaling, TE | Production-ready, widely adopted |
| FP4 (E2M1) | 2024-future | Blackwell B200+ | Block quantization, compensation | Research, early adoption |
Each step down the precision ladder roughly doubles throughput while requiring increasingly sophisticated numerical techniques to maintain training quality. The common thread: the compute formats get smaller, but the master weights and optimizer states remain in FP32. The “mixed” in mixed precision is essential — it is not about training in low precision, it is about computing in low precision while maintaining a high-precision anchor for correctness.
Theoretical Peak Throughput Scaling by Precision (H100 SXM)
(TFLOPS)The trajectory is clear: within the next two years, FP4 Tensor Core training will likely become standard practice for large models, just as FP8 became standard in 2023-2024 and BF16 in 2020-2022. Each transition required new hardware support, new scaling techniques, and new software infrastructure — but each delivered roughly 2x training efficiency for the same model quality. For practitioners, the imperative is straightforward: use the lowest precision your hardware supports and your model tolerates, keep master weights in FP32, and let the frameworks handle the scaling.