Part of Series Inference Optimization Timeline 52 of 60
1 Transformer Fundamentals for Systems Engineers: The 10-Minute Bridge from Architecture to Inference 2 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 3 KV Cache: The Hidden Memory Giant in LLM Serving 4 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 5 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 6 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 7 Continuous Batching: The Complete Guide to LLM Inference Scheduling 8 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 9 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 10 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 11 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 12 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 13 Mamba and State Space Models: The O(n) Alternative to Attention 14 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 15 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 16 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 17 Model Loading and Cold Start: safetensors, mmap, and Startup Optimization 18 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 19 Kernel Autotuning: How TensorRT and torch.compile Find Optimal CUDA Kernels 20 Attention Kernel Comparison: FlashAttention vs FlashInfer vs xformers vs Triton 21 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 22 Dynamic Batching: Orca, Sarathi, and Iteration-Level Scheduling Algorithms 23 Memory Pool Management: Slab Allocators for GPU Inference 24 Prefill vs Decode Optimization: Different Bottlenecks, Different Solutions 25 Decode Optimization: CUDA Graphs, Persistent Batches, and Speculative Verification 26 Multi-Model Serving: GPU Sharing, Model Switching, and Adapter Pool Management 27 Structured Output Acceleration: Compressed FSMs, Speculative JSON, and Grammar Caching 28 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 29 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 30 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 31 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 32 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification 33 Disaggregated Serving v2: Mooncake KV-Centric Architecture and LoongServe Elastic SP 34 Request Preemption and Priority Scheduling in Production LLM Serving 35 Autoscaling LLM Inference: Signals, Lag, Warm Pools, and Cost-Optimal Scaling 36 The Inference Stack in 2026: From HTTP Request to GPU Kernel and Back 37 Video and Audio LLM Serving: Temporal Encoding, Chunked Streaming, and Latency Budgets 38 KV Cache Compression and Eviction: H2O, Attention Sinks, Sliding Window, and Quantized KV 39 Distributed Inference: Tensor Parallelism vs Pipeline Parallelism for Serving 40 Serving Benchmark Methodology: How to Properly Measure LLM Inference Performance 41 Compute-Communication Overlap: Hiding Distributed Training Latency 42 DeepSpeed ZeRO: Memory Optimization for Distributed Training at Scale 43 Pipeline Parallelism: From GPipe to DualPipe -- Eliminating the Bubble 44 Gradient Compression for Distributed Training: Promise, Reality, and Where It Still Wins 45 The Definitive Guide to Distributed Parallelism: Data, Tensor, Pipeline, Expert, and Sequence Parallelism for Large-Scale Training 46 Decoding Performance: Beam Search vs Sampling — Latency, Throughput, Memory, and the Full Design Space 47 LLM Prefill Phase Optimization: Why Prompt Processing Is Compute-Bound and How to Fix It 48 LLM Serving Engines: vLLM vs SGLang vs TensorRT-LLM — A Systems Comparison 49 Request Routing for LLM Inference: From Naive Load Balancing to KV Cache-Aware Scheduling 50 Why Adam Is Expensive and What To Do About It: 8-bit Adam, Adafactor, CAME, and the Memory Math of Optimizers 51 How Large Models Actually Get Loaded: Safetensors, mmap, Tensor Parallelism, and Progressive Loading 52 Mixed Precision Training: The Complete Precision Landscape from FP32 to FP4 53 Model Compression: Pruning, Distillation, and Why Quantization Won 54 From NAS to Scaling Laws: How We Design LLM Architectures Now 55 NVIDIA NCCL Performance Tuning for Multi-GPU Training 56 ONNX Runtime in Practice: Graph Optimization, Execution Providers, Quantization, and When ORT Is the Right Choice 57 Optimizing GEMM for Neural Networks: BLAS vs Custom Kernels (Nov 2019) 58 Long Context: From Sparse Attention to Ring Attention 59 TensorRT-LLM: Graph Optimization for Maximum Inference Performance 60 Long Context LLMs: From 2K to 1M Tokens

Every floating-point multiply in FP32 that could have run in a lower precision format is wasted memory bandwidth, wasted compute cycles, and wasted power. Modern GPU architectures have made this tradeoff explicit: their fastest execution units — Tensor Cores — operate on FP16, BF16, or FP8, not FP32. If your training loop runs in pure FP32, the most powerful hardware on your chip sits idle. Mixed precision training exists to fix this: run forward and backward passes in reduced precision for throughput, keep master weights in FP32 for numerical correctness.

This post covers the full precision landscape from FP32 down to FP4, with the engineering details that matter in practice: why BF16 displaced FP16, how dynamic loss scaling actually works, what NVIDIA’s Transformer Engine does under the hood for FP8, and exact memory calculations for production-scale models.

Why FP32 Is Wasteful for Training

The Memory Bandwidth Argument

GPU compute has scaled faster than memory bandwidth for over a decade. An H100 SXM delivers 3,958 TFLOPS of FP8 Tensor Core compute but only 3.35 TB/s of HBM3 bandwidth. The arithmetic intensity required to keep the chip busy at peak throughput is over 1,000 ops/byte for FP8. Most training workloads fall well below this, making them memory-bandwidth bound.

Every parameter stored in FP32 (4 bytes) instead of FP16/BF16 (2 bytes) doubles the bandwidth required to load and store it. For the massive weight matrices in transformer models, this means the GPU spends more cycles waiting for data than doing useful math. Reducing precision to 16 bits halves the traffic. Reducing to 8 bits quarters it.

Memory Bandwidth per Parameter by Precision

