Part of Series Quantization Masterclass 5 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest — Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier — Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 — How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization — The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization — OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

NVIDIA Blackwell (B100, B200, GB200) introduces native tensor core support for 4-bit floating-point arithmetic. This is a hardware-level capability — the tensor cores consume FP4 operands directly and produce FP32 accumulated results. The throughput is 2x the FP8 tensor core peak: where H100 achieves approximately 1979 TFLOPS in FP8, B200 achieves approximately 4500 TOPS in FP4.

Two FP4 formats are supported: NVFP4 (NVIDIA’s proprietary format) and MXFP4 (the Open Compute Project microscaling standard). Both use 4 bits per element with a shared scaling mechanism, but they differ in how the scaling is organized and how the hardware consumes the data.

This post covers the bit-level details of both formats, the shared exponent mechanism that makes 4-bit precision viable, throughput and quality implications, when FP4 is sufficient versus when it degrades, and a complete implementation of FP4 quantization with block scaling.

Why 4-Bit Floating Point

The progression from FP16 to FP8 to FP4 follows a consistent pattern: halving precision doubles throughput because the tensor cores process twice as many elements per cycle. The memory bandwidth savings are equally dramatic — FP4 reads half the bytes of FP8, which means 2x the effective bandwidth for weight-bound operations.

Tensor Core Peak Throughput by Precision (Single B200 GPU)

(TOPS)
FP16
1,125 TOPS
BF16
1,125 TOPS
FP8 2x FP16
2,250 TOPS
FP4 (NVFP4/MXFP4) 4x FP16
4,500 TOPS

For LLM inference at batch size 1 (the decode phase), performance is entirely determined by how fast you can read model weights from HBM. A 70B model at FP4 is only 35 GB (including scale overhead) — less than half the 70 GB at FP8. This directly translates to lower latency per token and higher throughput for serving.

📊

Model Size and Decode Bandwidth Cost by Precision (70B Parameters)

PrecisionBits/ParamModel SizeDecode Read Time (B200 HBM, 8 TB/s)Relative Speed
FP16 16 140 GB 17.5 ms 1.0x
FP8 8 70 GB 8.75 ms 2.0x
INT4 (GPTQ/AWQ) 4.5 (with scales) ~40 GB 5.0 ms 3.5x
FP4 + scale ~5.0 (with scales) ~44 GB 5.5 ms 3.2x
FP4 (dense, ideal) 4 35 GB 4.4 ms 4.0x
Note: FP4 memory size includes per-block scale factors. The effective bits/param depends on block size -- larger blocks amortize scale overhead better but reduce quality.

NVFP4: NVIDIA’s 4-Bit Floating Point

NVFP4 is NVIDIA’s proprietary 4-bit format. Each element uses exactly 4 bits:

NVFP4 element (4 bits):
[S | EE | M]
 1   2    1

S: sign bit
E: 2-bit exponent (bias = 1)
M: 1-bit mantissa

The complete set of NVFP4 representable values (positive):

BitsExpMantissaValue
0000000.0
0001010.5
0010101.0
0011111.5
0100202.0
0101213.0
0110304.0
0111316.0

With the sign bit, the full set is: 6.

That is 15 distinct non-zero values plus zero. This is strikingly coarse — even INT4’s 16 levels span a wider range of distinct values for a given scale factor. The advantage of NVFP4 over INT4 is the non-uniform spacing: the values are denser near zero and sparser at large magnitudes, matching neural network weight distributions.

NVFP4 Block Scaling

NVFP4 is always used with a per-block scale factor. A block of BB consecutive NVFP4 elements shares one scale factor stored in FP8 E4M3 (or FP16). The dequantized value is:

xreal=scaleblock×xNVFP4x_{\text{real}} = \text{scale}_{\text{block}} \times x_{\text{NVFP4}}

The block size BB is a critical parameter. NVIDIA’s recommended default is B=16B = 16:

  • One FP8 scale per 16 FP4 elements = 8 bits / 16 = 0.5 bits overhead per element
  • Effective bits per element: 4 + 0.5 = 4.5 bits
  • The scale provides 8 bits of dynamic range adjustment, shifting the representable range to wherever the data lives
import torch
import numpy as np

# Complete NVFP4 value table
NVFP4_TABLE = torch.tensor([
    0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,  # positive (codes 0-7)
], dtype=torch.float32)

# Full signed table
NVFP4_SIGNED = torch.cat([NVFP4_TABLE, -NVFP4_TABLE[1:]])  # 0 only once

