Part of Series Quantization Masterclass 4 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest โ€” Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier โ€” Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 โ€” How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization โ€” The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization โ€” OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

FP8 is the most important precision innovation since BF16. It halves the memory and doubles the throughput compared to BF16/FP16 for both training and inference, and it does so without the outlier problems that plague INT8 activation quantization. FP8โ€™s floating-point representation naturally handles the non-uniform distributions found in neural network tensors โ€” no SmoothQuant transformation needed.

This post covers FP8 in full depth: why two FP8 variants exist (E4M3 and E5M2) and when to use each, how NVIDIA Transformer Engine manages FP8 mixed precision automatically, the delayed scaling algorithm that makes FP8 practical, which operations benefit from FP8 (and which do not), how DeepSeek trained a 671B MoE model entirely in FP8, and a complete implementation of an FP8 linear layer with proper scaling.

The FP8 Precision Strategy

Recall from Part 1 that FP8 comes in two variants:

  • E4M3: 4 exponent bits, 3 mantissa bits. Range ยฑ448\pm 448. More precision, less range.
  • E5M2: 5 exponent bits, 2 mantissa bits. Range ยฑ57344\pm 57344. Less precision, more range.

The standard FP8 recipe uses both variants in different roles:

TensorFormatRationale
Forward activationsE4M3Values are bounded; precision matters
Forward weightsE4M3Static, well-distributed; precision matters
Backward gradientsE5M2Extreme dynamic range; range matters
Master weightsFP32Full precision for optimizer state
Optimizer stateFP32Momentum/variance need full precision
โ„น๏ธ Why Not E4M3 Everywhere?

E4M3โ€™s maximum value is 448. Gradient values during training regularly exceed this. A single gradient overflow to NaN can destabilize the entire training run. E5M2โ€™s maximum of 57344 provides the safety margin needed for gradients. Loss scaling could extend E4M3โ€™s effective range, but E5M2 avoids the complexity and fragility of loss scaling entirely.

Which Operations Use FP8

Not all operations benefit from FP8 precision. The rule is simple: GEMMs use FP8, everything else stays in higher precision.

A transformer layer contains these compute operations:

Forward pass through one transformer block:
1. LayerNorm(x)              -- FP32 (reduction, needs precision)
2. Q = x @ W_q               -- FP8 GEMM (E4M3 inputs, FP32 accum)
3. K = x @ W_k               -- FP8 GEMM
4. V = x @ W_v               -- FP8 GEMM
5. attn = softmax(Q @ K^T)   -- FP32 (softmax needs precision)
6. out = attn @ V             -- FP8 GEMM (attention @ values)
7. proj = out @ W_o           -- FP8 GEMM
8. residual add               -- FP32
9. LayerNorm(x)              -- FP32
10. up = x @ W_up            -- FP8 GEMM
11. gate = x @ W_gate        -- FP8 GEMM
12. SiLU(gate) * up          -- FP32 (element-wise)
13. down = h @ W_down        -- FP8 GEMM
14. residual add             -- FP32

Of the 14 operations, 8 are GEMMs that run in FP8. The remaining 6 (LayerNorm, softmax, activation functions, residual adds) stay in FP32 or BF16 because they are either numerically sensitive (softmax, LayerNorm) or trivially cheap (element-wise operations, residual adds).

โšก GEMMs Dominate Compute

In a typical transformer, GEMMs account for over 95% of FLOPs. Running them in FP8 provides nearly 2x throughput improvement for the entire layer, even though the non-GEMM operations remain in higher precision. The non-GEMM operations contribute negligible runtime.

The Scaling Problem

FP8 E4M3 can represent values from 2โˆ’92^{-9} (smallest subnormal) to 448. If your tensor values fall outside this range, you get underflow (small values become zero) or overflow (large values become NaN). Neither is acceptable.

The solution is scaling: multiply the tensor by a scale factor before casting to FP8, then divide the GEMM output by the same factor.