(bytes)
FP32
4 bytes
FP16 / BF16
2 bytes
FP8
1 bytes
FP4
0.5 bytes

The Tensor Core Argument

NVIDIA Tensor Cores, starting with Volta (V100) in 2017, are specialized matrix-multiply-accumulate units. They operate on reduced-precision inputs and accumulate in higher precision. Critically, they are the only way to reach peak FLOPS on modern NVIDIA GPUs. The CUDA core FP32 throughput on an H100 is 67 TFLOPS. The Tensor Core FP16 throughput is 1,979 TFLOPS — a 30x gap. Running pure FP32 training on an H100 means you are using roughly 1.7% of the chip’s peak compute capability.

📊

H100 SXM Peak Throughput by Precision

PrecisionTensor Core TFLOPSvs FP32 CUDAFormat
FP32 (CUDA cores) 67 1.0x (baseline) IEEE 754
TF32 (Tensor Cores) 989 14.8x 19-bit internal
FP16 (Tensor Cores) 1,979 29.5x IEEE 754 half
BF16 (Tensor Cores) 1,979 29.5x Brain float
FP8 (Tensor Cores) 3,958 59.1x E4M3 / E5M2
Note: H100 SXM5 specs. Sparsity variants deliver 2x these numbers but require structured sparsity patterns.

The takeaway is clear: if you want to use the hardware you paid for, you must run in reduced precision. Mixed precision training is the technique that lets you do this without destroying model quality.

The Core Idea: Mixed Precision Training

The foundational approach, introduced by Micikevicius et al. (2018), maintains three copies of information:

  1. FP16 weights — used in the forward and backward passes for compute
  2. FP32 master weights — the authoritative copy, updated by the optimizer
  3. FP16 gradients — computed during backpropagation, optionally scaled

Each training step: (a) copy FP32 master weights to FP16, (b) run forward pass in FP16, (c) compute loss in FP16, (d) scale the loss, (e) run backward pass in FP16, (f) unscale gradients, (g) update FP32 master weights with optimizer. The FP32 master weights are essential because optimizer updates can be extremely small — on the order of 10710^{-7} when multiplying a learning rate of 10410^{-4} by a gradient of 10310^{-3}. In FP16, the smallest representable subnormal is 5.96×108\approx 5.96 \times 10^{-8}, and values below 6.1×105\approx 6.1 \times 10^{-5} lose significant precision. The FP32 master copy preserves these tiny updates.

IEEE 754 Floating-Point: The Bit Layout That Matters

To understand why BF16 displaced FP16, you need to understand how floating-point numbers are encoded.

A floating-point number is represented as:

(1)s×2ebias×(1+m)(-1)^{s} \times 2^{e - \text{bias}} \times (1 + m)

where ss is the sign bit, ee is the stored exponent, bias is 2k112^{k-1} - 1 for kk exponent bits, and mm is the fractional mantissa.

📊

Floating-Point Format Comparison

FormatTotal BitsSignExponentMantissaRange (max)Smallest NormalPrecision (decimal digits)
FP32 32 1 8 23 3.4e38 1.2e-38 ~7.2
TF32 19 1 8 10 3.4e38 1.2e-38 ~3.3
FP16 16 1 5 10 65,504 6.1e-5 ~3.3
BF16 16 1 8 7 3.4e38 1.2e-38 ~2.4
FP8 E4M3 8 1 4 3 448 1.5e-2 (sub: 2e-3) ~1.1
FP8 E5M2 8 1 5 2 57,344 6.1e-5 ~0.9
FP4 E2M1 4 1 2 1 6 1.0 ~0.6
Note: Precision in decimal digits = mantissa bits x log10(2). FP8 E4M3 often uses a special NaN encoding (no +/-inf, only NaN at 0x7F/0xFF).

The two numbers that matter most are exponent bits (which determine representable range) and mantissa bits (which determine precision within that range). This is the fundamental tradeoff that defines every format in the precision landscape.

FP16 vs BF16: Why BF16 Won

FP16: More Precision, Less Range

FP16 (IEEE 754 binary16) uses 5 exponent bits and 10 mantissa bits. This gives it decent precision — roughly 3.3 decimal digits — but a maximum representable value of only 65,504 and a minimum normal of about 6.1×1056.1 \times 10^{-5}.

The problem for training: gradients and activations routinely exceed 65,504 (overflow) or fall below 6.1×1056.1 \times 10^{-5} (underflow into the subnormal range where precision degrades, or to zero). Both situations are catastrophic. Overflow produces infinities and NaNs that propagate and destroy the training run. Underflow silently zeros out gradients, causing the model to stop learning.

This is why FP16 mixed precision training requires loss scaling — an engineering workaround to shift gradient magnitudes into FP16’s representable range.

BF16: Same Range as FP32, Less Precision

BF16 (Brain Floating Point, developed at Google Brain) uses 8 exponent bits and 7 mantissa bits. The 8 exponent bits give it exactly the same dynamic range as FP32: max value 3.4×1038\approx 3.4 \times 10^{38}, min normal 1.2×1038\approx 1.2 \times 10^{-38}. The tradeoff is precision — only about 2.4 decimal digits, compared to FP16’s 3.3.

Why Range Matters More Than Precision

