Part of Series Inference Optimization Timeline 30 of 23
1 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 2 KV Cache: The Hidden Memory Giant in LLM Serving 3 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 4 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 5 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 6 Continuous Batching: The Complete Guide to LLM Inference Scheduling 7 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 8 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 9 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 10 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 11 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 12 Mamba and State Space Models: The O(n) Alternative to Attention 13 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 14 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 15 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 16 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 17 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 18 Memory Pool Management: Slab Allocators for GPU Inference 19 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 20 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 21 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 22 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 23 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification

FP8 is the most significant precision reduction since the introduction of FP16 tensor cores in Volta. On NVIDIA H100, FP8 tensor cores deliver 1,979 TFLOPS — exactly 2x the FP16 rate of 989 TFLOPS — while consuming half the memory bandwidth per element. For large, compute-bound GEMMs (the prefill phase of LLM inference), this translates directly to a near-2x throughput improvement. For memory-bound GEMMs (decode phase), the bandwidth reduction from 2 bytes to 1 byte per element provides a 1.5-1.8x speedup depending on the shape.

The catch: FP8 has severely limited range and precision. The E4M3 format covers [448,448][-448, 448] with roughly 0.1% relative precision at best. Getting FP8 inference to work without quality loss requires understanding where the precision matters, how to compute scaling factors, which operations to quantize (GEMMs only), and which to leave in higher precision (everything else). This post covers all of it.


1. The FP8 E4M3 Format

IEEE 754 does not define an 8-bit floating point format. NVIDIA, ARM, and Intel jointly specified two FP8 variants in the OFP8 (Open FP8) standard:

E4M3: 4 exponent bits, 3 mantissa bits

Bit layout: [S][EEEE][MMM]
             1   4     3    = 8 bits

S = sign bit (0 = positive, 1 = negative)
E = exponent bits (4 bits, bias = 7)
M = mantissa bits (3 bits, implicit leading 1 for normals)

Value encoding:

For normal numbers (E0E \neq 0 and E15E \neq 15):

value=(1)S×2E7×(1+M/8)\text{value} = (-1)^S \times 2^{E - 7} \times (1 + M / 8)

For subnormal numbers (E=0E = 0, M0M \neq 0):

value=(1)S×26×(M/8)\text{value} = (-1)^S \times 2^{-6} \times (M / 8)

Special values:

  • E=15,M=7E = 15, M = 7: NaN (only one NaN encoding, unlike IEEE 754)
  • E=15,M7E = 15, M \neq 7: valid number (NOT infinity — E4M3 sacrifices infinity for range)
  • E=0,M=0E = 0, M = 0: zero

This is the key departure from IEEE 754: E4M3 has no infinity representation. The bit pattern that would be infinity in IEEE format (E=all ones,M=0E = \text{all ones}, M = 0) instead represents the value (1)S×28×1.0=±256(-1)^S \times 2^8 \times 1.0 = \pm 256. The maximum representable value is:

max=28×(1+7/8)=256×1.875=448\text{max} = 2^8 \times (1 + 7/8) = 256 \times 1.875 = 448

📊

FP8 E4M3 Format Properties

PropertyValueComparison to FP16
Total bits 8 16
Sign bits 1 1
Exponent bits 4 5
Mantissa bits 3 10
Exponent bias 7 15
Max normal value 448 65504
Min normal value 2^-6 = 0.015625 2^-14 = 6.1e-5
Min subnormal 2^-9 = 0.001953 2^-24 = 5.96e-8
Precision (mantissa) 3 bits = 12.5% 10 bits = 0.098%
Has infinity? No Yes
NaN encodings 1 2046
Unique representable values 448 65536
Note: E4M3 sacrifices infinity and most NaN encodings to maximize the representable range. The 448 max value is sufficient for post-scaling weights and activations in transformer models.

E5M2: 5 exponent bits, 2 mantissa bits

Bit layout: [S][EEEEE][MM]
             1    5     2    = 8 bits

E5M2 follows IEEE 754 conventions: it has infinity and NaN. The tradeoff is even less precision (2 mantissa bits = 25% relative precision) but wider range (±57344\pm 57344).

When to use which:

  • E4M3 for forward pass (inference): The extra mantissa bit matters more than the range, because we can control the range via scaling factors.
  • E5M2 for backward pass (training gradients): Gradients have wider dynamic range and benefit from the larger exponent. Less relevant for inference.
ℹ️ E4M3 vs E5M2 in Practice

For LLM inference, E4M3 is the only format that matters. All FP8 inference implementations (TensorRT-LLM, vLLM, SGLang) use E4M3 for both weights and activations. E5M2 is used only during training for gradient accumulation. The rest of this post focuses exclusively on E4M3.

Representable Values

With 3 mantissa bits, the spacing between consecutive representable values within the same exponent range is:

ULP(e)=2e7/8=2e10\text{ULP}(e) = 2^{e - 7} / 8 = 2^{e - 10}

where ee is the biased exponent. This means:

  • Between 1.0 and 2.0: 8 evenly spaced values (1.0, 1.125, 1.25, 1.375, 1.5, 1.625, 1.75, 1.875)
  • Between 2.0 and 4.0: 8 values spaced by 0.25
  • Between 128 and 256: 8 values spaced by 16
  • Between 256 and 448: 8 values spaced by 32