Y=(sxโ‹…Xfp8)โ‹…(swโ‹…Wfp8)T=sxโ‹…swโ‹…(Xfp8โ‹…Wfp8T)Y = (s_x \cdot X_{\text{fp8}}) \cdot (s_w \cdot W_{\text{fp8}})^T = s_x \cdot s_w \cdot (X_{\text{fp8}} \cdot W_{\text{fp8}}^T)

The GEMM is performed in FP8 with FP32 accumulation, and the scale factors are applied to the FP32 result. The critical question is: how do you choose sxs_x and sws_w?

Per-Tensor Scaling

The simplest approach: compute the max absolute value of the tensor, and set the scale factor to map that value to the FP8 max:

s=FP8_MAXmaxโก(โˆฃTโˆฃ)=448maxโก(โˆฃTโˆฃ)s = \frac{\text{FP8\_MAX}}{\max(|T|)} = \frac{448}{\max(|T|)}

import torch

def compute_fp8_scale(tensor, fp8_max=448.0):
    """Compute per-tensor scale to map tensor into FP8 E4M3 range."""
    amax = tensor.abs().max().item()
    if amax == 0:
        return 1.0
    return fp8_max / amax

def quantize_to_fp8_e4m3(tensor, scale):
    """Quantize FP32/BF16 tensor to simulated FP8 E4M3.

    In practice, this is a hardware cast instruction (CUDA __nv_fp8_e4m3).
    Here we simulate it with clamping and reduced precision.
    """
    scaled = tensor * scale
    # Clamp to E4M3 range
    clamped = scaled.clamp(-448.0, 448.0)
    # Simulate E4M3 precision: round to nearest representable value
    # E4M3 has 3 mantissa bits = 8 values per power-of-2 interval
    # We approximate by quantizing to 4-bit resolution in the log domain
    return clamped  # On real hardware, this becomes actual FP8

def dequantize_from_fp8(fp8_tensor, scale):
    """Dequantize FP8 back to FP32."""
    return fp8_tensor / scale

The Problem with Just-In-Time Scaling

Computing tensor.abs().max() requires reading the entire tensor from memory. For a large activation tensor, this is a separate kernel launch that reads all the data, computes the max, and then the quantization kernel reads the data again. Two full memory reads where there should be one.

This overhead is why delayed scaling exists.

Delayed Scaling: The Key Algorithm

Delayed scaling uses the max absolute value from a previous iteration to compute the scale factor for the current iteration. The insight: tensor distributions change slowly during training. The max value at step tt is a good approximation of the max value at step t+1t+1.

NVIDIA Transformer Engine maintains an amax history buffer for each FP8 tensor. The algorithm:

  1. At step tt, use the scale computed from step tโˆ’1t-1 (or earlier) to cast tensors to FP8
  2. During the FP8 GEMM, record the actual max absolute value of the current tensors (piggyback on the GEMM kernel)
  3. Update the amax history buffer
  4. Compute the scale for step t+1t+1 from the history
class DelayedScaling:
    """Delayed scaling algorithm for FP8 tensors.

    Maintains a history of amax values and uses them to compute
    scale factors one step behind.
    """

    def __init__(self, history_len=1024, fp8_max=448.0, margin=0):
        """
        history_len: number of past amax values to keep
        fp8_max: maximum representable FP8 value (448 for E4M3)
        margin: safety margin in powers of 2 (2^margin headroom)
        """
        self.history_len = history_len
        self.fp8_max = fp8_max
        self.margin = margin

        # Circular buffer of amax values
        self.amax_history = torch.zeros(history_len)
        self.history_idx = 0
        self.scale = 1.0

    def update(self, current_amax):
        """Record the amax from the current step and update scale."""
        # Store current amax in history
        self.amax_history[self.history_idx % self.history_len] = current_amax
        self.history_idx += 1

        # Compute scale from history
        # Use the max of recent history for safety
        valid_len = min(self.history_idx, self.history_len)
        amax_from_history = self.amax_history[:valid_len].max().item()

        if amax_from_history == 0:
            self.scale = 1.0
        else:
            self.scale = (self.fp8_max / amax_from_history) / (2 ** self.margin)

        return self.scale

    def get_scale(self):
        """Get the current scale factor (computed from previous step)."""
        return self.scale

