Part of Series Transformer Anatomy 24 of 36
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Distributed Data Parallel: Gradient Synchronization, Bucket All-Reduce, and Overlap with Backward 21 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 22 Dropout and Regularization in Transformers: Where It Helps, Where It Hurts 23 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 24 Mixed Precision Training: BF16 Forward, FP32 Master Weights, and the Precision Hierarchy 25 Token Prediction Heads: Next-Token, Multi-Token, and Classifier Heads 26 Mixture of Depths: Conditional Computation Per Layer for Faster Inference 27 Sparse Attention Patterns: Local, Strided, Hash-Based, and Learnable Sparsity 28 Rotary Position Embedding: The Complete Mathematical Derivation 29 Knowledge Distillation: Training Small Models to Match Large Ones 30 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search 31 Pruning at Scale: SparseGPT, Wanda, and Structured Removal of Redundant Parameters 32 The Transformer in 2026: What Changed, What Stayed, and What's Next 33 Data Loading: Tokenization, Sequence Packing, Padding Strategies, and Attention Masks 34 The FlashAttention Backward Pass: Recomputation, Memory Savings, and the 33% Compute Overhead 35 The Inference Engine: Token Generation Loop, KV Cache Management, and Autoregressive Decoding 36 Tensor Parallelism Implementation: Splitting Weights Across GPUs for Training and Inference

Training a 70B parameter model in FP32 requires 280 GB just for weights (70 billion parameters times 4 bytes). Add optimizer states (AdamW stores two additional FP32 copies: first and second moment), and you need 840 GB. Add activations for a reasonable batch size, and you exceed 1 TB. No single GPU has this much memory.

Mixed precision training solves this by using lower-precision formats (BF16 or FP16) for the bulk of computation while keeping critical operations in FP32. The result: 2x memory savings, 2-8x faster matrix multiplications, and — when done correctly — zero loss in training quality.

This post covers the exact precision hierarchy used in modern LLM training: which operations run in which precision, why, and how to implement it.

Number Formats

1.1 The Three Formats That Matter

import struct
import torch

def analyze_format(name, torch_dtype):
    """Show the bit layout and range of a floating-point format."""
    info = torch.finfo(torch_dtype)
    return {
        "name": name,
        "bits": info.bits,
        "sign_bits": 1,
        "exponent_bits": {
            torch.float32: 8,
            torch.float16: 5,
            torch.bfloat16: 8,
        }[torch_dtype],
        "mantissa_bits": {
            torch.float32: 23,
            torch.float16: 10,
            torch.bfloat16: 7,
        }[torch_dtype],
        "max": info.max,
        "min_positive": info.tiny,
        "eps": info.eps,
        "decimal_digits": info.resolution,
    }

for dtype in [torch.float32, torch.float16, torch.bfloat16]:
    info = analyze_format(str(dtype), dtype)
    print(f"{info['name']}: {info['bits']} bits "
          f"({info['exponent_bits']}e + {info['mantissa_bits']}m), "
          f"range [{info['min_positive']:.2e}, {info['max']:.2e}], "
          f"eps={info['eps']:.2e}")

Output:

torch.float32:  32 bits (8e + 23m), range [1.18e-38, 3.40e+38], eps=1.19e-07
torch.float16:  16 bits (5e + 10m), range [6.10e-05, 6.55e+04], eps=9.77e-04
torch.bfloat16: 16 bits (8e + 7m),  range [1.18e-38, 3.39e+38], eps=7.81e-03

The critical difference: BF16 has the same exponent range as FP32 (8 exponent bits) but much lower precision (7 mantissa bits vs. 23). FP16 has higher precision than BF16 (10 mantissa bits) but drastically smaller range (max 65,504 vs. 3.4×10383.4 \times 10^{38}).

1.2 Why BF16 Won Over FP16

FP16’s range problem: gradients during LLM training can have magnitudes spanning from 10810^{-8} to 10410^{4}. FP16’s maximum value is 65,504. A single gradient magnitude above this causes overflow, producing infinity, which corrupts the entire training run.

BF16’s range matches FP32 (3.4×10383.4 \times 10^{38}), so overflows that would not occur in FP32 will not occur in BF16 either. The tradeoff is precision: BF16 has only 7 mantissa bits, meaning it can represent about 2.5 decimal digits. For weight magnitudes (typically 0.001 to 1.0), BF16 quantization error is:

relative error270.78 percent\text{relative error} \leq 2^{-7} \approx 0.78 \text{ percent}

This is acceptable for forward and backward passes, where the result will be accumulated into FP32 master weights.