import torch

# Enumerate all positive E4M3 values
def enumerate_e4m3_values():
    values = []

    # Subnormals: E=0, M=1..7
    for m in range(1, 8):
        val = 2**(-6) * (m / 8)
        values.append(val)

    # Normals: E=1..14, M=0..7
    for e in range(1, 15):
        for m in range(8):
            val = 2**(e - 7) * (1 + m / 8)
            values.append(val)

    # E=15, M=0..6 (M=7 is NaN)
    for m in range(7):
        val = 2**(15 - 7) * (1 + m / 8)
        values.append(val)

    return sorted(values)

e4m3_values = enumerate_e4m3_values()
print(f"Number of positive values: {len(e4m3_values)}")
# 119 positive values + 119 negative + 1 zero + 1 NaN = 240 total

print(f"Smallest subnormal: {e4m3_values[0]}")   # 0.001953125
print(f"Smallest normal:    {e4m3_values[7]}")    # 0.015625
print(f"Largest value:      {e4m3_values[-1]}")   # 448.0

The entire format has only 240 representable values (including negatives and zero). Compare with FP16’s approximately 65,000 representable values. Every quantization from FP16 to E4M3 maps approximately 273 FP16 values to a single E4M3 value.

2. Per-Tensor Dynamic Scaling

The fundamental problem with FP8’s limited range: transformer activations and weights span very different numerical ranges. Layer norm outputs might be in [3,3][-3, 3], while FFN intermediate activations after SiLU might be in [50,50][-50, 50]. If we naively cast to E4M3, small values lose all precision and large values might overflow.

The solution: per-tensor scaling. Before casting to FP8, multiply by a scale factor that maps the tensor’s range to E4M3’s representable range.

Scale Factor Computation

scale=448.0max(tensor)\text{scale} = \frac{448.0}{\max(|\text{tensor}|)}

The quantized tensor is:

tensor_fp8=cast_to_e4m3(tensor×scale)\text{tensor\_fp8} = \text{cast\_to\_e4m3}(\text{tensor} \times \text{scale})

And dequantization recovers the original:

tensor_recovered=tensor_fp8/scale\text{tensor\_recovered} = \text{tensor\_fp8} / \text{scale}

def compute_scale(tensor, fp8_max=448.0):
    """
    Compute per-tensor scale for FP8 quantization.

    The scale maps the tensor's range to [-448, 448].
    """
    amax = tensor.abs().max().float()

    # Avoid division by zero
    if amax == 0:
        return torch.tensor(1.0, device=tensor.device)

    scale = fp8_max / amax

    # Clamp scale to avoid overflow in the scale factor itself
    # Scale is stored in FP32 — no precision concerns
    return scale.clamp(min=1e-12)


def quantize_to_fp8(tensor, scale):
    """
    Quantize a tensor to FP8 E4M3.

    Args:
        tensor: FP16/BF16/FP32 tensor
        scale: per-tensor scale factor (FP32 scalar)
    Returns:
        fp8_tensor: quantized tensor in torch.float8_e4m3fn
    """
    # Scale and clamp to E4M3 range
    scaled = tensor.float() * scale
    scaled = scaled.clamp(-448.0, 448.0)

    # Cast to FP8
    fp8_tensor = scaled.to(torch.float8_e4m3fn)

    return fp8_tensor


def dequantize_from_fp8(fp8_tensor, scale):
    """Dequantize FP8 tensor back to FP16."""
    return fp8_tensor.float() / scale

Per-Tensor vs Per-Channel vs Per-Token Scaling

The granularity of the scale factor affects both accuracy and performance:

Per-tensor:  one scale for the entire [M, K] or [K, N] matrix
Per-channel: one scale per column (weights) or per row (activations)
Per-token:   one scale per row in the activation tensor
Per-group:   one scale per group of G elements (e.g., G=128)
📊

FP8 Scaling Granularity Comparison

GranularityScales Count (weights [K,N])Scales Count (acts [M,K])AccuracyHW Support (H100)
Per-tensor 1 1 Baseline Native TMA
Per-channel (weights) N +0.1-0.3% acc Requires custom kernel
Per-token (activations) M +0.2-0.5% acc Supported via row scaling
Per-group (G=128) K*N/G M*K/G +0.5-1.0% acc Not natively supported
Note: H100 tensor cores natively support per-tensor scaling for both operands. Per-token scaling for activations is supported via a row-wise scale vector. Finer granularities require decomposing the GEMM or using custom epilogues.

In practice, the standard approach for FP8 inference is:

  • Weights: Per-tensor or per-channel scaling, computed offline during calibration
  • Activations: Per-tensor or per-token scaling, computed dynamically at runtime

Per-tensor scaling for both operands is the simplest and best-supported path on H100:

# FP8 GEMM with per-tensor scaling:
# C = (A_fp8 / scale_A) @ (B_fp8 / scale_B)
#   = (A_fp8 @ B_fp8) / (scale_A * scale_B)
#
# The tensor core computes A_fp8 @ B_fp8 in FP8,
# accumulates in FP32, and the epilogue divides by
# (scale_A * scale_B). Single kernel, no overhead.
Scale Factor Overhead