In practice, BF16 won decisively for training because:

  1. No loss scaling needed. BF16’s range matches FP32, so gradients almost never overflow or underflow. This eliminates an entire class of training instabilities and removes the engineering complexity of dynamic loss scaling.

  2. Gradient distributions are wide. During training, gradient magnitudes span many orders of magnitude across layers. Early layers often have gradients near 10610^{-6} while loss-adjacent layers can have gradients near 10210^{2}. This range fits comfortably in BF16 but not in FP16.

  3. Precision loss is tolerable. The stochastic nature of SGD-based optimizers means that tiny per-step precision errors are equivalent to noise, which SGD is inherently robust to. The FP32 master weights absorb any precision deficit in BF16 during the optimizer step.

  4. Conversion is trivial. Converting FP32 to BF16 is a simple truncation of the lower 16 mantissa bits — no rounding logic needed at the hardware level. This makes the cast essentially free.

BF16 Eliminated Loss Scaling

The single most important practical benefit of BF16 over FP16 is eliminating loss scaling. With FP16, a bad loss scale causes either gradient underflow (scale too low) or overflow (scale too high), and dynamic adjustment adds complexity and occasional skipped steps. With BF16, you simply cast and compute. This reduced training instability and simplified codebases significantly.

📊

FP16 vs BF16: Practical Training Comparison

PropertyFP16BF16
Loss scaling required Yes (dynamic) No
Gradient overflow risk High (max 65,504) Negligible (max 3.4e38)
Gradient underflow risk High (min normal 6.1e-5) Negligible (min normal 1.2e-38)
Precision (mantissa bits) 10 7
Tensor Core support (H100) 1,979 TFLOPS 1,979 TFLOPS
Typical training stability Requires careful tuning Drop-in replacement for FP32
First GPU support Volta V100 (2017) Ampere A100 (2020)

The industry largely migrated to BF16 once Ampere hardware became available. Today, nearly all large-scale training runs (GPT-4, Llama 3, Gemini, Claude, etc.) use BF16 as the default reduced-precision format.

Loss Scaling for FP16: The Engineering Workaround

Even though BF16 has largely superseded FP16 for training, understanding loss scaling is important: it explains a fundamental numerical challenge in reduced-precision training, and variants of scaling appear in FP8 training.

The Underflow Problem

Consider a gradient value of 3.0×1063.0 \times 10^{-6}. In FP32, this is perfectly representable. In FP16, the smallest normal number is 6.1×1056.1 \times 10^{-5}. Values below this enter the subnormal (denormalized) range where precision degrades rapidly, and values below approximately 5.96×1085.96 \times 10^{-8} round to zero. If a significant fraction of your gradients live in this region, the model stops learning.

Empirically, Micikevicius et al. showed that for many networks, a large fraction of gradient values fall below FP16’s minimum normal. For SSD object detection, over 80% of gradient values were below 2242^{-24} (5.96×108\approx 5.96 \times 10^{-8}), meaning they would be flushed to zero in FP16.

How Loss Scaling Works

The fix is simple in concept: multiply the loss by a large constant SS before backpropagation. By the chain rule, all gradients are also multiplied by SS, shifting their magnitudes up into FP16’s representable range. After backpropagation but before the optimizer step, divide by SS to recover the true gradient values (in FP32).

scaled_loss=S×L\text{scaled\_loss} = S \times \mathcal{L}

θscaled_loss=S×θL\nabla_{\theta}\text{scaled\_loss} = S \times \nabla_{\theta}\mathcal{L}

θL=1Sθscaled_loss\nabla_{\theta}\mathcal{L} = \frac{1}{S} \nabla_{\theta}\text{scaled\_loss}

Static vs Dynamic Loss Scaling

Static loss scaling uses a fixed constant (e.g., S=216=65536S = 2^{16} = 65536). This works for some models but fails when gradient magnitudes change during training.

Dynamic loss scaling is what production systems use. The algorithm:

  1. Start with a large scale factor SS (e.g., 2162^{16}).
  2. After each backward pass, check gradients for inf/NaN (overflow detection).
  3. If overflow detected: halve SS, skip this optimizer step, discard the gradients.
  4. If no overflow for NN consecutive steps (e.g., N=2000N = 2000): double SS.
class DynamicLossScaler:
    def __init__(self, init_scale=2**16, growth_factor=2.0,
                 backoff_factor=0.5, growth_interval=2000):
        self.scale = init_scale
        self.growth_factor = growth_factor
        self.backoff_factor = backoff_factor
        self.growth_interval = growth_interval
        self.good_steps = 0

    def check_overflow(self, gradients):
        """Returns True if any gradient is inf or NaN."""
        for grad in gradients:
            if grad is not None:
                if torch.isinf(grad).any() or torch.isnan(grad).any():
                    return True
        return False

    def update(self, overflow_detected):
        if overflow_detected:
            # Overflow: reduce scale, skip step
            self.scale *= self.backoff_factor
            self.good_steps = 0
            return False  # Signal: do NOT run optimizer.step()
        else:
            self.good_steps += 1
            if self.good_steps >= self.growth_interval:
                # No overflow for a while: try increasing scale
                self.scale *= self.growth_factor
                self.good_steps = 0
            return True  # Signal: safe to run optimizer.step()
⚠️ Skipped Steps Are Normal

With dynamic loss scaling in FP16 training, it is normal to see occasional skipped optimizer steps when the scaler detects overflow. A few skipped steps per thousand are benign. If you see more than 1-2% of steps being skipped, the training may be numerically unstable and may need hyperparameter adjustment (lower learning rate, gradient clipping, or switching to BF16).

PyTorch AMP: The Standard Implementation

PyTorch’s Automatic Mixed Precision (AMP) wraps this machinery in a clean API:

import torch
from torch.cuda.amp import autocast, GradScaler

model = MyModel().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = GradScaler()  # Manages dynamic loss scaling