class FP8TensorManager:
    """Manages delayed scaling for all FP8 tensors in a layer."""

    def __init__(self, fp8_max_fwd=448.0, fp8_max_bwd=57344.0):
        self.input_scaling = DelayedScaling(fp8_max=fp8_max_fwd)
        self.weight_scaling = DelayedScaling(fp8_max=fp8_max_fwd)
        self.grad_output_scaling = DelayedScaling(fp8_max=fp8_max_bwd)

    def get_forward_scales(self):
        """Get scale factors for forward pass (E4M3)."""
        return self.input_scaling.get_scale(), self.weight_scaling.get_scale()

    def get_backward_scale(self):
        """Get scale factor for backward pass (E5M2)."""
        return self.grad_output_scaling.get_scale()

    def update_forward(self, input_amax, weight_amax):
        """Update forward scales after computing the GEMM."""
        self.input_scaling.update(input_amax)
        self.weight_scaling.update(weight_amax)

    def update_backward(self, grad_amax):
        """Update backward scale after computing the backward GEMM."""
        self.grad_output_scaling.update(grad_amax)
โš ๏ธ Delayed Scaling Lag

Delayed scaling uses stale scale factors. If the tensor distribution changes abruptly (e.g., a sudden gradient spike), the scale factor from the previous step may be too small, causing overflow. The margin parameter provides headroom: a margin of 1 means the scale leaves a 2x safety margin. Transformer Engine uses a default margin of 0 and relies on the amax history taking the max over recent steps to handle transients.

Complete FP8 Linear Layer

Here is a full FP8 linear layer implementation with delayed scaling, suitable for both training and inference:

class FP8Linear(torch.nn.Module):
    """Linear layer with FP8 compute and delayed scaling.

    Forward:  Y = (X_e4m3 @ W_e4m3^T) * (sx * sw)  -- FP8 GEMM, FP32 accum
    Backward: dX = (dY_e5m2 @ W_e4m3) * (sdy * sw)  -- FP8 GEMM
              dW = (dY_e5m2^T @ X_e4m3) * (sdy * sx)  -- FP8 GEMM
    """

    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Master weights in FP32
        self.weight = torch.nn.Parameter(
            torch.randn(out_features, in_features) * (2 / in_features) ** 0.5
        )
        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(out_features))
        else:
            self.bias = None

        # FP8 scaling managers
        self.fp8_manager = FP8TensorManager()

        # Cached FP8 weight (recomputed when scale changes)
        self._cached_weight_fp8 = None
        self._cached_weight_scale = None

    def cast_to_fp8_e4m3(self, tensor, scale):
        """Cast tensor to FP8 E4M3 (simulated)."""
        scaled = tensor.float() * scale
        clamped = scaled.clamp(-448.0, 448.0)
        # Record amax for next step's scale computation
        amax = tensor.abs().max().item()
        return clamped, amax

    def cast_to_fp8_e5m2(self, tensor, scale):
        """Cast tensor to FP8 E5M2 (simulated)."""
        scaled = tensor.float() * scale
        clamped = scaled.clamp(-57344.0, 57344.0)
        amax = tensor.abs().max().item()
        return clamped, amax

    def forward(self, x):
        """FP8 forward pass."""
        # Get delayed scales
        s_input, s_weight = self.fp8_manager.get_forward_scales()

        # Cast input to E4M3
        x_fp8, x_amax = self.cast_to_fp8_e4m3(x, s_input)

        # Cast weight to E4M3 (cache if scale unchanged)
        if (self._cached_weight_scale is None or
                self._cached_weight_scale != s_weight):
            self._cached_weight_fp8, w_amax = self.cast_to_fp8_e4m3(
                self.weight.data, s_weight
            )
            self._cached_weight_scale = s_weight
        else:
            w_amax = self.weight.data.abs().max().item()

        # FP8 GEMM with FP32 accumulation
        # On real hardware: cublasFp8Gemm
        y_fp32 = torch.matmul(x_fp8, self._cached_weight_fp8.T)

        # Dequantize: divide by both scales
        y_fp32 = y_fp32 / (s_input * s_weight)

        # Update delayed scaling with current amax
        self.fp8_manager.update_forward(x_amax, w_amax)

        if self.bias is not None:
            y_fp32 = y_fp32 + self.bias

        return y_fp32