def demonstrate_precision_loss():
    """Show the precision difference between BF16 and FP32."""
    # Create a weight value
    w_fp32 = torch.tensor(0.123456789, dtype=torch.float32)
    w_bf16 = w_fp32.to(torch.bfloat16)
    w_fp16 = w_fp32.to(torch.float16)

    print(f"FP32:  {w_fp32.item():.10f}")
    print(f"BF16:  {w_bf16.item():.10f}")
    print(f"FP16:  {w_fp16.item():.10f}")
    print(f"BF16 error: {abs(w_fp32.item() - w_bf16.item()):.2e}")
    print(f"FP16 error: {abs(w_fp32.item() - w_fp16.item()):.2e}")
    # FP32:  0.1234567890
    # BF16:  0.1230468750  (error: ~4e-4)
    # FP16:  0.1234130859  (error: ~4e-5)

def demonstrate_range_problem():
    """Show why FP16 overflows but BF16 does not."""
    # Gradient that exceeds FP16 range
    grad = torch.tensor(100000.0, dtype=torch.float32)

    grad_bf16 = grad.to(torch.bfloat16)
    grad_fp16 = grad.to(torch.float16)

    print(f"FP32 grad:  {grad.item()}")
    print(f"BF16 grad:  {grad_bf16.item()}")  # 98304.0 (rounded but finite)
    print(f"FP16 grad:  {grad_fp16.item()}")  # inf (overflow)
⚠️ Warning

FP16 overflow is the primary reason BF16 replaced FP16 for LLM training. With FP16, you need loss scaling to prevent overflow. With BF16, loss scaling is unnecessary because the exponent range matches FP32. Every major training framework and hardware vendor now defaults to BF16.

The Precision Hierarchy

2.1 Overview

Modern LLM training uses a three-tier precision hierarchy:

OperationPrecisionWhy
Forward pass (matmuls)BF162x faster on tensor cores, sufficient precision
Backward pass (matmuls)BF16Same as forward
Forward pass (norms, softmax)FP32Numerical stability requires high precision
Backward pass (norms, softmax)FP32Gradient precision for sensitive ops
Master weightsFP32Accumulation of small updates requires precision
Optimizer states (m, v)FP32Running averages must not lose precision
Gradient accumulationFP32Sum of many small values needs precision
Weight updateFP32update = lr * m / sqrt(v); computed in FP32
Loss computationFP32Cross-entropy with log-softmax needs range

2.2 Why Master Weights Must Be FP32

The weight update in AdamW is:

wt+1=wtηm^tv^t+ϵw_{t+1} = w_t - \eta \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}

For a typical learning rate η=3×104\eta = 3 \times 10^{-4} and a typical update magnitude m^v^0.01\frac{\hat{m}}{\sqrt{\hat{v}}} \approx 0.01, the update per step is roughly 3×1063 \times 10^{-6}. In BF16, the smallest representable change to a weight of magnitude 0.1 is:

0.1×27=0.1×0.0078=7.8×1040.1 \times 2^{-7} = 0.1 \times 0.0078 = 7.8 \times 10^{-4}

The update (3×1063 \times 10^{-6}) is 250x smaller than BF16’s resolution at that magnitude. In BF16, the update would be rounded to zero. The weight would never change. Training would not converge.

def demonstrate_stale_weights():
    """Show that BF16 weights cannot accumulate small updates."""
    w_bf16 = torch.tensor(0.1, dtype=torch.bfloat16)
    w_fp32 = torch.tensor(0.1, dtype=torch.float32)

    update = 3e-6  # Typical per-step update magnitude

    # Simulate 10000 updates
    for _ in range(10000):
        w_bf16 = w_bf16 - update  # BF16 arithmetic
        w_fp32 = w_fp32 - update  # FP32 arithmetic

    print(f"BF16 after 10K updates: {w_bf16.item():.6f}")  # Still ~0.1
    print(f"FP32 after 10K updates: {w_fp32.item():.6f}")  # 0.070000
    # BF16 rounds every update to 0 -- no learning happens

The solution: maintain a FP32 copy of all weights (master weights). The training loop becomes:

  1. Cast master weights (FP32) to BF16 for forward/backward
  2. Compute gradients in BF16
  3. Cast gradients to FP32
  4. Update master weights in FP32
  5. Repeat

2.3 Memory Budget