def nvfp4_quantize_block(block, block_size=16):
    """Quantize a 1D block of FP32 values to NVFP4.

    Args:
        block: (block_size,) FP32 tensor
        block_size: number of elements per block

    Returns:
        codes: (block_size,) uint8 (4-bit codes stored in uint8)
        scale: scalar FP32 scale factor
    """
    amax = block.abs().max().item()
    nvfp4_max = 6.0

    if amax < 1e-12:
        return torch.zeros(block_size, dtype=torch.uint8), 0.0

    scale = amax / nvfp4_max

    # Scale block to NVFP4 range
    scaled = block / scale

    # Quantize each element to nearest NVFP4 value
    codes = torch.zeros(block_size, dtype=torch.uint8)
    for i in range(block_size):
        val = scaled[i].item()
        sign = 0
        if val < 0:
            sign = 1
            val = -val

        # Find nearest positive NVFP4 value
        dists = (NVFP4_TABLE - val).abs()
        best_code = dists.argmin().item()

        codes[i] = (sign << 3) | best_code

    return codes, scale

def nvfp4_dequantize_block(codes, scale):
    """Dequantize NVFP4 codes back to FP32."""
    values = torch.zeros(len(codes), dtype=torch.float32)
    for i, code in enumerate(codes):
        code = code.item()
        sign = (code >> 3) & 1
        idx = code & 0x07
        val = NVFP4_TABLE[idx].item()
        values[i] = -val if sign else val
    return values * scale

def nvfp4_quantize_tensor(tensor, block_size=16):
    """Quantize a 2D weight tensor to NVFP4 with per-block scaling.

    Args:
        tensor: (out_features, in_features) FP32 weight matrix
        block_size: elements per scale factor

    Returns:
        codes: (out_features, in_features) uint8
        scales: (out_features, in_features // block_size) FP32
    """
    out_f, in_f = tensor.shape
    assert in_f % block_size == 0

    num_blocks = in_f // block_size
    codes = torch.zeros(out_f, in_f, dtype=torch.uint8)
    scales = torch.zeros(out_f, num_blocks, dtype=torch.float32)

    for row in range(out_f):
        for b in range(num_blocks):
            start = b * block_size
            end = start + block_size
            block = tensor[row, start:end]
            block_codes, block_scale = nvfp4_quantize_block(block, block_size)
            codes[row, start:end] = block_codes
            scales[row, b] = block_scale

    return codes, scales

def nvfp4_dequantize_tensor(codes, scales, block_size=16):
    """Dequantize NVFP4 tensor back to FP32."""
    out_f, in_f = codes.shape
    num_blocks = in_f // block_size
    result = torch.zeros(out_f, in_f, dtype=torch.float32)

    for row in range(out_f):
        for b in range(num_blocks):
            start = b * block_size
            end = start + block_size
            block_vals = nvfp4_dequantize_block(
                codes[row, start:end], scales[row, b].item()
            )
            result[row, start:end] = block_vals

    return result

MXFP4: Microscaling FP4 (OCP Standard)

MXFP4 is defined by the Open Compute Project (OCP) Microscaling Formats specification. The per-element encoding is identical to NVFP4 (1 sign + 2 exponent + 1 mantissa), but the shared scaling mechanism is different:

MXFP4 block structure:
[shared_exponent (8 bits)] [element_0 (4 bits)] [element_1 (4 bits)] ... [element_31 (4 bits)]

Block size: 32 elements (fixed by spec)
Shared exponent: 8-bit unsigned integer (E8M0 format)
Per-element: 4-bit FP4 (S1E2M1)

Key differences from NVFP4:

  1. Fixed block size of 32 (NVFP4 allows variable block sizes, typically 16)
  2. Shared exponent format: The scale is an 8-bit E8M0 value (pure power of 2), not a general floating-point scale. This means scale=2e127\text{scale} = 2^{e - 127} for stored exponent ee.
  3. No mantissa in the scale: The shared exponent is always a power of 2. This simplifies hardware (scaling by a power of 2 is a bit shift, not a multiply) but provides less precise scaling than an FP8 scale factor.
ℹ️ E8M0 Shared Exponent

The MXFP4 shared exponent is an E8M0 value: 8 exponent bits, 0 mantissa bits. This means the scale is always an exact power of 2. While this is less flexible than an FP8 E4M3 scale factor (which has 3 mantissa bits of precision), it enables the tensor core to apply the scale as a simple exponent addition rather than a multiply, reducing hardware complexity and latency.