for batch in dataloader:
    optimizer.zero_grad()

    # Forward pass in FP16 (on CUDA) / BF16
    with autocast(dtype=torch.float16):
        outputs = model(batch["input_ids"].cuda())
        loss = criterion(outputs, batch["labels"].cuda())

    # Backward pass: scaler scales loss, then calls backward
    scaler.scale(loss).backward()

    # Unscale gradients, then clip
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # Optimizer step: scaler checks for inf/NaN,
    # skips step if overflow, otherwise calls optimizer.step()
    scaler.step(optimizer)
    scaler.update()

With BF16, the code simplifies because no scaler is needed:

for batch in dataloader:
    optimizer.zero_grad()

    with autocast(dtype=torch.bfloat16):
        outputs = model(batch["input_ids"].cuda())
        loss = criterion(outputs, batch["labels"].cuda())

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

No GradScaler, no scaler.scale(), no scaler.unscale_(). This simplicity is another reason BF16 won.

FP8 Training: The Hopper/Blackwell Generation

FP8 training cuts the precision further to 8 bits, offering approximately 2x the throughput of BF16 on Hopper (H100) and Blackwell (B200) GPUs. This is not a minor optimization — it is a generational step that required new formats, new scaling strategies, and NVIDIA’s Transformer Engine library to make practical.

Two FP8 Formats: E4M3 and E5M2

FP8 defines two sub-formats that trade off range and precision differently:

E4M3 (4 exponent bits, 3 mantissa bits): Maximum value of 448, roughly 1.1 decimal digits of precision. Optimized for values that need more precision but have a bounded range — primarily forward pass activations and weights.

E5M2 (5 exponent bits, 2 mantissa bits): Maximum value of 57,344, roughly 0.9 decimal digits of precision. Optimized for values that span a wider range but can tolerate less precision — primarily backward pass gradients.

📊

FP8 Sub-Format Assignment in Training

Training PhaseTensorFP8 FormatReason
Forward Weights E4M3 Bounded range, need precision
Forward Activations E4M3 Bounded range after normalization
Forward GEMM output BF16/FP32 Accumulated in higher precision
Backward Gradient activations E5M2 Wide range, precision less critical
Backward Weight gradients E5M2 Wide range across layers
Optimizer Master weights FP32 Must preserve small updates
Note: GEMM inputs are cast to FP8; the accumulation inside Tensor Cores happens in FP32. Only the stored/transferred values are FP8.
ℹ️ E4M3 NaN Encoding

The FP8 E4M3 format uses a non-standard encoding: the bit pattern 0x7F (0111 1111) represents NaN, and there is no representation for positive or negative infinity. All 4-exponent-bit patterns combined with the 3-mantissa-bit all-ones pattern produce NaN instead of infinity. This frees up one bit pattern for the value 448 instead of using it for infinity. E5M2, by contrast, follows the IEEE 754 convention with distinct inf and NaN encodings.

Per-Tensor Scaling: Why FP8 Needs It

FP8 E4M3 has a maximum value of 448. If a tensor has values exceeding 448, they overflow to NaN. If values are much smaller than 1, they lose almost all precision or underflow to zero (the smallest representable subnormal in E4M3 is approximately 1.95×1031.95 \times 10^{-3}). The solution is per-tensor scaling: each tensor gets its own scale factor that maps its dynamic range into FP8’s representable range.

xfp8=cast_to_fp8(xscale)x_{\text{fp8}} = \text{cast\_to\_fp8}\left(\frac{x}{\text{scale}}\right)

The scale is chosen so that the maximum absolute value in the tensor maps to the maximum representable FP8 value:

scale=max(x)fp8_max\text{scale} = \frac{\max(|x|)}{\text{fp8\_max}}

For E4M3, fp8_max=448\text{fp8\_max} = 448. For E5M2, fp8_max=57344\text{fp8\_max} = 57344.

Delayed Scaling: The Production Strategy

Computing the optimal scale requires a pass over the entire tensor to find its maximum absolute value — this adds latency. NVIDIA’s Transformer Engine uses delayed scaling: use the scale factor from the previous iteration (or a running maximum over recent iterations) to quantize the current iteration’s tensors.

The reasoning: tensor statistics change slowly between adjacent training steps. A scale computed from step t1t-1 is a good approximation for step tt. The delayed scaling algorithm:

  1. Maintain a history buffer of recent amax values (e.g., the last 1024 iterations) for each tensor.
  2. Compute the scale from the maximum of this history buffer.
  3. Quantize the current tensor using this scale.
  4. After the GEMM, record the actual amax of the current tensor into the history buffer.
# Simplified delayed scaling logic (Transformer Engine internals)
class DelayedScaling:
    def __init__(self, fp8_max=448.0, margin=0, history_len=1024):
        self.fp8_max = fp8_max
        self.margin = margin
        self.amax_history = torch.zeros(history_len)
        self.history_idx = 0
        self.scale = 1.0

    def compute_scale(self):
        """Compute scale from historical amax values."""
        amax = self.amax_history.max()
        # Add safety margin to prevent overflow
        exp = torch.floor(torch.log2(self.fp8_max / amax)) - self.margin
        self.scale = (2.0 ** exp).item()
        return self.scale

    def record_amax(self, tensor):
        """Record current tensor's amax for future scale computation."""
        self.amax_history[self.history_idx % len(self.amax_history)] = tensor.abs().max()
        self.history_idx += 1

    def quantize(self, tensor):
        """Quantize tensor to FP8 using delayed scale."""
        scale = self.compute_scale()
        scaled = tensor * scale
        fp8_tensor = cast_to_fp8_e4m3(scaled)
        self.record_amax(tensor)
        return fp8_tensor, scale