def memory_budget(params_B, seq_len=4096, batch_size=4, n_layers=80):
    """Calculate memory budget for mixed-precision training.

    Args:
        params_B: number of parameters in billions
    """
    params = params_B * 1e9

    # Weights
    master_weights_fp32 = params * 4  # 4 bytes per param
    bf16_weights = params * 2          # 2 bytes per param

    # Optimizer (AdamW)
    optimizer_m = params * 4  # First moment (FP32)
    optimizer_v = params * 4  # Second moment (FP32)

    # Gradients
    gradients_fp32 = params * 4  # Accumulated in FP32

    # Total model state
    total_model = (master_weights_fp32 + bf16_weights +
                   optimizer_m + optimizer_v + gradients_fp32)

    print(f"Model: {params_B}B parameters")
    print(f"  Master weights (FP32): {master_weights_fp32 / 1e9:.1f} GB")
    print(f"  BF16 weights:          {bf16_weights / 1e9:.1f} GB")
    print(f"  Optimizer m (FP32):    {optimizer_m / 1e9:.1f} GB")
    print(f"  Optimizer v (FP32):    {optimizer_v / 1e9:.1f} GB")
    print(f"  Gradients (FP32):      {gradients_fp32 / 1e9:.1f} GB")
    print(f"  Total model state:     {total_model / 1e9:.1f} GB")
    print(f"  Per-param bytes:       {total_model / params:.1f}")

    return total_model

# Example: Llama 70B
memory_budget(70)
# Master weights (FP32): 280.0 GB
# BF16 weights:          140.0 GB
# Optimizer m (FP32):    280.0 GB
# Optimizer v (FP32):    280.0 GB
# Gradients (FP32):      280.0 GB
# Total model state:     1260.0 GB
# Per-param bytes:       18.0

18 bytes per parameter. For 70B parameters, that is 1.26 TB of model state alone, before activations. This is why 70B training requires at least 16 H100 GPUs (80 GB each, 1.28 TB total) with FSDP to shard the model state across GPUs.

Per-Operation Precision Requirements

3.1 Matrix Multiplications: BF16

GEMMs (General Matrix Multiplications) dominate training FLOPs. On NVIDIA tensor cores, BF16 GEMMs are 2x faster than FP32 GEMMs (H100: 990 TFLOP/s BF16 vs. 495 TFLOP/s FP32). The internal accumulation inside the tensor core is done in FP32 — only the inputs and outputs are BF16.

import torch

def bf16_matmul(a, b):
    """BF16 matmul with FP32 internal accumulation.

    Tensor cores compute: C_fp32 = A_bf16 @ B_bf16 (accumulated in FP32)
    Then: C_bf16 = cast(C_fp32)

    The accumulation in FP32 is critical -- without it,
    summing many BF16 products would lose too much precision.
    """
    # PyTorch handles this automatically when inputs are BF16
    a_bf16 = a.to(torch.bfloat16)
    b_bf16 = b.to(torch.bfloat16)

    # Tensor core does: FP32 accumulation of BF16 * BF16 products
    c = torch.matmul(a_bf16, b_bf16)  # Output is BF16
    return c

# The key matmuls in a transformer:
# QKV projection:    X_bf16 @ W_qkv_bf16
# Attention scores:  Q_bf16 @ K^T_bf16
# Attention output:  attn_bf16 @ V_bf16
# Output projection: attn_out_bf16 @ W_o_bf16
# FFN up/gate/down:  X_bf16 @ W_ff_bf16

3.2 Normalization: FP32

RMSNorm and LayerNorm require FP32 because they compute statistics (variance, RMS) over the hidden dimension. In BF16, the sum of squares xi2\sum x_i^2 can overflow or suffer catastrophic cancellation:

class RMSNormMixedPrecision(torch.nn.Module):
    """RMSNorm with FP32 computation, BF16 input/output."""

    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        # Input may be BF16
        input_dtype = x.dtype

        # Upcast to FP32 for the norm computation
        x_fp32 = x.float()

        # Compute RMS in FP32
        rms = torch.sqrt(x_fp32.pow(2).mean(dim=-1, keepdim=True) + self.eps)

        # Normalize in FP32, then cast back
        normed = x_fp32 / rms

        # Apply weight (in FP32) and cast back to input dtype
        return (normed * self.weight.float()).to(input_dtype)

Why FP32 is needed for norms: consider hidden dimension d=4096d = 4096. If all elements are 1.0 in BF16, the sum of squares is 4096.0, and the mean is 1.0. But with BF16 addition of 4096 terms, the accumulated rounding error can be up to 4096×ϵBF1664×0.00780.5\sqrt{4096} \times \epsilon_{BF16} \approx 64 \times 0.0078 \approx 0.5. That is a 50% error in the variance estimate, which would cause wildly wrong normalization.

3.3 Softmax: FP32

The softmax function softmax(xi)=eximax(x)jexjmax(x)\text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_j e^{x_j - \max(x)}} involves exponentiation and normalization. Both are sensitive to precision:

def softmax_precision_comparison():
    """Show why softmax needs FP32."""
    # Attention scores for a sequence of length 2048
    torch.manual_seed(42)
    scores = torch.randn(1, 32, 2048, 2048) * 3.0  # Typical scale

    # BF16 softmax
    scores_bf16 = scores.to(torch.bfloat16)
    attn_bf16 = torch.softmax(scores_bf16, dim=-1)

    # FP32 softmax
    attn_fp32 = torch.softmax(scores, dim=-1)

    # Compare: do rows sum to 1.0?
    row_sums_bf16 = attn_bf16.float().sum(dim=-1)
    row_sums_fp32 = attn_fp32.sum(dim=-1)

    print(f"BF16 softmax row sums: "
          f"mean={row_sums_bf16.mean():.6f}, "
          f"std={row_sums_bf16.std():.6f}")
    print(f"FP32 softmax row sums: "
          f"mean={row_sums_fp32.mean():.6f}, "
          f"std={row_sums_fp32.std():.6f}")
    # BF16 rows may not sum to exactly 1.0, causing attention to
    # systematically over- or under-weight certain positions

3.4 Cross-Entropy Loss: FP32

The cross-entropy loss involves log-softmax, which can produce very negative values (log of small probabilities). BF16 cannot represent values below 3.39×1038-3.39 \times 10^{38} (same as FP32 range), but the precision is insufficient for stable gradient computation:

def cross_entropy_precision():
    """Cross-entropy always computed in FP32."""
    vocab_size = 128256
    batch_seq = 4 * 4096  # batch * seq_len

    # Logits from the model (BF16)
    logits_bf16 = torch.randn(batch_seq, vocab_size, dtype=torch.bfloat16)
    targets = torch.randint(0, vocab_size, (batch_seq,))

    # Must upcast to FP32 for loss computation
    logits_fp32 = logits_bf16.float()
    loss = torch.nn.functional.cross_entropy(logits_fp32, targets)

    # The gradient of CE w.r.t. logits is:
    # d_loss/d_logit_i = softmax(logit_i) - target_i
    # This difference can be very small (1e-6 to 1e-8),
    # which would be rounded to 0 in BF16
    return loss

3.5 Summary: The Precision Map