Per-tensor scaling adds exactly ONE FP32 division in the GEMM epilogue per output element. Since the GEMM itself performs 2K2K FLOPs per output element (where KK is thousands to tens of thousands), the scaling overhead is negligible — less than 0.01% of total compute.

3. Which Operations Use FP8

Not every operation in a transformer can use FP8. The rule is:

FP8 for GEMMs. Higher precision for everything else.

Operations That Use FP8

OperationInput PrecisionWeight PrecisionAccumulationOutput
QKV projectionFP8 (E4M3)FP8 (E4M3)FP32BF16
Output projectionFP8FP8FP32BF16
Gate projectionFP8FP8FP32BF16
Up projectionFP8FP8FP32BF16
Down projectionFP8FP8FP32BF16

Operations That Stay in BF16/FP16

OperationWhy Not FP8
Layer norm / RMS normRequires high-precision running statistics. 3 mantissa bits produce incorrect variance.
SoftmaxExponential and division are numerically sensitive. The log-sum-exp trick requires precision.
SiLU / GELU activationNon-linear; small input differences produce large output differences.
Residual additionAccumulates across layers. FP8 rounding errors compound.
Rotary embeddingsSine/cosine computation requires precision.
Embedding lookupTable lookup, not compute. No benefit from FP8.
Attention scores (QK)Typically done in FP16/BF16 within FlashAttention. Could use FP8 but quality degrades.
Attention values (PV)Same as above.
class FP8TransformerLayer:
    """
    Transformer layer with FP8 GEMMs and BF16 everything else.
    """

    def __init__(self, config):
        # FP8 weights (quantized offline)
        self.qkv_weight_fp8 = None      # [d, n_h*d_h + 2*n_kv*d_h] in E4M3
        self.qkv_scale = None            # FP32 scalar
        self.o_weight_fp8 = None         # [n_h*d_h, d] in E4M3
        self.o_scale = None
        self.gate_up_weight_fp8 = None   # [d, 2*d_ff] in E4M3
        self.gate_up_scale = None
        self.down_weight_fp8 = None      # [d_ff, d] in E4M3
        self.down_scale = None

        # BF16 parameters (NOT quantized)
        self.rms_norm_weight = None      # [d] in BF16
        self.rms_norm2_weight = None     # [d] in BF16

    def forward(self, x):
        """
        x: [B, d] in BF16

        All GEMMs use FP8. All other ops use BF16.
        """
        # ---- RMS Norm (BF16) ----
        normed = rms_norm(x, self.rms_norm_weight)  # BF16

        # ---- QKV Projection (FP8 GEMM) ----
        act_scale = compute_scale(normed)
        normed_fp8 = quantize_to_fp8(normed, act_scale)
        qkv = fp8_gemm(
            normed_fp8, act_scale,
            self.qkv_weight_fp8, self.qkv_scale
        )  # Output in BF16

        # ---- Attention (BF16 — FlashAttention) ----
        q, k, v = split_qkv(qkv)
        attn_out = flash_attention(q, k, v)  # BF16

        # ---- Output Projection (FP8 GEMM) ----
        act_scale = compute_scale(attn_out)
        attn_fp8 = quantize_to_fp8(attn_out, act_scale)
        o = fp8_gemm(
            attn_fp8, act_scale,
            self.o_weight_fp8, self.o_scale
        )  # BF16

        # ---- Residual Add (BF16) ----
        x = x + o

        # ---- RMS Norm 2 (BF16) ----
        normed2 = rms_norm(x, self.rms_norm2_weight)

        # ---- Gate + Up Projection (FP8 GEMM) ----
        act_scale = compute_scale(normed2)
        normed2_fp8 = quantize_to_fp8(normed2, act_scale)
        gate_up = fp8_gemm(
            normed2_fp8, act_scale,
            self.gate_up_weight_fp8, self.gate_up_scale
        )  # BF16

        # ---- SiLU Activation (BF16) ----
        gate, up = gate_up.chunk(2, dim=-1)
        intermediate = torch.nn.functional.silu(gate) * up  # BF16

        # ---- Down Projection (FP8 GEMM) ----
        act_scale = compute_scale(intermediate)
        inter_fp8 = quantize_to_fp8(intermediate, act_scale)
        down = fp8_gemm(
            inter_fp8, act_scale,
            self.down_weight_fp8, self.down_scale
        )  # BF16

        # ---- Residual Add (BF16) ----
        x = x + down

        return x
⚠️ Attention in FP8 Is Risky

Some implementations quantize the QK dot product and PV multiply to FP8. This can work for large models (70B+) where individual head dimensions are large, but degrades quality for smaller models. The attention mechanism is particularly sensitive because softmax amplifies small numerical errors — a 1% error in attention scores can become a 5-10% error in attention weights after exponentiation. The safe default is to keep attention in BF16 and only quantize the linear projections.

4. H100 FP8 Tensor Cores: 1,979 TFLOPS

The H100 SXM5 delivers these peak throughput numbers:

📊

H100 SXM5 Tensor Core Throughput by Precision