import math

def mxfp4_quantize_block(block):
    """Quantize a 32-element block to MXFP4.

    Returns:
        codes: (32,) uint8 (4-bit codes)
        shared_exp: uint8 (E8M0 shared exponent)
    """
    assert len(block) == 32
    amax = block.abs().max().item()

    if amax < 1e-30:
        return torch.zeros(32, dtype=torch.uint8), 0

    # Shared exponent: floor(log2(amax / nvfp4_max))
    # nvfp4_max = 6.0, so we want amax / 6.0 mapped to 2^e
    target_scale = amax / 6.0
    exp_unbiased = math.floor(math.log2(target_scale)) if target_scale > 0 else -127
    shared_exp = max(0, min(255, exp_unbiased + 127))  # E8M0 bias = 127

    # Actual scale is exactly a power of 2
    scale = 2.0 ** (shared_exp - 127)

    # Quantize each element
    scaled = block / scale
    codes = torch.zeros(32, dtype=torch.uint8)

    for i in range(32):
        val = scaled[i].item()
        sign = 0
        if val < 0:
            sign = 1
            val = -val
        # Round to nearest NVFP4 value
        dists = (NVFP4_TABLE - val).abs()
        best_code = dists.argmin().item()
        codes[i] = (sign << 3) | best_code

    return codes, shared_exp

def mxfp4_dequantize_block(codes, shared_exp):
    """Dequantize MXFP4 block."""
    scale = 2.0 ** (shared_exp - 127)
    values = torch.zeros(32, dtype=torch.float32)
    for i, code in enumerate(codes):
        code = code.item()
        sign = (code >> 3) & 1
        idx = code & 0x07
        val = NVFP4_TABLE[idx].item()
        values[i] = (-val if sign else val) * scale
    return values

def mxfp4_quantize_tensor(tensor):
    """Quantize 2D tensor to MXFP4 (block size = 32)."""
    out_f, in_f = tensor.shape
    # Pad to multiple of 32
    pad = (32 - in_f % 32) % 32
    if pad > 0:
        tensor = torch.nn.functional.pad(tensor, (0, pad))
    in_f_padded = tensor.shape[1]
    num_blocks = in_f_padded // 32

    codes = torch.zeros(out_f, in_f_padded, dtype=torch.uint8)
    shared_exps = torch.zeros(out_f, num_blocks, dtype=torch.uint8)

    for row in range(out_f):
        for b in range(num_blocks):
            start = b * 32
            end = start + 32
            block_codes, block_exp = mxfp4_quantize_block(tensor[row, start:end])
            codes[row, start:end] = block_codes
            shared_exps[row, b] = block_exp

    return codes[:, :in_f], shared_exps, in_f

NVFP4 vs MXFP4: Detailed Comparison

📊

NVFP4 vs MXFP4 Feature Comparison

FeatureNVFP4MXFP4
Standard NVIDIA proprietary OCP open standard
Element format S1E2M1 (4 bits) S1E2M1 (4 bits)
Block size 16 (typical, flexible) 32 (fixed by spec)
Scale format FP8 E4M3 (8 bits) E8M0 (8 bits, power-of-2 only)
Scale precision 3 mantissa bits 0 mantissa bits (exact power of 2)
Overhead per element 0.5 bits (8/16) 0.25 bits (8/32)
Effective bits 4.5 bits/elem 4.25 bits/elem
Hardware scaling cost FP multiply Exponent add (cheaper)
Blackwell support Yes Yes
Note: NVFP4's FP8 scale provides finer granularity but costs more hardware and has higher per-element overhead. MXFP4's power-of-2 scale is cheaper to apply but less precise.

The quality difference between NVFP4 and MXFP4 is typically small for well-distributed weights. NVFP4’s FP8 scale can represent any value in the E4M3 range, while MXFP4’s E8M0 scale is restricted to powers of 2. In practice, this means MXFP4 has up to 2x scale quantization error compared to NVFP4, but the impact on model quality is usually within 0.1-0.3 perplexity points.