class TransformerLayerMixedPrecision(torch.nn.Module):
    """A single transformer layer with correct mixed precision."""

    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attn_norm = RMSNormMixedPrecision(d_model)      # FP32 internal
        self.ffn_norm = RMSNormMixedPrecision(d_model)        # FP32 internal

        # All linear layers store weights in BF16
        self.q_proj = torch.nn.Linear(d_model, d_model, bias=False)
        self.k_proj = torch.nn.Linear(d_model, d_model, bias=False)
        self.v_proj = torch.nn.Linear(d_model, d_model, bias=False)
        self.o_proj = torch.nn.Linear(d_model, d_model, bias=False)

        self.gate_proj = torch.nn.Linear(d_model, d_ff, bias=False)
        self.up_proj = torch.nn.Linear(d_model, d_ff, bias=False)
        self.down_proj = torch.nn.Linear(d_ff, d_model, bias=False)

        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

    def forward(self, x):
        """
        x: BF16 tensor [B, S, d_model]

        Precision flow:
        1. x (BF16) -> norm (FP32 internal, BF16 output) -> BF16
        2. BF16 -> Q,K,V projections (BF16 matmul) -> BF16
        3. BF16 -> attention scores Q@K^T (BF16 matmul) -> BF16
        4. BF16 -> softmax (FP32 internal, BF16 output) -> BF16
        5. BF16 -> attention @ V (BF16 matmul) -> BF16
        6. BF16 -> output proj (BF16 matmul) -> BF16
        7. BF16 -> residual add -> BF16
        8. Repeat for FFN
        """
        # Attention block
        normed = self.attn_norm(x)  # BF16 -> FP32 -> BF16

        B, S, D = normed.shape
        q = self.q_proj(normed)  # BF16 matmul
        k = self.k_proj(normed)
        v = self.v_proj(normed)

        q = q.reshape(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.reshape(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.reshape(B, S, self.n_heads, self.head_dim).transpose(1, 2)

        # Attention scores (BF16 matmul)
        scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Softmax in FP32
        attn_weights = torch.softmax(scores.float(), dim=-1).to(x.dtype)

        # Weighted sum (BF16 matmul)
        attn_out = attn_weights @ v
        attn_out = attn_out.transpose(1, 2).reshape(B, S, D)
        attn_out = self.o_proj(attn_out)

        # Residual (BF16 add)
        x = x + attn_out

        # FFN block
        normed = self.ffn_norm(x)
        ffn_out = self.down_proj(
            torch.nn.functional.silu(self.gate_proj(normed)) * self.up_proj(normed)
        )
        x = x + ffn_out

        return x

Loss Scaling (FP16 Only)

4.1 The FP16 Gradient Problem

If you must use FP16 instead of BF16 (older hardware without BF16 support), gradients can underflow. Small gradients (10810^{-8}) fall below FP16’s minimum positive value (6.1×1056.1 \times 10^{-5}) and become zero. Loss scaling fixes this by multiplying the loss by a large constant before backward, then dividing the gradients by the same constant after backward:

class LossScaler:
    """Dynamic loss scaling for FP16 training.

    Not needed for BF16 -- only for FP16 on older hardware.
    """

    def __init__(self, init_scale=2**16, growth_factor=2.0,
                 backoff_factor=0.5, growth_interval=2000):
        self.scale = init_scale
        self.growth_factor = growth_factor
        self.backoff_factor = backoff_factor
        self.growth_interval = growth_interval
        self.steps_since_last_overflow = 0

    def scale_loss(self, loss):
        """Multiply loss by scale factor before backward."""
        return loss * self.scale

    def unscale_gradients(self, optimizer):
        """Divide gradients by scale factor after backward."""
        for group in optimizer.param_groups:
            for param in group["params"]:
                if param.grad is not None:
                    param.grad.data /= self.scale

    def update(self, overflow_detected):
        """Adjust scale based on whether overflow occurred."""
        if overflow_detected:
            # Overflow: reduce scale, skip this step
            self.scale *= self.backoff_factor
            self.steps_since_last_overflow = 0
        else:
            self.steps_since_last_overflow += 1
            if self.steps_since_last_overflow >= self.growth_interval:
                # No overflow for a while: try increasing scale
                self.scale *= self.growth_factor
                self.steps_since_last_overflow = 0

    def check_overflow(self, optimizer):
        """Check if any gradient contains inf or nan."""
        for group in optimizer.param_groups:
            for param in group["params"]:
                if param.grad is not None:
                    if torch.isinf(param.grad).any() or torch.isnan(param.grad).any():
                        return True
        return False

4.2 Why BF16 Does Not Need Loss Scaling

BF16 has the same exponent range as FP32 (8 exponent bits). Any value representable in FP32 (in terms of magnitude) is representable in BF16 (with reduced precision). Gradients of magnitude 10810^{-8} map to the BF16 value 9.3×1099.3 \times 10^{-9} (the nearest representable value). The precision loss is up to 0.78%, but the value is not zero.

def bf16_vs_fp16_gradient_survival():
    """Compare how small gradients survive in BF16 vs FP16."""
    small_grads = [1e-4, 1e-5, 1e-6, 1e-7, 1e-8]

    for g in small_grads:
        g_fp32 = torch.tensor(g, dtype=torch.float32)
        g_bf16 = g_fp32.to(torch.bfloat16)
        g_fp16 = g_fp32.to(torch.float16)

        print(f"Grad {g:.0e}: "
              f"FP32={g_fp32.item():.2e}, "
              f"BF16={g_bf16.item():.2e}, "
              f"FP16={g_fp16.item():.2e}")

    # Output:
    # Grad 1e-04: FP32=1.00e-04, BF16=1.00e-04, FP16=1.00e-04
    # Grad 1e-05: FP32=1.00e-05, BF16=9.78e-06, FP16=9.97e-06
    # Grad 1e-06: FP32=1.00e-06, BF16=9.78e-07, FP16=1.01e-06
    # Grad 1e-07: FP32=1.00e-07, BF16=1.01e-07, FP16=0.00e+00  <-- FP16 underflow
    # Grad 1e-08: FP32=1.00e-08, BF16=9.31e-09, FP16=0.00e+00  <-- FP16 underflow
ℹ️ Note

If your hardware supports BF16 (NVIDIA Ampere and later, AMD MI250 and later, Google TPUs), always use BF16 over FP16. You avoid the entire loss scaling complexity with no quality cost. The only reason to use FP16 is on older hardware (NVIDIA V100, T4) that lacks BF16 tensor cores.

Implementation with torch.cuda.amp

5.1 The AMP Autocast Context Manager

PyTorch’s Automatic Mixed Precision (AMP) handles precision casting automatically. You wrap the forward pass in torch.cuda.amp.autocast, and it applies the precision hierarchy described above:

import torch
from torch.cuda.amp import autocast, GradScaler

def training_step_bf16(model, batch, optimizer):
    """A single training step with BF16 mixed precision.

    No GradScaler needed for BF16 -- only for FP16.
    """
    input_ids = batch["input_ids"].cuda()
    labels = batch["labels"].cuda()

    # autocast handles: matmuls in BF16, norms/softmax in FP32
    with autocast(dtype=torch.bfloat16):
        outputs = model(input_ids)
        # Loss computed in FP32 (autocast upcasts for CE loss)
        loss = torch.nn.functional.cross_entropy(
            outputs.logits.float(),  # Explicit upcast for safety
            labels.reshape(-1)
        )

    # Backward (gradients computed in mixed precision)
    loss.backward()

    # Gradient clipping (in FP32, on master weights' gradients)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # Optimizer step (updates FP32 master weights)
    optimizer.step()
    optimizer.zero_grad()

    return loss.item()

def training_step_fp16(model, batch, optimizer, scaler):
    """A single training step with FP16 mixed precision.

    Requires GradScaler to prevent gradient underflow.
    """
    input_ids = batch["input_ids"].cuda()
    labels = batch["labels"].cuda()

    with autocast(dtype=torch.float16):
        outputs = model(input_ids)
        loss = torch.nn.functional.cross_entropy(
            outputs.logits.float(),
            labels.reshape(-1)
        )

    # Scale loss before backward to prevent gradient underflow
    scaler.scale(loss).backward()

    # Unscale gradients for clipping
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # Step (with overflow checking)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()

    return loss.item()

5.2 What autocast Does Internally

The autocast context manager maintains a list of operations and their target precision:

# Operations that autocast runs in BF16/FP16 (fast on tensor cores):
LOWER_PRECISION_OPS = [
    "torch.matmul",
    "torch.nn.functional.linear",
    "torch.nn.functional.conv1d",
    "torch.nn.functional.conv2d",
    "torch.bmm",
    "torch.addmm",
    "torch.addbmm",
    "torch.baddbmm",
]

# Operations that autocast keeps in FP32 (need precision):
FP32_OPS = [
    "torch.nn.functional.softmax",
    "torch.nn.functional.cross_entropy",
    "torch.nn.functional.log_softmax",
    "torch.nn.functional.layer_norm",
    "torch.nn.functional.group_norm",
    "torch.nn.functional.batch_norm",
    "torch.pow",
    "torch.norm",
    "torch.sum",  # Large reductions
    "torch.mean",
]

# Operations that autocast promotes to the widest input type:
PROMOTE_OPS = [
    "torch.add",
    "torch.sub",
    "torch.mul",
    "torch.div",
    "torch.cat",
    "torch.stack",
]

5.3 Full Training Loop

def train(model, train_loader, val_loader, config):
    """Complete training loop with BF16 mixed precision."""

    # Model in BF16 for forward/backward
    model = model.cuda().to(torch.bfloat16)

    # Optimizer operates on FP32 master weights
    # PyTorch handles this: optimizer stores FP32 copies internally
    # when model params are BF16
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config["lr"],
        betas=(0.9, 0.95),
        eps=1e-8,
        weight_decay=config["weight_decay"],
    )

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=config["total_steps"], eta_min=config["min_lr"]
    )

    # Gradient accumulation for effective batch size
    grad_accum_steps = config["grad_accum_steps"]

    step = 0
    for epoch in range(config["epochs"]):
        model.train()

        for micro_step, batch in enumerate(train_loader):
            with autocast(dtype=torch.bfloat16):
                outputs = model(batch["input_ids"].cuda())
                loss = torch.nn.functional.cross_entropy(
                    outputs.logits.float().reshape(-1, outputs.logits.size(-1)),
                    batch["labels"].cuda().reshape(-1),
                )
                # Scale for gradient accumulation
                loss = loss / grad_accum_steps

            loss.backward()

            if (micro_step + 1) % grad_accum_steps == 0:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm=1.0
                )
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                step += 1

                if step % config["log_interval"] == 0:
                    print(f"Step {step}: loss={loss.item() * grad_accum_steps:.4f}, "
                          f"lr={scheduler.get_last_lr()[0]:.2e}")

                if step % config["eval_interval"] == 0:
                    val_loss = evaluate(model, val_loader)
                    print(f"Step {step}: val_loss={val_loss:.4f}")

    return model