PrecisionPeak TFLOPSBytes/ElementBandwidth-Equivalent TFLOPSRidge Point (FLOP/byte)
FP64 67 8 20
TF32 495 4 148
FP16 989 2 989 295
BF16 989 2 989 295
FP8 (E4M3) 1979 1 1979 591
INT8 1979 1 1979 591
Note: Peak TFLOPS with sparsity disabled. HBM3 bandwidth: 3.35 TB/s. Ridge point = peak TFLOPS / BW. FP8 doubles compute at the same memory bandwidth, so the ridge point doubles — harder to saturate.

Why Exactly 2x

The H100 tensor core pipeline processes two FP8 elements in the same cycle that it processes one FP16 element. The FP8 MMA instruction shape is m16n8k32 (K=32 for FP8 vs K=16 for FP16), meaning each instruction processes twice as many multiply-accumulate operations. The accumulator is FP32 in both cases.

FP16 MMA: m16n8k16 -> 16*8*16*2 = 4096 FLOPs per instruction
FP8 MMA:  m16n8k32 -> 16*8*32*2 = 8192 FLOPs per instruction

Same number of instructions per cycle -> 2x FLOPs

Practical FP8 Throughput

The 2x is a peak number. Actual throughput depends on arithmetic intensity:

FP8 vs FP16 Actual Throughput: FFN GEMM [M, 57344, 8192]

(TFLOPS)
FP16 B=1
3.2 TFLOPS
FP8 B=1
3.3 TFLOPS
FP16 B=128
385 TFLOPS
FP8 B=128
410 TFLOPS
FP16 B=512
790 TFLOPS
FP8 B=512
1,120 TFLOPS
FP16 B=2048
920 TFLOPS
FP8 B=2048
1,720 TFLOPS
FP16 peak
989 TFLOPS
FP8 peak
1,979 TFLOPS

At B=1B = 1 (decode), both FP8 and FP16 are memory-bandwidth-bound. FP8 loads 1 byte per weight instead of 2, but the GEMM shape [1,N,K][1, N, K] means the weight matrix dominates bandwidth. The speedup from halving weight bytes is offset by the same activation reads and output writes. Net: approximately 1.03x.

At B=512B = 512, FP8 achieves 1120 TFLOPS (57% of peak) vs FP16 at 790 TFLOPS (80% of peak). The FP8 speedup is 1.42x — less than 2x because the FP8 GEMM has not yet reached its ridge point.

At B=2048B = 2048, FP8 achieves 1720 TFLOPS (87% of peak) vs FP16 at 920 TFLOPS (93% of peak). The speedup is 1.87x — approaching but not reaching 2x because of memory traffic for activations.

FP8 Speedup Depends on Batch Size

The FP8 speedup over FP16 ranges from 1.0x (decode, B=1B = 1) to 1.95x (large prefill, B=4096+B = 4096+). For a serving system running a mix of prefill and decode, the aggregate speedup is typically 1.3-1.6x at realistic batch sizes. Claims of “2x from FP8” assume compute-bound regimes that many workloads do not reach.

5. Quantization Workflow

The full FP8 quantization workflow has three phases: calibration, offline weight quantization, and online activation quantization.

Phase 1: Calibration

Run a representative dataset through the model in FP16/BF16. For each tensor that will be quantized, record the maximum absolute value (amax):

class CalibrationObserver:
    """
    Collect activation statistics for FP8 scale computation.
    Attach to each linear layer as a forward hook.
    """

    def __init__(self):
        self.amax_history = []
        self.max_samples = 512  # Number of calibration samples

    def observe(self, tensor):
        """Record the amax of a tensor."""
        amax = tensor.abs().max().item()
        self.amax_history.append(amax)

    def compute_scale(self, fp8_max=448.0, percentile=99.99):
        """
        Compute scale from observed amax values.

        Using the percentile instead of absolute max provides
        robustness against outliers. A single outlier value
        would set the scale too conservatively, wasting precision
        for the rest of the distribution.
        """
        import numpy as np
        amax = np.percentile(self.amax_history, percentile)
        return fp8_max / max(amax, 1e-12)


def calibrate_model(model, calibration_loader, num_batches=32):
    """
    Run calibration to determine per-tensor scales.

    Args:
        model: FP16/BF16 model
        calibration_loader: representative data
        num_batches: number of calibration batches
    Returns:
        scales: dict mapping layer_name -> {input_scale, weight_scale}
    """
    observers = {}

    # Register observers for each linear layer
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            observer_input = CalibrationObserver()
            observer_weight = CalibrationObserver()

            # Record weight amax (static — does not change)
            observer_weight.observe(module.weight.data)

            # Hook for input activations
            def make_hook(obs):
                def hook(mod, inp, out):
                    obs.observe(inp[0])
                return hook

            module.register_forward_hook(make_hook(observer_input))
            observers[name] = {
                'input': observer_input,
                'weight': observer_weight,
            }

    # Run calibration
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(calibration_loader):
            if i >= num_batches:
                break
            model(batch['input_ids'].cuda())

    # Compute scales
    scales = {}
    for name, obs in observers.items():
        scales[name] = {
            'input_scale': obs['input'].compute_scale(),
            'weight_scale': obs['weight'].compute_scale(),
        }

    return scales

Phase 2: Offline Weight Quantization

Quantize model weights to FP8 and save. This is done once.