⚠️ Delayed Scaling Failure Modes

Delayed scaling assumes tensor statistics are locally stationary. This assumption breaks during learning rate warmup, sudden loss spikes, or when transitioning between training phases. Transformer Engine handles this by detecting overflow (NaN in output) and falling back to recomputation with a freshly computed scale. In practice, such fallbacks are rare after the first few hundred steps.

Which Operations Use FP8

Not all operations benefit from or tolerate FP8 precision. The Transformer Engine selectively applies FP8 only where it provides the most benefit:

FP8 operations (GEMMs only):

  • Linear layer forward: Y=XWT+bY = XW^T + b where XX and WW are cast to FP8, accumulation in FP32
  • Linear layer backward: gradient computation via GEMMs

Higher precision operations (BF16/FP32):

  • LayerNorm / RMSNorm (requires precision for mean/variance computation)
  • Softmax (exponentials are sensitive to precision)
  • Attention score computation (after QK^T, before softmax)
  • Embedding lookups
  • Residual additions
  • All non-GEMM element-wise operations

The reason GEMMs dominate: in a transformer model, over 70% of training FLOPS are in the linear projections (Q, K, V, O projections, and feed-forward layers). These are all matrix multiplications, and they map directly to Tensor Core GEMMs. Making only these operations FP8 captures most of the throughput benefit while keeping numerically sensitive operations in higher precision.

📊

FLOPS Distribution in Transformer Training (per layer)

Operation% of FLOPSFP8 EligiblePrecision Used
QKV Projection (GEMM) ~18% Yes FP8 E4M3
Attention Output Projection (GEMM) ~6% Yes FP8 E4M3
FFN Up/Gate Projection (GEMM) ~24% Yes FP8 E4M3
FFN Down Projection (GEMM) ~24% Yes FP8 E4M3
Attention Scores (QK^T) ~10% Partially BF16 or FP8
Softmax ~3% No FP32
LayerNorm / RMSNorm ~2% No FP32 / BF16
Other (residual, etc.) ~13% No BF16
Note: Approximate for a standard transformer with FFN hidden dim = 4x model dim. Attention score GEMMs are sometimes run in FP8 depending on implementation.

Transformer Engine: NVIDIA’s FP8 Library

NVIDIA’s Transformer Engine (TE) is the production library for FP8 mixed-precision training. It provides drop-in replacements for torch.nn.Linear, torch.nn.LayerNorm, and transformer layer building blocks that automatically manage FP8 casting, scaling, and format selection.

import transformer_engine.pytorch as te

# Replace standard linear layers with TE equivalents
class TransformerBlock(torch.nn.Module):
    def __init__(self, hidden_size, ffn_hidden_size, num_heads):
        super().__init__()
        # These layers automatically handle FP8 quantization
        self.self_attention = te.MultiheadAttention(
            hidden_size, num_heads,
            fuse_qkv_params=True,
        )
        self.layernorm1 = te.LayerNorm(hidden_size)
        self.ffn = te.LayerNormMLP(
            hidden_size, ffn_hidden_size,
            activation="gelu",
        )

    def forward(self, x):
        # FP8 is handled internally by TE layers
        residual = x
        x = self.layernorm1(x)
        x = self.self_attention(x) + residual
        x = self.ffn(x) + x
        return x

# Enable FP8 training with a context manager
with te.fp8_autocast(enabled=True):
    output = model(input_tensor)
    loss = criterion(output, labels)
    loss.backward()

Under the hood, te.Linear performs:

  1. Compute amax of the input tensor and cache it for the delayed scaling history.
  2. Look up the current scale factor from the delayed scaling state.
  3. Quantize the input to FP8 E4M3 using this scale.
  4. Quantize the weight to FP8 E4M3 using its own scale.
  5. Call a Tensor Core FP8 GEMM with FP32 accumulation.
  6. Dequantize the output (multiply by input_scale * weight_scale).
  7. During backward, use E5M2 for gradient tensors.

FP8 Training Results: Throughput and Quality

The throughput gains from FP8 are substantial and well-documented.

Training Throughput on H100 SXM: BF16 vs FP8

(TFLOPS (achieved))
GPT-3 175B (BF16)
420 TFLOPS (achieved)
GPT-3 175B (FP8) +71%
720 TFLOPS (achieved)
+71.4%
LLaMA 70B (BF16)
480 TFLOPS (achieved)
+14.3%
LLaMA 70B (FP8) +73%
830 TFLOPS (achieved)
+97.6%
LLaMA 7B (BF16)
510 TFLOPS (achieved)
+21.4%
LLaMA 7B (FP8) +69%
860 TFLOPS (achieved)
+104.8%
📊

FP8 vs BF16 Training Quality (Published Results)

ModelFormatBenchmarkScoreDelta vs BF16
GPT-3 175B BF16 LAMBADA acc 76.2% baseline
GPT-3 175B FP8 LAMBADA acc 76.0% -0.2%
LLaMA 7B BF16 HellaSwag 76.1% baseline
LLaMA 7B FP8 HellaSwag 75.9% -0.2%
DeepSeek V3 671B FP8 MMLU 87.1% N/A (FP8 only)
Note: DeepSeek V3 was trained entirely in FP8 from scratch, so no BF16 baseline exists for that specific run.

DeepSeek V3: FP8 at 671B Scale