def evaluate(model, val_loader):
    """Evaluate model in BF16."""
    model.eval()
    total_loss = 0
    total_tokens = 0

    with torch.no_grad(), autocast(dtype=torch.bfloat16):
        for batch in val_loader:
            outputs = model(batch["input_ids"].cuda())
            loss = torch.nn.functional.cross_entropy(
                outputs.logits.float().reshape(-1, outputs.logits.size(-1)),
                batch["labels"].cuda().reshape(-1),
                reduction="sum",
            )
            total_loss += loss.item()
            total_tokens += batch["labels"].numel()

    return total_loss / total_tokens

6.1 Loss Spikes from BF16 Accumulation

When using BF16, gradient accumulation across many micro-batches can lose precision. Each BF16 addition loses up to 0.78% relative error. Over 64 accumulation steps, the error compounds:

def gradient_accumulation_error(n_steps=64):
    """Show precision loss from BF16 gradient accumulation."""
    # Simulate: accumulate small gradients
    grad_per_step = torch.randn(4096, dtype=torch.float32) * 0.001

    # FP32 accumulation (ground truth)
    accum_fp32 = torch.zeros(4096, dtype=torch.float32)
    for _ in range(n_steps):
        accum_fp32 += grad_per_step

    # BF16 accumulation (problematic)
    accum_bf16 = torch.zeros(4096, dtype=torch.bfloat16)
    for _ in range(n_steps):
        accum_bf16 += grad_per_step.to(torch.bfloat16)

    # Compare
    error = (accum_bf16.float() - accum_fp32).abs()
    relative_error = error / (accum_fp32.abs() + 1e-10)

    print(f"Accumulation over {n_steps} steps:")
    print(f"  Mean relative error: {relative_error.mean():.4f}")
    print(f"  Max relative error:  {relative_error.max():.4f}")
    # With 64 steps: mean relative error can reach 5-10%