def quantize_weights_fp8(model, scales):
    """
    Quantize all linear layer weights to FP8 E4M3.

    The original FP16 weights are replaced with FP8 weights
    plus an FP32 scale factor. Memory: 1 byte/param + negligible
    scale overhead (one FP32 per tensor).
    """
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear) and name in scales:
            weight = module.weight.data  # [out, in] in FP16/BF16
            scale = scales[name]['weight_scale']

            # Quantize
            weight_scaled = weight.float() * scale
            weight_scaled = weight_scaled.clamp(-448.0, 448.0)
            weight_fp8 = weight_scaled.to(torch.float8_e4m3fn)

            # Replace weight
            module.weight = torch.nn.Parameter(
                weight_fp8, requires_grad=False
            )

            # Store scale as a buffer (not a parameter)
            module.register_buffer(
                'weight_scale',
                torch.tensor(scale, dtype=torch.float32)
            )

    return model


def save_fp8_model(model, path):
    """Save FP8-quantized model."""
    state_dict = {}
    for name, param in model.named_parameters():
        state_dict[name] = param.data
    for name, buf in model.named_buffers():
        state_dict[name] = buf
    torch.save(state_dict, path)
    # Model size: ~50% of FP16 (1 byte/param vs 2 bytes/param)

Phase 3: Online Activation Quantization

During inference, activations are quantized to FP8 dynamically before each GEMM. The scale is computed on-the-fly from the activation tensor’s amax.

class FP8Linear(torch.nn.Module):
    """
    Linear layer with FP8 weights and dynamic FP8 activation
    quantization.
    """

    def __init__(self, in_features, out_features):
        super().__init__()
        # FP8 weight and scale (set during quantization)
        self.weight_fp8 = None   # [out, in] in float8_e4m3fn
        self.weight_scale = None # FP32 scalar

        # Optional: use delayed scaling (reuse previous step's scale)
        self.use_delayed_scaling = False
        self.prev_input_scale = None

    def forward(self, x):
        """
        x: [B, in_features] in BF16

        1. Dynamically quantize x to FP8
        2. Execute FP8 GEMM
        3. Descale output to BF16
        """
        # Dynamic activation scaling
        if self.use_delayed_scaling and self.prev_input_scale is not None:
            input_scale = self.prev_input_scale
        else:
            amax = x.abs().max()
            input_scale = (448.0 / amax).clamp(min=1e-12)

        # Store for next step (delayed scaling)
        if self.use_delayed_scaling:
            self.prev_input_scale = input_scale.detach()

        # Quantize activation to FP8
        x_scaled = (x.float() * input_scale).clamp(-448.0, 448.0)
        x_fp8 = x_scaled.to(torch.float8_e4m3fn)

        # FP8 GEMM with FP32 accumulation
        # Output descaling: divide by (input_scale * weight_scale)
        output = torch._scaled_mm(
            x_fp8,
            self.weight_fp8.t(),
            out_dtype=torch.bfloat16,
            scale_a=torch.tensor(1.0 / input_scale,
                                 dtype=torch.float32,
                                 device=x.device),
            scale_b=torch.tensor(1.0 / self.weight_scale,
                                 dtype=torch.float32,
                                 device=x.device),
        )

        return output
💡 Delayed Scaling Avoids the Amax Kernel

Computing x.abs().max() requires a full reduction over the activation tensor — an extra kernel launch and global memory read. Delayed scaling reuses the previous step’s scale factor, eliminating this overhead. The assumption is that activation ranges change slowly between tokens, which holds in practice for autoregressive generation. The first token uses a default scale (e.g., 1.0) or a calibrated scale.

Dynamic vs Static Scaling

ApproachScale SourceAdvantagesDisadvantages
Static (calibration)Pre-computed from calibration setNo runtime overhead, simpleMay not cover all input distributions
Dynamic (per-tensor)Computed from current tensorAdapts to actual dataExtra amax kernel per GEMM
Delayed dynamicPrevious step’s amaxMinimal overhead, adaptiveSlight staleness (1 token lag)

Production systems typically use:

  • Static scaling for weights (computed once during quantization)
  • Delayed dynamic scaling for activations (previous step’s amax)

6. Complete FP8 Inference Implementation

Here is a complete, runnable implementation of FP8 inference for a transformer model using PyTorch’s native FP8 support:

import torch
import torch.nn as nn
import torch.nn.functional as F


class RMSNorm(nn.Module):
    """RMS Normalization — always in FP32/BF16, never FP8."""

    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        # Compute in FP32 for numerical stability
        x_float = x.float()
        rms = torch.sqrt(x_float.pow(2).mean(-1, keepdim=True) + self.eps)
        normed = x_float / rms
        return (normed * self.weight.float()).to(x.dtype)