def fp8_training_step(model, optimizer, data, target, loss_fn):
    """Single training step with FP8 linear layers."""
    optimizer.zero_grad()

    # Forward pass: FP8 GEMMs with E4M3
    output = model(data)
    loss = loss_fn(output, target)

    # Backward pass: FP8 GEMMs with E5M2
    # (In practice, the autograd backward through FP8Linear
    #  uses E5M2 for gradient tensors)
    loss.backward()

    # Optimizer step in FP32 (master weights)
    optimizer.step()

    return loss.item()

Transformer Engine Integration

NVIDIA Transformer Engine wraps the FP8 complexity behind a drop-in API. You replace torch.nn.Linear with te.Linear and the framework handles all scaling, casting, and history management automatically.

# Standard PyTorch
import torch.nn as nn
layer = nn.Linear(4096, 4096)

# Transformer Engine FP8 equivalent
# import transformer_engine.pytorch as te
# layer = te.Linear(4096, 4096)

# The te.Linear layer:
# Maintains delayed scaling state for input, weight, and gradient
# Casts input and weight to E4M3 on forward
# Runs the GEMM on FP8 tensor cores
# Accumulates in FP32
# Casts gradient to E5M2 on backward
# Updates amax history after each step

The integration into a training loop requires wrapping the forward pass in an FP8 context manager:

def train_with_transformer_engine(model, optimizer, dataloader, num_steps):
    """Training loop with Transformer Engine FP8.

    Requires: NVIDIA H100 or later, Transformer Engine installed.
    """
    # import transformer_engine.pytorch as te

    model.train()
    for step, (data, target) in enumerate(dataloader):
        if step >= num_steps:
            break

        optimizer.zero_grad()

        # The fp8_autocast context manages all FP8 state
        # with te.fp8_autocast(enabled=True):
        #     output = model(data)
        #     loss = loss_fn(output, target)

        # Simulated version (without real TE):
        output = model(data)
        loss = torch.nn.functional.cross_entropy(output, target)

        loss.backward()
        optimizer.step()

        if step % 100 == 0:
            print(f"Step {step}, loss={loss.item():.4f}")
๐Ÿ“Š

Transformer Engine FP8 vs BF16 Training Performance

ModelGPUBF16 (tokens/sec)FP8 TE (tokens/sec)Speedup
GPT-3 175B 256x H100 ~12,000 ~20,400 1.70x
Llama 2 70B 64x H100 ~8,500 ~14,000 1.65x
Llama 7B 8x H100 ~45,000 ~76,000 1.69x
Mistral 7B 8x H100 ~48,000 ~80,000 1.67x
Note: FP8 training throughput gain is consistently 1.6-1.7x over BF16 on H100. The theoretical 2x is reduced by non-GEMM operations running in BF16 and memory access patterns.

DeepSeek V3: FP8 Training at 671B Scale

DeepSeek V3 is the most significant real-world validation of FP8 training. They trained a 671B MoE model (37B active parameters per token) using FP8 on 2048 H800 GPUs for approximately 14.8 trillion tokens. Key technical decisions:

FP8 for all GEMMs in the forward pass. Attention QKV projections, output projections, MoE expert layers, and gating layers all use FP8 E4M3.