The fix: accumulate gradients in FP32, even if individual gradients are computed in BF16.

def safe_gradient_accumulation(model, micro_batches, grad_accum_steps):
    """Accumulate gradients in FP32 for precision."""
    # Option 1: PyTorch autograd accumulates in param.grad dtype
    # If param is BF16, grad is BF16 -- problematic for many steps

    # Option 2: Keep separate FP32 gradient buffers
    fp32_grads = {
        name: torch.zeros_like(param, dtype=torch.float32)
        for name, param in model.named_parameters()
    }

    for micro_batch in micro_batches:
        with autocast(dtype=torch.bfloat16):
            loss = compute_loss(model, micro_batch) / grad_accum_steps
        loss.backward()

        # Accumulate in FP32
        with torch.no_grad():
            for name, param in model.named_parameters():
                if param.grad is not None:
                    fp32_grads[name] += param.grad.float()
                    param.grad = None  # Free BF16 grad

    # Copy FP32 accumulated grads back
    with torch.no_grad():
        for name, param in model.named_parameters():
            param.grad = fp32_grads[name].to(param.dtype)

6.2 Norm Instability

If norms accidentally run in BF16, training can become unstable after tens of thousands of steps. The symptom: loss spikes that recover but become more frequent:

def diagnose_norm_precision(model):
    """Check that all norm layers operate in FP32 internally."""
    for name, module in model.named_modules():
        if "norm" in name.lower():
            # Check weight dtype
            if hasattr(module, "weight"):
                w_dtype = module.weight.dtype
                if w_dtype != torch.float32:
                    print(f"WARNING: {name} weight is {w_dtype}, should be FP32")

            # Test forward precision
            test_input = torch.randn(1, 10, module.weight.shape[0],
                                      device="cuda", dtype=torch.bfloat16)
            with torch.no_grad():
                output = module(test_input)
                # Check if intermediate computation uses FP32
                # (this is a simplified check)
                print(f"{name}: input={test_input.dtype}, output={output.dtype}")

6.3 Embedding and Output Head Precision

The embedding lookup and the final output projection (logits) are often overlooked. The embedding lookup itself is exact (just a table lookup), but the output logits feed into softmax and cross-entropy, which require FP32:

class SafeLMHead(torch.nn.Module):
    """Output head that ensures FP32 for loss computation."""

    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.proj = torch.nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, hidden_states, labels=None):
        # Logits in BF16 (matmul)
        logits = self.proj(hidden_states)

        if labels is not None:
            # Upcast to FP32 for loss
            loss = torch.nn.functional.cross_entropy(
                logits.float().reshape(-1, logits.size(-1)),
                labels.reshape(-1),
            )
            return logits, loss

        return logits

FP8: The Next Frontier

7.1 FP8 Formats on Hopper/Blackwell

NVIDIA Hopper (H100) introduced FP8 tensor cores. Two formats:

  • E4M3: 4 exponent bits, 3 mantissa bits. Range: ±448\pm 448. Precision: ϵ=0.0625\epsilon = 0.0625 (6.25%). Used for forward pass.
  • E5M2: 5 exponent bits, 2 mantissa bits. Range: ±57344\pm 57344. Precision: ϵ=0.25\epsilon = 0.25 (25%). Used for backward pass (needs more range for gradients).