DeepSeek V3, a 671B parameter Mixture-of-Experts model, was trained entirely in FP8 on H800 GPUs. This was a landmark result because:

  1. No loss spikes. Previous large-scale FP8 experiments sometimes produced training instabilities at scale. DeepSeek V3 trained stably through 14.8 trillion tokens.
  2. No BF16 fallback. All GEMMs used FP8 throughout training, with no phases that reverted to BF16.
  3. Cost efficiency. The training cost was approximately $5.6M, roughly 1/10th of comparably-sized models trained in BF16, largely due to the throughput gains from FP8.

Their approach included fine-grained quantization: instead of per-tensor scaling, they used per-block scaling with 128-element blocks, which better handles tensors with non-uniform value distributions. They also employed an auxiliary loss-free load balancing strategy for the MoE routing, which may have contributed to training stability.

FP4: The Emerging Frontier

FP4 training is an active research area, primarily driven by Microsoft and DeepSeek. With only 4 bits per value, FP4 offers a theoretical 2x memory and bandwidth reduction over FP8, but the engineering challenges are severe.

The Precision Challenge

FP4 E2M1 (2 exponent bits, 1 mantissa bit) can represent exactly 7 distinct positive values (plus zero, their negatives, and NaN): 6. This is barely a quantization grid, not a continuous number line. The rounding error for any individual value can be up to 33% of the value itself.

Research Approaches

Microsoft’s FP4 Training (2024): Proposed a mixed FP4/FP8 scheme where forward-pass GEMMs use FP4 weights and FP8 activations. Key innovations include:

  • Outlier-aware quantization: Identifying and separately handling activation outliers that would destroy FP4 accuracy
  • Hadamard rotation: Applying random orthogonal transforms to spread outlier energy across dimensions before quantization
  • Two-level scaling: Block-level scaling with a coarser group-level scale to handle dynamic range

DeepSeek FP4 Research: Explored per-channel quantization with learned scale factors, showing that FP4 training of language models up to 7B parameters could match BF16 quality with careful compensation strategies.

ℹ️ FP4 Is Not Yet Production-Ready

As of early 2025, FP4 training remains a research topic. No major production training run has been published using FP4 as the primary precision. The Blackwell B200 GPU includes FP4 Tensor Core support (theoretically 2x the FP8 throughput at ~8,000 TFLOPS), but the software ecosystem and numerical techniques are still maturing. Expect FP4 training to become practical for production use in the 2025-2026 timeframe as Transformer Engine and frameworks add support.

📊

Emerging FP4 Research Results

Paper / GroupModel SizeMethodQuality vs BF16Status
Microsoft (2024) Up to 7B FP4 weights + FP8 activations Within 0.5% on MMLU Research
DeepSeek (2024) Up to 7B Per-channel FP4 + compensation Within 0.3% on HellaSwag Research
NVIDIA Blackwell TBD Hardware FP4 Tensor Cores TBD Hardware available

Memory Savings Breakdown: The 70B Model Case Study

Understanding exact memory requirements is critical for capacity planning. Let us trace through a 70B parameter model (e.g., LLaMA 2 70B) under different precision regimes.

Parameter Memory

The number of parameters is 70 billion. Raw parameter storage:

  • FP32: 70×109×4 bytes=280 GB70 \times 10^9 \times 4\text{ bytes} = 280\text{ GB}
  • BF16: 70×109×2 bytes=140 GB70 \times 10^9 \times 2\text{ bytes} = 140\text{ GB}
  • FP8: 70×109×1 byte=70 GB70 \times 10^9 \times 1\text{ byte} = 70\text{ GB}

Optimizer State Memory

Adam/AdamW maintains two additional states per parameter: first moment (m) and second moment (v). These are always kept in FP32 for numerical stability.

  • Optimizer states (always FP32): 70×109×2×4 bytes=560 GB70 \times 10^9 \times 2 \times 4\text{ bytes} = 560\text{ GB}

Gradient Memory

Gradients are the same size as parameters, stored in the training precision:

  • FP32 gradients: 70×109×4 bytes=280 GB70 \times 10^9 \times 4\text{ bytes} = 280\text{ GB}
  • BF16 gradients: 70×109×2 bytes=140 GB70 \times 10^9 \times 2\text{ bytes} = 140\text{ GB}
  • FP8 gradients: 70×109×1 byte=70 GB70 \times 10^9 \times 1\text{ byte} = 70\text{ GB}

Total Training Memory (Excluding Activations)

📊

Memory Breakdown: 70B Parameter Model (Excluding Activations)

ComponentPure FP32BF16 MixedFP8 Mixed
Model params 280 GB (FP32) 140 GB (BF16) 70 GB (FP8)
FP32 master weights --- (same as above) 280 GB 280 GB
Optimizer (m + v) 560 GB (FP32) 560 GB (FP32) 560 GB (FP32)
Gradients 280 GB (FP32) 140 GB (BF16) 70 GB (FP8)
Total 1,120 GB 1,120 GB 980 GB
Total (with ZeRO-3, 8 GPUs) 140 GB/GPU 140 GB/GPU 122.5 GB/GPU
Note: BF16 mixed precision requires FP32 master weights + BF16 working copies. The optimizer states dominate total memory. Activation memory (not shown) depends on sequence length, batch size, and checkpointing strategy.
The Optimizer Dominates Memory

A surprising insight from this breakdown: for Adam-based optimizers, the optimizer states (m and v) consume more memory than the model parameters in every precision regime. This is why memory-efficient optimizers (Adafactor, CAME, 8-bit Adam) and techniques like ZeRO offloading matter so much for large-scale training — they attack the largest memory consumer. Reducing parameter precision from FP32 to FP8 saves 210 GB on the parameters and gradients, but the optimizer states remain at 560 GB regardless.