def compare_nvfp4_mxfp4_quality(weight):
    """Compare quantization error of NVFP4 vs MXFP4."""
    out_f, in_f = weight.shape

    # NVFP4 with block_size=16
    codes_nv, scales_nv = nvfp4_quantize_tensor(weight, block_size=16)
    recon_nv = nvfp4_dequantize_tensor(codes_nv, scales_nv, block_size=16)
    mse_nv = ((weight - recon_nv) ** 2).mean().item()

    # MXFP4 with block_size=32
    codes_mx, exps_mx, orig_cols = mxfp4_quantize_tensor(weight)
    # Dequantize
    recon_mx = torch.zeros_like(weight)
    num_blocks = (in_f + 31) // 32
    for row in range(out_f):
        for b in range(num_blocks):
            start = b * 32
            end = min(start + 32, in_f)
            block_size = end - start
            if block_size < 32:
                padded_codes = torch.zeros(32, dtype=torch.uint8)
                padded_codes[:block_size] = codes_mx[row, start:end]
            else:
                padded_codes = codes_mx[row, start:end]
            vals = mxfp4_dequantize_block(padded_codes, exps_mx[row, b].item())
            recon_mx[row, start:end] = vals[:block_size]

    mse_mx = ((weight - recon_mx) ** 2).mean().item()

    print(f"NVFP4 (block=16) MSE: {mse_nv:.8f}")
    print(f"MXFP4 (block=32) MSE: {mse_mx:.8f}")
    print(f"NVFP4/MXFP4 ratio: {mse_mx / mse_nv:.2f}x")

    return mse_nv, mse_mx

# Test
torch.manual_seed(42)
weight = torch.randn(256, 1024) * 0.02
compare_nvfp4_mxfp4_quality(weight)

Throughput: How FP4 Achieves 2x Over FP8

The 2x throughput gain of FP4 over FP8 comes from two multiplicative effects:

1. Doubled operand density: FP4 elements are half the size of FP8. The tensor core matrix multiply unit processes a fixed number of bytes per cycle. At FP4, those bytes contain twice as many elements, so the GEMM processes 2x more output elements per cycle.

2. Halved weight memory bandwidth: Loading weights from HBM requires half the bytes. For memory-bound operations (batch-1 decode), this directly translates to 2x throughput.

For compute-bound operations (large-batch prefill), the 2x tensor core throughput is the binding constraint. For memory-bound operations (single-request decode), the 2x bandwidth savings is the binding constraint. Either way, FP4 provides approximately 2x improvement over FP8.

Inference Throughput: FP4 vs FP8 on B200 (Llama 70B)

(tokens/sec)
FP8 (batch=1) decode
180 tokens/sec
FP4 (batch=1) 1.89x
340 tokens/sec
FP8 (batch=64) prefill
8,500 tokens/sec
FP4 (batch=64) 1.79x
15,200 tokens/sec

Quality Impact: When FP4 Is Enough

FP4 has only 16 distinct values (including zero). This is extremely coarse quantization. Whether it is “good enough” depends on the model size, the task, and the quality metric.

Model Size Matters

Larger models tolerate more aggressive quantization because they have more redundancy. A 70B model at FP4 loses less quality than a 7B model at FP4, relative to each model’s FP16 baseline.

📊

Perplexity Impact of FP4 vs FP8 (WikiText-2)

ModelFP16 PPLFP8 PPLFP4 (block=16) PPLFP4 Degradation
Llama 2 7B 5.47 5.49 5.78 +0.31
Llama 2 13B 4.88 4.89 5.05 +0.17
Llama 2 70B 3.32 3.33 3.41 +0.09
Mixtral 8x7B 3.84 3.85 3.96 +0.12
Note: FP4 degradation decreases with model size. At 70B, the degradation of 0.09 PPL is comparable to INT4 GPTQ/AWQ quality.

Task Sensitivity

Some tasks are more sensitive to quantization than others:

  • General text generation: Relatively tolerant. FP4 works well for chatbot-style applications.
  • Code generation: Moderately sensitive. FP4 may introduce subtle logic errors in complex code.
  • Math/reasoning: Highly sensitive. The precision loss at FP4 can cause arithmetic errors that propagate through chain-of-thought reasoning.
  • Long-context tasks: Sensitivity increases with context length because quantization errors accumulate across more attention computations.
def fp4_quality_analysis(weight, x, block_sizes=(16, 32, 64)):
    """Analyze FP4 quality at different block sizes."""
    y_ref = x @ weight.T
    results = {}

    for bs in block_sizes:
        # Quantize with NVFP4
        codes, scales = nvfp4_quantize_tensor(weight, block_size=bs)
        recon = nvfp4_dequantize_tensor(codes, scales, block_size=bs)
        y_q = x @ recon.T

        mse = ((y_ref - y_q) ** 2).mean().item()
        cos_sim = torch.nn.functional.cosine_similarity(
            y_ref.flatten().unsqueeze(0),
            y_q.flatten().unsqueeze(0)
        ).item()

        results[bs] = {'mse': mse, 'cosine_similarity': cos_sim}
        print(f"Block size {bs}: MSE={mse:.8f}, cos_sim={cos_sim:.6f}")

    return results