def fp8_training_concept():
    """FP8 training: the next step after BF16.

    H100 FP8 tensor cores: 1980 TFLOP/s (2x BF16, 4x FP32)
    """
    config = {
        "forward_matmuls": "FP8-E4M3 (higher precision for activations)",
        "backward_matmuls": "FP8-E5M2 (higher range for gradients)",
        "norms": "FP32 (still needs high precision)",
        "softmax": "FP32 (still needs high precision)",
        "master_weights": "FP32 (still needs accumulation precision)",
        "optimizer": "FP32 (still needs accumulation precision)",
        "per_tensor_scaling": "Required (dynamic scale per tensor)",
    }
    return config

7.2 Per-Tensor Scaling for FP8

FP8’s narrow range (±448\pm 448 for E4M3) means that tensors must be scaled to fit. Each tensor gets a per-tensor scale factor that maps its values into the FP8 representable range:

def fp8_quantize(tensor, fp8_max=448.0):
    """Quantize a tensor to FP8 with per-tensor scaling.

    scale = fp8_max / tensor.abs().max()
    fp8_tensor = round(tensor * scale)

    Dequantize: tensor_approx = fp8_tensor / scale
    """
    amax = tensor.abs().max()
    scale = fp8_max / amax.clamp(min=1e-12)

    # Scale and clamp to FP8 range
    scaled = tensor * scale
    quantized = scaled.clamp(-fp8_max, fp8_max)

    # In real hardware, this is stored as 8-bit FP values
    # Here we simulate with float
    return quantized, scale

def fp8_matmul(a, b, a_scale, b_scale):
    """Simulated FP8 matrix multiplication.

    Real hardware: tensor cores compute A_fp8 @ B_fp8 with FP32 accumulation.
    Result: C_fp32 = (A_fp8 @ B_fp8) / (a_scale * b_scale)
    """
    c = torch.matmul(a, b)  # FP8 matmul (simulated in float)
    c = c / (a_scale * b_scale)  # Descale
    return c
Performance

FP8 training on H100 achieves 1980 TFLOP/s for matrix multiplications — 2x over BF16 (990 TFLOP/s). For a 70B model, this reduces training time by roughly 30-40% (not 2x, because not all operations are matmuls). The quality impact is minimal with proper per-tensor scaling, typically less than 0.1% loss increase.

Practical Checklist

def mixed_precision_checklist():
    """Checklist for correct mixed precision training."""
    return {
        "1_use_bf16": (
            "Use BF16, not FP16, if hardware supports it. "
            "Eliminates need for loss scaling."
        ),
        "2_fp32_master_weights": (
            "Optimizer must maintain FP32 copies of all weights. "
            "PyTorch AdamW does this automatically."
        ),
        "3_fp32_norms": (
            "RMSNorm/LayerNorm internal computation must be FP32. "
            "Upcast input, compute, downcast output."
        ),
        "4_fp32_softmax": (
            "Attention softmax must be computed in FP32. "
            "FlashAttention handles this internally."
        ),
        "5_fp32_loss": (
            "Cross-entropy loss must be computed in FP32. "
            "Upcast logits before loss function."
        ),
        "6_fp32_grad_accum": (
            "If using many gradient accumulation steps (more than 8), "
            "accumulate in FP32 to prevent drift."
        ),
        "7_no_bf16_reduction": (
            "Never reduce (sum, mean) over large dimensions in BF16. "
            "Upcast first, reduce in FP32."
        ),
        "8_check_grad_norms": (
            "Monitor gradient norms per layer during training. "
            "Sudden spikes indicate precision issues."
        ),
    }
📊

Mixed Precision Training Speed (Llama 7B, single H100)

PrecisionThroughput (tok/s)Speedup vs FP32
FP32 (all operations) 1,200 baseline
BF16 mixed (standard) 3,400 +183%
BF16 + torch.compile 4,100 +242%
FP8 mixed (H100) 5,800 +383%

Mixed precision training is not optional for LLMs. It is a prerequisite. The precision hierarchy described in this post — BF16 for matmuls, FP32 for norms/softmax/loss/optimizer/master weights — is the standard used by every major training framework. Understanding why each operation requires its specific precision prevents subtle training bugs that manifest as loss spikes, divergence, or silently degraded model quality.

References

  1. Micikevicius, P. et al. “Mixed Precision Training.” ICLR 2018.
  2. Kalamkar, D. et al. “A Study of BFLOAT16 for Deep Learning Training.” arXiv 2019.
  3. NVIDIA. “Transformer Engine: FP8 Training.” Documentation, 2023.
  4. Dettmers, T. et al. “LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale.” NeurIPS 2022.
  5. PyTorch. “Automatic Mixed Precision.” Documentation, 2024.