class FP8LinearLayer(nn.Module):
    """
    FP8 linear layer with per-tensor scaling.
    Uses torch._scaled_mm for FP8 GEMM on H100.
    """

    def __init__(self, in_features, out_features, bias=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Placeholder — will be set during quantization
        self.register_buffer(
            'weight_fp8',
            torch.zeros(out_features, in_features,
                        dtype=torch.float8_e4m3fn)
        )
        self.register_buffer(
            'weight_scale',
            torch.tensor(1.0, dtype=torch.float32)
        )
        self.register_buffer(
            'input_scale',
            torch.tensor(1.0, dtype=torch.float32)
        )

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.bias = None

    @torch.no_grad()
    def quantize_weight(self, weight_fp16):
        """Quantize an FP16 weight tensor to FP8."""
        amax = weight_fp16.abs().max().float()
        scale = (448.0 / amax).clamp(min=1e-12)

        w_scaled = (weight_fp16.float() * scale).clamp(-448.0, 448.0)
        self.weight_fp8.copy_(w_scaled.to(torch.float8_e4m3fn))
        self.weight_scale.fill_(scale.item())

    def forward(self, x):
        """
        x: [*, in_features] in BF16
        Returns: [*, out_features] in BF16
        """
        orig_shape = x.shape
        x = x.reshape(-1, self.in_features)  # [B, in]

        # Dynamic activation quantization
        amax = x.abs().max().float()
        act_scale = (448.0 / amax).clamp(min=1e-12)

        x_scaled = (x.float() * act_scale).clamp(-448.0, 448.0)
        x_fp8 = x_scaled.to(torch.float8_e4m3fn)

        # FP8 GEMM: x_fp8 @ weight_fp8.T
        # torch._scaled_mm handles the descaling in the epilogue
        inv_act_scale = (1.0 / act_scale).to(torch.float32)
        inv_weight_scale = (1.0 / self.weight_scale).to(torch.float32)

        output = torch._scaled_mm(
            x_fp8,                          # [B, in] E4M3
            self.weight_fp8.t().contiguous(), # [in, out] E4M3
            out_dtype=torch.bfloat16,
            scale_a=inv_act_scale,
            scale_b=inv_weight_scale,
        )  # [B, out] BF16

        if self.bias is not None:
            output = output + self.bias.to(output.dtype)

        return output.reshape(*orig_shape[:-1], self.out_features)


class FP8TransformerBlock(nn.Module):
    """
    Single transformer block with FP8 GEMMs.
    Norms, attention, activations stay in BF16.
    """

    def __init__(self, dim, n_heads, n_kv_heads, ff_dim):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = dim // n_heads

        # Norms (BF16)
        self.attn_norm = RMSNorm(dim)
        self.ffn_norm = RMSNorm(dim)

        # Attention projections (FP8)
        self.q_proj = FP8LinearLayer(dim, n_heads * self.head_dim)
        self.k_proj = FP8LinearLayer(dim, n_kv_heads * self.head_dim)
        self.v_proj = FP8LinearLayer(dim, n_kv_heads * self.head_dim)
        self.o_proj = FP8LinearLayer(n_heads * self.head_dim, dim)

        # FFN projections (FP8)
        self.gate_proj = FP8LinearLayer(dim, ff_dim)
        self.up_proj = FP8LinearLayer(dim, ff_dim)
        self.down_proj = FP8LinearLayer(ff_dim, dim)

    def forward(self, x, cos_freqs, sin_freqs, mask=None):
        """
        x: [B, S, dim] in BF16
        """
        # --- Attention ---
        residual = x
        x = self.attn_norm(x)  # BF16 norm

        # FP8 projections
        q = self.q_proj(x)     # FP8 GEMM -> BF16 output
        k = self.k_proj(x)     # FP8 GEMM -> BF16 output
        v = self.v_proj(x)     # FP8 GEMM -> BF16 output

        # Reshape for attention
        B, S, _ = q.shape
        q = q.view(B, S, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, S, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, S, self.n_kv_heads, self.head_dim).transpose(1, 2)

        # Apply rotary embeddings (BF16)
        q = apply_rotary_emb(q, cos_freqs, sin_freqs)
        k = apply_rotary_emb(k, cos_freqs, sin_freqs)

        # GQA: expand KV heads
        if self.n_kv_heads < self.n_heads:
            rep = self.n_heads // self.n_kv_heads
            k = k.repeat_interleave(rep, dim=1)
            v = v.repeat_interleave(rep, dim=1)

        # Attention (BF16 — FlashAttention)
        attn_out = F.scaled_dot_product_attention(
            q, k, v, attn_mask=mask, is_causal=True
        )  # BF16

        attn_out = attn_out.transpose(1, 2).reshape(B, S, -1)

        # Output projection (FP8 GEMM)
        attn_out = self.o_proj(attn_out)

        # Residual (BF16)
        x = residual + attn_out

        # --- FFN ---
        residual = x
        x = self.ffn_norm(x)  # BF16 norm

        # FP8 GEMMs
        gate = self.gate_proj(x)   # FP8 -> BF16
        up = self.up_proj(x)       # FP8 -> BF16

        # SiLU activation (BF16)
        x = F.silu(gate) * up

        # Down projection (FP8 GEMM)
        x = self.down_proj(x)

        # Residual (BF16)
        x = residual + x

        return x


def apply_rotary_emb(x, cos, sin):
    """Apply rotary positional embeddings. Always BF16."""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([
        x1 * cos - x2 * sin,
        x2 * cos + x1 * sin,
    ], dim=-1)

Loading and Quantizing a Pretrained Model

def convert_to_fp8(model_fp16):
    """
    Convert an FP16 model to FP8 inference.
    Replaces all nn.Linear with FP8LinearLayer.
    """
    for name, module in model_fp16.named_children():
        if isinstance(module, nn.Linear):
            fp8_layer = FP8LinearLayer(
                module.in_features,
                module.out_features,
                bias=module.bias is not None,
            ).to(module.weight.device)

            # Quantize weight
            fp8_layer.quantize_weight(module.weight.data)

            if module.bias is not None:
                fp8_layer.bias.data.copy_(module.bias.data)

            setattr(model_fp16, name, fp8_layer)
        else:
            convert_to_fp8(module)  # Recurse

    return model_fp16


# Usage:
# model = load_pretrained_model("llama-70b", dtype=torch.bfloat16)
# model = convert_to_fp8(model)
# model.eval()
#
# Memory: 70B params * 1 byte = 70 GB (vs 140 GB in FP16)
# Throughput: 1.3-1.9x depending on batch size

Memory Savings

📊

Memory Comparison: FP16 vs FP8 (Llama 70B, H100 80GB)

ComponentFP16 SizeFP8 SizeSavings
Model weights 140 GB 70 GB 50%
Scale factors 0 GB ~0.001 GB Negligible overhead
KV cache (batch=64, seq=4096) 85.9 GB 85.9 GB (still BF16) 0%
Activation memory ~2 GB ~2 GB (still BF16) 0%
Total (batch=64) 228 GB 158 GB 31%
Note: FP8 halves weight memory but KV cache and activations stay in BF16. For large-batch serving where KV cache dominates, FP8 weight quantization alone has limited memory impact — combine with KV cache quantization for maximum effect.
⚠️ FP8 Weights Alone May Not Be Enough

At batch size 64 with sequence length 4096, the KV cache for Llama 70B in BF16 is 85.9 GB — larger than the FP8 weight savings of 70 GB. To maximize serving capacity, combine FP8 weight quantization with FP8 KV cache quantization (reducing KV cache from 85.9 GB to 42.9 GB). The combination reduces total memory from 228 GB to approximately 115 GB — fitting on two H100s instead of three.

7. Quality Impact and When FP8 Fails

FP8 inference is not lossless. The 3 mantissa bits introduce quantization error on every weight and activation. For most production models, the accuracy degradation is small enough to be acceptable.

📊

FP8 Quality Impact: Perplexity on WikiText-2

ModelFP16 PPLFP8 (static scale) PPLFP8 (dynamic scale) PPLDegradation
Llama 2 7B 5.47 5.58 5.51 +0.04-0.11
Llama 2 13B 4.88 4.96 4.91 +0.03-0.08
Llama 2 70B 3.32 3.35 3.33 +0.01-0.03
Mistral 7B 5.25 5.39 5.30 +0.05-0.14
Llama 3 8B 6.14 6.38 6.21 +0.07-0.24
Llama 3 70B 2.86 2.89 2.87 +0.01-0.03
Note: Dynamic per-tensor scaling consistently outperforms static scaling. Larger models degrade less — more parameters means each individual parameter's quantization error has less impact. Llama 3 8B shows larger degradation than Llama 2 7B, likely due to tighter weight distributions from more aggressive training.

When FP8 Produces Unacceptable Quality

  1. Small models (less than 3B parameters): Each weight carries more information. FP8 quantization error is proportionally larger. Consider INT8 weight-only quantization instead, which preserves activations in FP16.

  2. Models with outlier channels: Some transformer models develop outlier features — channels where activation values are 10-100x larger than the rest. Per-tensor scaling is dominated by these outliers, causing severe precision loss for normal-range values. SmoothQuant-style techniques migrate the outlier magnitude from activations to weights before quantization.

  3. Fine-tuned models with narrow weight distributions: Models fine-tuned on narrow domains (e.g., medical, legal) may have weights concentrated in a very small range. FP8’s 3 mantissa bits may not provide enough resolution to distinguish between close weight values.

  4. Long-context generation: Quantization errors accumulate across the sequence length through the residual stream. At 100K+ tokens, the accumulated error from 80 layers of FP8 GEMMs can produce noticeably different outputs from FP16. This is model-dependent and difficult to predict without testing.

def check_fp8_compatibility(model):
    """
    Quick diagnostic: check for conditions that make FP8
    quantization risky.
    """
    warnings = []

    for name, param in model.named_parameters():
        if 'weight' not in name:
            continue

        w = param.data.float()
        amax = w.abs().max().item()
        mean_abs = w.abs().mean().item()
        std = w.std().item()

        # Check 1: Outlier ratio
        outlier_ratio = amax / mean_abs
        if outlier_ratio > 20:
            warnings.append(
                f"{name}: outlier ratio {outlier_ratio:.1f} "
                f"(amax={amax:.3f}, mean_abs={mean_abs:.5f}). "
                f"Consider SmoothQuant."
            )

        # Check 2: Very small dynamic range
        dynamic_range = amax / (w.abs()[w.abs() > 0].min().item())
        if dynamic_range > 1000:
            warnings.append(
                f"{name}: dynamic range {dynamic_range:.0f}. "
                f"FP8 may not resolve small values."
            )

        # Check 3: Near-zero standard deviation
        if std < 0.001:
            warnings.append(
                f"{name}: std={std:.6f}. Very narrow distribution, "
                f"FP8 quantization noise may dominate."
            )

    return warnings

8. Hardware Support Matrix

FP8 is not universally available. Here is the current support landscape:

📊

FP8 Hardware Support (as of early 2025)

HardwareFP8 SupportE4M3E5M2FP8 Tensor CoresFP8 TFLOPS
H100 SXM Yes Yes Yes Yes 1979
H100 PCIe Yes Yes Yes Yes 1513
H200 SXM Yes Yes Yes Yes 1979
L40S Yes Yes Yes Yes 733
A100 No
A10G No
RTX 4090 (Ada) Yes Yes Yes Yes 660
RTX 3090 (Ampere) No
AMD MI300X Yes (OCP FP8) Yes Yes Yes ~2600
Intel Gaudi 2 Yes Yes No MME only ~600
Note: FP8 requires Hopper (SM 9.0) or Ada Lovelace (SM 8.9) on NVIDIA. Ampere (SM 8.0) does not support FP8. AMD MI300X supports FP8 via OCP (Open Compute Platform) standard which is compatible with E4M3/E5M2.

Software Requirements

NVIDIA FP8:
  - CUDA 12.0+
  - cuDNN 8.9+
  - PyTorch 2.1+ (for torch.float8_e4m3fn dtype)
  - PyTorch 2.4+ (for torch._scaled_mm with proper H100 support)
  - Driver 525.60+

Frameworks with FP8 support:
  - TensorRT-LLM: Native FP8 since v0.5
  - vLLM: FP8 quantization via compressed-tensors (v0.4+)
  - SGLang: FP8 via torch._scaled_mm (v0.2+)
  - DeepSpeed: FP8 via Transformer Engine integration
  - NVIDIA Transformer Engine: The reference FP8 library

Transformer Engine Integration

NVIDIA’s Transformer Engine library provides the most optimized FP8 path:

import transformer_engine.pytorch as te

# Replace nn.Linear with te.Linear for automatic FP8
class TETransformerLayer(nn.Module):
    def __init__(self, dim, ff_dim):
        super().__init__()
        # Transformer Engine handles FP8 quantization internally
        self.qkv_proj = te.Linear(dim, 3 * dim, bias=False)
        self.o_proj = te.Linear(dim, dim, bias=False)
        self.gate_proj = te.Linear(dim, ff_dim, bias=False)
        self.up_proj = te.Linear(dim, ff_dim, bias=False)
        self.down_proj = te.Linear(ff_dim, dim, bias=False)

        # Norms integrated into TE layers
        self.attn_norm = te.LayerNorm(dim)
        self.ffn_norm = te.LayerNorm(dim)

    def forward(self, x):
        # Transformer Engine automatically:
        # 1. Manages FP8 scaling (delayed dynamic)
        # 2. Quantizes activations before each GEMM
        # 3. Accumulates in FP32
        # 4. Outputs in BF16
        # All within optimized fused kernels

        with te.fp8_autocast(enabled=True):
            normed = self.attn_norm(x)
            qkv = self.qkv_proj(normed)
            # ... attention ...
            o = self.o_proj(attn_out)
            x = x + o

            normed2 = self.ffn_norm(x)
            gate = self.gate_proj(normed2)
            up = self.up_proj(normed2)
            x = x + self.down_proj(F.silu(gate) * up)

        return x

Transformer Engine provides several advantages over manual FP8:

  • Delayed scaling with automatic amax history management
  • Fused GEMM + scaling kernels (no separate quantize kernel)
  • Automatic mixed-precision recipes (which layers use FP8)
  • Communication-efficient FP8 for tensor parallelism

Key Takeaways

  1. E4M3 is the inference format: 4 exponent bits, 3 mantissa bits, range [448,448][-448, 448], no infinity. 240 total representable values. Use E5M2 only for training gradients.

  2. Per-tensor scaling is mandatory: Without scaling, FP8’s limited range causes overflow or severe precision loss. Scale = 448.0 / max(abs(tensor)). One FP32 scalar per tensor.

  3. GEMMs only: Quantize linear projections (QKV, output, FFN) to FP8. Keep norms, softmax, activations, residuals, and embeddings in BF16. Quantizing non-GEMM operations degrades quality with no throughput benefit.

  4. 2x throughput is the ceiling, not the floor: H100 FP8 tensor cores deliver 1,979 vs 989 TFLOPS (FP16). The actual speedup depends on arithmetic intensity — 1.0x at B=1B = 1 (memory-bound), 1.9x at B=2048+B = 2048+ (compute-bound). Typical serving workloads see 1.3-1.6x.

  5. Larger models quantize better: Llama 70B loses 0.01-0.03 perplexity points; Llama 7B loses 0.04-0.11. Each parameter carries less marginal information in larger models, so quantization noise has less impact.

  6. Combine with KV cache quantization: FP8 weights save 70 GB for a 70B model. But KV cache at BF16 can exceed 85 GB at high batch sizes. FP8 KV cache quantization provides additive savings.

  7. Check hardware support: FP8 requires Hopper (H100/H200) or Ada Lovelace (L40S, RTX 4090). A100 and older GPUs do not have FP8 tensor cores. On unsupported hardware, use INT8 weight-only quantization instead.