BF16 for the backward pass. DeepSeek chose not to use E5M2 for gradients. They found that BF16 backward was more stable for their MoE training setup, where the combination of expert routing dynamics and gradient sparsity made FP8 gradients risky.

Fine-grained quantization. Instead of per-tensor scaling, DeepSeek used per-group FP8 quantization with a group size of 128. Each group of 128 elements has its own FP8 scale factor, similar to the per-group approach used in INT4 weight quantization. This provided better quality than per-tensor scaling at the cost of slightly more metadata.

Online scaling instead of delayed scaling. DeepSeek computed the actual amax of each block and used it immediately, avoiding the staleness risk of delayed scaling. Their custom kernels fused the amax computation with the quantization, avoiding the separate memory read that makes just-in-time scaling expensive.

def deepseek_fp8_gemm(x, weight, group_size=128):
    """Simulate DeepSeek V3's fine-grained FP8 GEMM approach.

    Key differences from standard Transformer Engine:
    1. Per-group scaling (group_size=128) instead of per-tensor
    2. Online scaling (compute amax and quantize in one pass)
    3. Only forward pass in FP8; backward in BF16
    """
    batch_tokens, hidden = x.shape
    out_features = weight.shape[0]

    # Per-group quantize input
    x_groups = x.reshape(batch_tokens, -1, group_size)
    x_amax = x_groups.abs().amax(dim=2, keepdim=True)
    x_scale = 448.0 / x_amax.clamp(min=1e-12)
    x_fp8 = (x_groups * x_scale).clamp(-448, 448)
    x_fp8 = x_fp8.reshape(batch_tokens, hidden)

    # Per-group quantize weight
    w_groups = weight.reshape(out_features, -1, group_size)
    w_amax = w_groups.abs().amax(dim=2, keepdim=True)
    w_scale = 448.0 / w_amax.clamp(min=1e-12)
    w_fp8 = (w_groups * w_scale).clamp(-448, 448)
    w_fp8 = w_fp8.reshape(out_features, hidden)

    # FP8 GEMM (simulated)
    y = x_fp8 @ w_fp8.T

    # Dequantize: need to account for per-group scales
    # In practice, this is handled by the GEMM kernel itself
    x_dequant_scale = (1.0 / x_scale).reshape(batch_tokens, -1)
    w_dequant_scale = (1.0 / w_scale).reshape(out_features, -1)

    # Simplified: approximate dequantization
    # Real implementation accumulates partial products with proper scaling
    y_approx = x @ weight.T  # Placeholder for properly scaled result

    return y_approx
๐Ÿ’ก DeepSeek V3 Training Cost