torch.manual_seed(42)
weight = torch.randn(512, 2048) * 0.02
x = torch.randn(32, 2048)
fp4_quality_analysis(weight, x)

Advanced: Two-Stage Quantization (FP16 to FP8 to FP4)

A practical technique for FP4 deployment is two-stage quantization: first quantize the model to FP8 using calibration data (with per-tensor or per-block scaling), then further quantize the FP8 weights to FP4. This can leverage FP8 calibration infrastructure (Transformer Engine, TensorRT-LLM) while targeting FP4 deployment.

def two_stage_fp4_quantize(weight, calibration_data=None, block_size=16):
    """Two-stage quantization: FP32/16 -> FP8 calibrated -> FP4.

    Stage 1: Quantize to FP8 E4M3 with per-block scaling
    Stage 2: Quantize the FP8 values to FP4 with sub-block scaling
    """
    out_f, in_f = weight.shape

    # Stage 1: Per-block FP8 quantization (block_size = 128)
    fp8_block_size = 128
    num_fp8_blocks = in_f // fp8_block_size
    w_grouped = weight.reshape(out_f, num_fp8_blocks, fp8_block_size)
    fp8_amax = w_grouped.abs().amax(dim=2, keepdim=True)
    fp8_scale = 448.0 / fp8_amax.clamp(min=1e-12)
    w_fp8 = (w_grouped * fp8_scale).clamp(-448, 448)

    # Simulate FP8 quantization noise (round to E4M3 precision)
    # E4M3 has 3 mantissa bits -> quantize to 8 levels per pow2
    w_fp8_dequant = w_fp8 / fp8_scale
    w_fp8_flat = w_fp8_dequant.reshape(out_f, in_f)

    # Stage 2: FP4 quantization on the FP8-quantized weights
    fp4_codes, fp4_scales = nvfp4_quantize_tensor(w_fp8_flat, block_size=block_size)
    w_fp4_recon = nvfp4_dequantize_tensor(fp4_codes, fp4_scales, block_size=block_size)

    # Compare quality
    mse_fp8 = ((weight - w_fp8_flat) ** 2).mean().item()
    mse_fp4 = ((weight - w_fp4_recon) ** 2).mean().item()

    print(f"Stage 1 (FP8) MSE: {mse_fp8:.8f}")
    print(f"Stage 2 (FP4) MSE: {mse_fp4:.8f}")
    print(f"FP4/FP8 error ratio: {mse_fp4/mse_fp8:.1f}x")

    return fp4_codes, fp4_scales

FP4 with GPTQ/AWQ: Combining Techniques

The weight quantization algorithms from Part 2 (GPTQ, AWQ) can be adapted for FP4 targets. Instead of rounding to the nearest INT4 value, the quantizer rounds to the nearest FP4 value. The error compensation (GPTQ) and channel scaling (AWQ) mechanisms are format-agnostic — they only need a quantize/dequantize function.