A more practical way to think about the memory savings from mixed precision:

📊

Effective Memory Savings: 70B Model with Adam

Precision RegimeTotal Memoryvs Pure FP32Practical Savings
Pure FP32 1,120 GB baseline ---
BF16 mixed (standard) 1,120 GB 0% less total But enables Tensor Cores (2x compute)
BF16 mixed + FP32 opt ~1,120 GB 0% Working set fits in less HBM
FP8 mixed + FP32 opt ~980 GB 12.5% less 2x compute + reduced bandwidth
FP8 + 8-bit Adam ~560 GB 50% less Aggressive but proven at scale
Note: These numbers exclude activation memory. With activation checkpointing, activation memory is roughly proportional to model size x sequence length / checkpoint granularity.

The real value of reduced precision is not just memory reduction — it is the compute throughput gain from using Tensor Cores and the bandwidth reduction from moving smaller tensors.

Throughput Comparison: Real Training Numbers on H100

These are representative achieved TFLOPS for training different model sizes on H100 SXM GPUs, using standard configurations (Megatron-LM style parallelism, activation checkpointing, sequence length 2048-4096).

📊

Achieved Training TFLOPS on H100 SXM (Single Node, 8 GPUs)

ModelFP32 (CUDA)TF32BF16FP8FP8 vs BF16
1.3B ~12 ~85 ~145 ~240 1.66x
7B ~10 ~90 ~160 ~275 1.72x
13B ~9 ~88 ~155 ~270 1.74x
70B N/A (OOM) ~80 ~140 ~245 1.75x
175B (multi-node) N/A ~70 ~130 ~225 1.73x
Note: FP32 CUDA core training at 70B+ does not fit on a single node. TF32 uses Tensor Cores with FP32 inputs (automatic in PyTorch). FP8 numbers assume Transformer Engine with delayed scaling. Achieved TFLOPS is model FLOPS utilization (MFU) x peak.

Achieved Training TFLOPS by Precision (7B Model, H100 SXM)

(TFLOPS)
FP32 (CUDA) 1.0x baseline
10 TFLOPS
TF32 (auto) 9x
90 TFLOPS
+800.0%
BF16 16x
160 TFLOPS
+1500.0%
FP8 27.5x
275 TFLOPS
+2650.0%

The progression is dramatic: moving from FP32 CUDA cores to FP8 Tensor Cores delivers a ~27x throughput improvement on the same hardware. Even the “lazy” optimization of enabling TF32 (which PyTorch does by default on Ampere+) yields a 9x improvement with zero code changes.

When Precision Reduction Fails

Not every workload benefits from reduced precision, and some actively break. Understanding the failure modes is as important as understanding the benefits.

Small Models

For models with fewer than ~100M parameters, the overhead of mixed-precision infrastructure (maintaining master weights, scaling, casting) can outweigh the compute savings. Small models also tend to be less tolerant of numerical noise because each parameter carries proportionally more “responsibility” for the model’s behavior. At small scale, the memory savings are also less impactful since the model fits comfortably on a single GPU in FP32.

Tasks Requiring Fine Numerical Precision

Some applications produce outputs where small numerical differences matter:

  • Scientific simulation models where outputs represent physical quantities
  • Financial modeling where rounding errors accumulate over long sequences of operations
  • Regression tasks with targets spanning many orders of magnitude
  • Reinforcement learning with reward signals near the precision floor

Architecture-Specific Challenges

Very deep networks without normalization: In networks with hundreds of layers and no LayerNorm/BatchNorm, gradient magnitudes can vary by 20+ orders of magnitude across layers. Even BF16 handles this (same range as FP32), but FP8 with per-tensor scaling may struggle because a single scale factor cannot simultaneously represent very large and very small values in the same tensor.

Attention with extremely long sequences: Softmax over long sequences (greater than 100K tokens) involves computing exp(x)\exp(x) for values that can be very negative. In FP16, the limited range can cause issues. BF16 and FP8 E5M2 handle this better due to wider range, but the softmax itself should always run in FP32.

Training instability during critical phases: The first few hundred steps of training (when gradients are large and chaotic) and fine-tuning with very small learning rates (when gradient signals are small) are the most precision-sensitive phases. Some practitioners start training in BF16, switch to FP8 after warmup, and revert to BF16 for the final fine-tuning stage.

🚨 Silent Degradation

The most dangerous failure mode of reduced precision is not a crash or NaN — it is silent quality degradation. The model trains, the loss decreases, but final evaluation metrics are 1-2% worse than they would have been in higher precision. This is hard to detect without running a BF16 baseline for comparison. Always validate reduced-precision training against a higher-precision reference on a representative subset of your evaluation suite.

📊

When to Avoid Reduced Precision

ScenarioRisk LevelFailure ModeMitigation
Model under 100M params Medium Overhead exceeds benefit Use BF16 (simpler than FP8)
Scientific regression High Accumulated rounding error Keep FP32 for critical ops
Very deep nets (500+ layers) Medium Gradient range exceeds format Per-layer scaling, BF16 minimum
Long-context attention (128K+) Medium Softmax precision Softmax always in FP32
Fine-tuning with lr under 1e-6 High Updates below precision floor FP32 optimizer, BF16 compute
RL with sparse rewards High Reward signal lost to rounding FP32 for reward/value heads

Production Setup: Framework-Specific Configurations

PyTorch Native (torch.amp)

The simplest path for single-GPU or DDP training:

import torch
from torch.cuda.amp import autocast, GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP

# BF16 training (recommended for Ampere+ GPUs)
model = MyModel().cuda()
model = DDP(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

for batch in dataloader:
    optimizer.zero_grad(set_to_none=True)  # set_to_none saves memory
    with autocast(device_type="cuda", dtype=torch.bfloat16):
        loss = model(batch)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

# FP16 training (when BF16 not available, e.g., V100)
scaler = GradScaler()
for batch in dataloader:
    optimizer.zero_grad(set_to_none=True)
    with autocast(device_type="cuda", dtype=torch.float16):
        loss = model(batch)
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    scaler.step(optimizer)
    scaler.update()

DeepSpeed ZeRO + Mixed Precision

DeepSpeed configuration for BF16 training with ZeRO Stage 2:

{
  "bf16": {
    "enabled": true
  },
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 5e8,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 5e8,
    "contiguous_gradients": true
  },
  "gradient_clipping": 1.0,
  "train_micro_batch_size_per_gpu": 4,
  "gradient_accumulation_steps": 8,
  "steps_per_print": 100
}

For FP16 with loss scaling:

{
  "fp16": {
    "enabled": true,
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  }
}

"loss_scale": 0 enables dynamic loss scaling. initial_scale_power: 16 means the initial scale is 216=655362^{16} = 65536. hysteresis: 2 means the scale must overflow 2 consecutive times before being reduced.

Megatron-LM: Large-Scale Training

Megatron-LM command-line flags for FP8 training with Transformer Engine:

python pretrain_gpt.py \
    --tensor-model-parallel-size 8 \
    --pipeline-model-parallel-size 4 \
    --num-layers 80 \
    --hidden-size 8192 \
    --num-attention-heads 64 \
    --seq-length 4096 \
    --micro-batch-size 2 \
    --global-batch-size 1024 \
    --bf16 \
    --fp8-format hybrid \
    --fp8-amax-history-len 1024 \
    --fp8-amax-compute-algo max \
    --transformer-impl transformer_engine \
    --attention-softmax-in-fp32 \
    --accumulate-allreduce-grads-in-fp32 \
    --use-flash-attn

Key flags explained:

  • --bf16: Master weights and non-GEMM ops in BF16
  • --fp8-format hybrid: E4M3 for forward, E5M2 for backward
  • --fp8-amax-history-len 1024: Delayed scaling history window
  • --fp8-amax-compute-algo max: Use max of history for scale computation
  • --transformer-impl transformer_engine: Use TE layers
  • --attention-softmax-in-fp32: Keep softmax in FP32 for stability
  • --accumulate-allreduce-grads-in-fp32: Reduce precision errors in gradient all-reduce
📊

Production Precision Configurations by Scale

Model ScaleRecommended PrecisionFrameworkKey Configuration
Under 1B BF16 PyTorch native AMP autocast(dtype=torch.bfloat16)
1B - 13B BF16 or FP8 DeepSpeed ZeRO-2 bf16.enabled + ZeRO Stage 2
13B - 70B FP8 Megatron-LM + TE TP=8, PP=2-4, FP8 hybrid
70B+ FP8 Megatron-LM + TE + ZeRO TP=8, PP=8+, FP8 hybrid + ZeRO-1
MoE 200B+ FP8 Custom (DeepSeek-style) Per-block FP8 scaling, EP
Note: TP = tensor parallelism, PP = pipeline parallelism, EP = expert parallelism, TE = Transformer Engine.

The Precision Ladder: A Summary

Training precision has evolved in a clear progression, driven by hardware support and numerical research:

📊

The Precision Landscape: Past, Present, and Future

FormatEraGPU GenerationKey InnovationStatus (2025)
FP32 2012-2017 Kepler through Pascal Default, no tricks needed Baseline / optimizer states only
FP16 + loss scaling 2017-2020 Volta V100 Tensor Cores + dynamic loss scaling Legacy, replaced by BF16
BF16 2020-present Ampere A100+ FP32 range in 16 bits, no loss scaling Current default for training
TF32 (auto) 2020-present Ampere A100+ Transparent FP32 replacement using TC Enabled by default in PyTorch
FP8 (E4M3/E5M2) 2022-present Hopper H100+ Per-tensor delayed scaling, TE Production-ready, widely adopted
FP4 (E2M1) 2024-future Blackwell B200+ Block quantization, compensation Research, early adoption

Each step down the precision ladder roughly doubles throughput while requiring increasingly sophisticated numerical techniques to maintain training quality. The common thread: the compute formats get smaller, but the master weights and optimizer states remain in FP32. The “mixed” in mixed precision is essential — it is not about training in low precision, it is about computing in low precision while maintaining a high-precision anchor for correctness.

Theoretical Peak Throughput Scaling by Precision (H100 SXM)

(TFLOPS)
FP32 (CUDA)
67 TFLOPS
TF32 (TC)
989 TFLOPS
+1376.1%
BF16 (TC)
1,979 TFLOPS
+2853.7%
FP8 (TC)
3,958 TFLOPS
+5807.5%
FP4 (B200, est.) Blackwell
8,000 TFLOPS
+11840.3%

The trajectory is clear: within the next two years, FP4 Tensor Core training will likely become standard practice for large models, just as FP8 became standard in 2023-2024 and BF16 in 2020-2022. Each transition required new hardware support, new scaling techniques, and new software infrastructure — but each delivered roughly 2x training efficiency for the same model quality. For practitioners, the imperative is straightforward: use the lowest precision your hardware supports and your model tolerates, keep master weights in FP32, and let the frameworks handle the scaling.