DeepSeek V3 trained for 2.788 million H800 GPU hours at a total cost of approximately 5.576million(at5.576 million (at 2/GPU-hour). FP8 training reduced the cost by an estimated 40% compared to BF16 training for the same model and data scale. Without FP8, the training would have required either more GPUs or more time.

FP8 Inference: Simpler Than Training

FP8 inference is simpler than FP8 training because:

  1. No backward pass (no E5M2 needed)
  2. Weights are static (quantize once, use forever)
  3. Scale factors for weights can be computed offline with calibration data

The inference recipe:

class FP8InferenceLinear:
    """FP8 linear layer optimized for inference.

    Weights are pre-quantized to E4M3 with offline calibration.
    Activations are dynamically quantized per-tensor or per-token.
    """

    def __init__(self, weight_fp8, weight_scale, activation_scale=None):
        """
        weight_fp8: (out_features, in_features) pre-quantized E4M3
        weight_scale: scalar or per-channel scale for weight
        activation_scale: optional static scale (if None, uses dynamic)
        """
        self.weight_fp8 = weight_fp8
        self.weight_scale = weight_scale
        self.static_act_scale = activation_scale

    @classmethod
    def from_float(cls, linear, calibration_data=None):
        """Quantize a float linear layer for FP8 inference."""
        weight = linear.weight.data.float()

        # Quantize weight to E4M3
        w_amax = weight.abs().max()
        w_scale = 448.0 / w_amax.item()
        weight_fp8 = (weight * w_scale).clamp(-448, 448)

        # Optional: compute static activation scale from calibration
        act_scale = None
        if calibration_data is not None:
            max_act = 0.0
            for x in calibration_data:
                if x.dim() == 3:
                    x = x.reshape(-1, x.shape[-1])
                max_act = max(max_act, x.abs().max().item())
            act_scale = 448.0 / max_act

        return cls(weight_fp8, w_scale, act_scale)

    def forward(self, x):
        """FP8 inference forward pass."""
        if x.dim() == 3:
            batch, seq, hidden = x.shape
            x = x.reshape(-1, hidden)
            reshape_back = True
        else:
            reshape_back = False

        # Quantize activation
        if self.static_act_scale is not None:
            act_scale = self.static_act_scale
        else:
            # Dynamic per-tensor scaling
            act_scale = 448.0 / x.abs().max().clamp(min=1e-12).item()

        x_fp8 = (x * act_scale).clamp(-448, 448)

        # FP8 GEMM
        y = x_fp8 @ self.weight_fp8.T

        # Dequantize
        y = y / (act_scale * self.weight_scale)

        if reshape_back:
            y = y.reshape(batch, seq, -1)

        return y

Inference Throughput by Precision (Llama 70B, H100, Batch=32)

(tokens/sec)
FP16
1,800 tokens/sec
INT8 W8A8 1.72x
3,100 tokens/sec
FP8 E4M3 1.89x
3,400 tokens/sec
INT4 W4A16 2.11x
3,800 tokens/sec

FP8 vs INT8: Why FP8 Is Better for Activations

FP8 has a fundamental advantage over INT8 for activation quantization: non-uniform quantization levels. INT8 has uniform spacing โ€” 256 values evenly distributed across the range. FP8 E4M3 has 240 values with logarithmic spacing โ€” denser near zero, sparser at large values.

Neural network activations typically follow a distribution with most values near zero and a long tail. FP8โ€™s logarithmic spacing matches this distribution naturally, providing more resolution where the data density is highest.

def compare_fp8_vs_int8_coverage(data):
    """Compare how well FP8 and INT8 cover a realistic activation distribution."""
    import numpy as np

    data_np = data.numpy().flatten()
    data_abs = np.abs(data_np)
    amax = data_abs.max()

    # INT8: 256 uniform levels
    int8_scale = amax / 127.0
    int8_levels = np.arange(-128, 128) * int8_scale
    int8_q = np.round(data_np / int8_scale).clip(-128, 127) * int8_scale
    int8_mse = np.mean((data_np - int8_q) ** 2)

    # FP8 E4M3: ~240 non-uniform levels
    # Generate all positive E4M3 values
    e4m3_values = []
    for exp in range(16):
        for mant in range(8):
            if exp == 15 and mant == 7:
                continue  # NaN
            if exp == 0:
                val = (mant / 8.0) * (2 ** -6)
            else:
                val = (1.0 + mant / 8.0) * (2 ** (exp - 7))
            e4m3_values.append(val)
    e4m3_values = np.array(sorted(set(e4m3_values)))

    fp8_scale = 448.0 / amax
    scaled = data_np * fp8_scale

    # Quantize to nearest E4M3 value
    fp8_q = np.zeros_like(scaled)
    for i, val in enumerate(scaled):
        sign = 1 if val >= 0 else -1
        abs_val = abs(val)
        idx = np.argmin(np.abs(e4m3_values - abs_val))
        fp8_q[i] = sign * e4m3_values[idx]

    fp8_recon = fp8_q / fp8_scale
    fp8_mse = np.mean((data_np - fp8_recon) ** 2)

    print(f"INT8 MSE: {int8_mse:.8f}")
    print(f"FP8  MSE: {fp8_mse:.8f}")
    print(f"FP8 advantage: {int8_mse / fp8_mse:.2f}x lower error")

    return int8_mse, fp8_mse

# Test with realistic activation distribution
activation = torch.randn(10000) * 0.5
activation[torch.randperm(10000)[:100]] *= 20  # Add outliers
compare_fp8_vs_int8_coverage(activation)
๐Ÿ“Š

FP8 vs INT8 Activation Quantization Quality

ScenarioINT8 MSEFP8 E4M3 MSEFP8 Advantage
Gaussian (no outliers) 1.2e-5 1.8e-5 0.67x (INT8 wins)
Gaussian + 1% outliers 8.4e-4 3.1e-4 2.7x
Gaussian + 5% outliers 4.2e-3 9.8e-4 4.3x
Real LLM activations 2.1e-3 6.2e-4 3.4x
Note: FP8 outperforms INT8 when outliers are present (the realistic case for LLMs). For pure Gaussian data without outliers, INT8's uniform spacing is actually slightly better.

FP8 Training Stability Considerations

FP8 training can be unstable if not managed carefully. Key failure modes and mitigations:

Loss spikes: Caused by sudden distribution shifts (e.g., encountering a batch with unusual data). Delayed scaling uses a stale scale factor that is too small, causing overflow. Mitigation: increase the amax history length or add a safety margin.

Gradient underflow: Small gradients in E5M2 can underflow to zero, causing dead parameters. Mitigation: monitor the fraction of zero gradients and increase the gradient scale if it exceeds a threshold.

Accumulation error: FP8 GEMMs accumulate in FP32, but the input values are already quantized. For very large matrix dimensions (e.g., 16K hidden size), the accumulation of many small quantization errors can be significant. Mitigation: use block-wise FP8 with smaller blocks.

class FP8TrainingMonitor:
    """Monitor FP8 training health metrics."""

    def __init__(self):
        self.overflow_count = 0
        self.underflow_count = 0
        self.total_count = 0

    def check_tensor(self, tensor, name, fp8_max=448.0):
        """Check a tensor for FP8 scaling issues."""
        self.total_count += 1
        amax = tensor.abs().max().item()

        if amax > fp8_max:
            self.overflow_count += 1
            print(f"WARNING: {name} amax={amax:.2f} exceeds FP8 max={fp8_max}")

        zero_frac = (tensor == 0).float().mean().item()
        if zero_frac > 0.5:
            self.underflow_count += 1
            print(f"WARNING: {name} has {zero_frac:.1%} zeros (potential underflow)")

        return {
            'name': name,
            'amax': amax,
            'zero_fraction': zero_frac,
            'overflow_risk': amax > 0.9 * fp8_max,
        }

    def report(self):
        """Print summary of FP8 health metrics."""
        print(f"FP8 Health: {self.overflow_count} overflows, "
              f"{self.underflow_count} underflows out of "
              f"{self.total_count} tensors checked")

Summary

FP8 delivers roughly 2x throughput improvement over BF16/FP16 for both training and inference, with minimal quality degradation when properly managed.

E4M3 provides precision-optimized 8-bit representation for forward pass tensors (weights and activations). Its maximum value of 448 requires scaling, but the 3 mantissa bits preserve adequate precision.

E5M2 provides range-optimized 8-bit representation for backward pass gradients. Its maximum of 57344 and 5 exponent bits handle the extreme dynamic range of gradients without loss scaling.

Delayed scaling is the algorithm that makes FP8 practical: it amortizes the cost of computing scale factors by using amax values from previous steps, updated as a side-effect of the GEMM kernel.

NVIDIA Transformer Engine wraps all of this complexity behind a simple API, managing scale factors, amax history, and format selection automatically.

DeepSeek V3 proved FP8 training works at 671B scale, training their flagship model with FP8 forward pass and BF16 backward pass, reducing training cost by an estimated 40%.

The next post covers the frontier beyond FP8: 4-bit floating-point formats on NVIDIA Blackwell that promise another 2x throughput improvement.