class FP4GPTQ:
    """GPTQ adapted for FP4 target format."""

    def __init__(self, layer, block_size=16, damp_percent=0.01):
        self.layer = layer
        self.block_size = block_size
        self.damp = damp_percent
        self.rows = layer.weight.shape[0]
        self.cols = layer.weight.shape[1]
        self.H = torch.zeros(self.cols, self.cols, dtype=torch.float32)
        self.nsamples = 0

    def add_batch(self, inp):
        if inp.dim() == 3:
            inp = inp.reshape(-1, inp.shape[-1])
        self.H += inp.float().T @ inp.float()
        self.nsamples += inp.shape[0]

    def quantize_value_fp4(self, value, scale):
        """Quantize a single value to NVFP4."""
        if scale < 1e-12:
            return 0.0
        scaled = value / scale
        sign = -1.0 if scaled < 0 else 1.0
        abs_val = abs(scaled)
        # Round to nearest FP4 value
        best_val = 0.0
        best_dist = abs_val
        for v in [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]:
            dist = abs(abs_val - v)
            if dist < best_dist:
                best_dist = dist
                best_val = v
        return sign * best_val * scale

    def quantize(self):
        """Run FP4-GPTQ: GPTQ with FP4 as the target format."""
        W = self.layer.weight.data.clone().float()
        H = self.H / self.nsamples

        damp = self.damp * torch.diag(H).mean()
        diag_idx = torch.arange(self.cols)
        H[diag_idx, diag_idx] += damp

        try:
            L = torch.linalg.cholesky(H)
        except RuntimeError:
            H[diag_idx, diag_idx] += 10 * damp
            L = torch.linalg.cholesky(H)

        H_inv = torch.cholesky_inverse(L)

        codes = torch.zeros(self.rows, self.cols, dtype=torch.uint8)
        scales = torch.zeros(self.rows, self.cols // self.block_size)

        for block_start in range(0, self.cols, self.block_size):
            block_end = min(block_start + self.block_size, self.cols)
            block_size = block_end - block_start
            group_idx = block_start // self.block_size

            W_block = W[:, block_start:block_end].clone()

            # Per-block scale for FP4
            amax = W_block.abs().amax(dim=1, keepdim=True)
            scale = amax / 6.0
            scale = scale.clamp(min=1e-10)
            scales[:, group_idx] = scale.squeeze()

            H_block_inv = H_inv[block_start:block_end, block_start:block_end]

            for col in range(block_size):
                w = W_block[:, col]

                # Quantize to FP4 instead of INT4
                w_q = torch.tensor([
                    self.quantize_value_fp4(w[r].item(), scale[r].item())
                    for r in range(self.rows)
                ])

                # Error and compensation
                err = (w - w_q) / H_block_inv[col, col]
                if col < block_size - 1:
                    W_block[:, col + 1:] -= (
                        err.unsqueeze(1) * H_block_inv[col, col + 1:].unsqueeze(0)
                    )

        return codes, scales

Deployment: FP4 in TensorRT-LLM

TensorRT-LLM supports FP4 inference on Blackwell GPUs. The deployment workflow:

  1. Quantize the model offline (GPTQ/AWQ adapted for FP4, or direct RTN FP4)
  2. Convert to TensorRT-LLM checkpoint format with FP4 weight tensors
  3. Build the TensorRT engine with FP4 GEMMs enabled
  4. Serve with the standard TensorRT-LLM runtime
# Conceptual TensorRT-LLM FP4 workflow (pseudo-code)

def prepare_fp4_model(model_path, output_path, block_size=16):
    """Prepare a model for FP4 deployment on Blackwell."""

    # Step 1: Load the FP16/BF16 model
    # model = load_model(model_path)

    # Step 2: Calibrate (128 samples)
    # calib_data = load_calibration_data(dataset='c4', n_samples=128)

    # Step 3: Quantize each linear layer to FP4
    # for name, layer in model.linear_layers():
    #     codes, scales = nvfp4_quantize_with_gptq(
    #         layer, calib_data, block_size=block_size
    #     )
    #     layer.replace_with_fp4(codes, scales)

    # Step 4: Save in TRT-LLM format
    # model.save_checkpoint(output_path, format='fp4')

    # Step 5: Build TRT engine
    # trtllm-build --checkpoint_dir output_path \
    #              --gemm_plugin fp4 \
    #              --max_batch_size 64

    pass

Summary

FP4 quantization on NVIDIA Blackwell represents the current frontier of low-precision inference. NVFP4 and MXFP4 both use 4-bit floating-point elements with shared scaling, achieving 2x throughput over FP8 tensor cores.

NVFP4 uses a flexible block size (typically 16) with an FP8 E4M3 scale factor per block. The scale has 3 mantissa bits of precision, providing fine-grained range adjustment.

MXFP4 uses a fixed block size of 32 with an E8M0 (power-of-2 only) shared exponent. The hardware applies the scale as an exponent addition, which is cheaper than a full multiply.

Quality at FP4 depends strongly on model size. At 70B parameters, FP4 degradation is under 0.1 perplexity points — comparable to INT4 GPTQ/AWQ. At 7B, degradation is larger (0.3+ PPL) and may be unacceptable for precision-sensitive tasks.

GPTQ and AWQ can target FP4 instead of INT4. The Hessian-based error compensation (GPTQ) and activation-aware scaling (AWQ) are format-agnostic and work with any quantization target, including non-uniform FP4 levels.

The next and final post in this series covers KV cache quantization — a separate but equally important optimization that operates on the dynamic state generated during inference rather than the static model weights.