Part of Series Quantization Masterclass 19 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest — Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier — Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 — How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization — The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization — OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

Training quantization is fundamentally different from inference quantization. In inference, we quantize a frozen model to reduce memory and increase throughput — any quality loss is permanent. In training, we quantize the forward and backward passes to speed up gradient computation, but maintain high-precision master weights that accumulate gradients. The quantization error in each step is ephemeral: it affects the gradient estimate but not the converged model.

This distinction means training can tolerate higher quantization error per step (because gradient descent is inherently noisy), but it requires careful management of numerical range (because gradients can be extremely small or large). The history of low-precision training is a progression from FP32 to FP16 to BF16 to FP8, with each step requiring new techniques to handle the precision/range tradeoff.

Mixed-Precision Training: The FP16 Recipe

Mixed-precision training (Micikevicius et al., 2018) established the fundamental pattern:

  1. Master weights: FP32 (full precision, stored in optimizer)
  2. Forward pass: FP16 (weights cast down, activations computed in FP16)
  3. Backward pass: FP16 (gradients computed in FP16)
  4. Weight update: FP32 (gradients cast up, applied to FP32 master weights)
import torch
import torch.nn as nn

class FP16MixedPrecisionTrainer:
    """Simplified FP16 mixed-precision training loop."""

    def __init__(self, model, optimizer, loss_scale_init=65536.0):
        self.model = model
        self.optimizer = optimizer
        self.loss_scale = loss_scale_init
        self.master_weights = {}

        # Store FP32 master copies
        for name, param in model.named_parameters():
            self.master_weights[name] = param.data.float().clone()

    def train_step(self, input_data, targets):
        """One training step with FP16 forward/backward + FP32 update."""

        # Step 1: Cast weights to FP16 for forward pass
        for name, param in self.model.named_parameters():
            param.data = self.master_weights[name].half()

        # Step 2: Forward pass in FP16
        output = self.model(input_data.half())
        loss = nn.functional.cross_entropy(output.float(), targets)

        # Step 3: Scale loss before backward (for FP16 gradient range)
        scaled_loss = loss * self.loss_scale

        # Step 4: Backward pass in FP16
        self.optimizer.zero_grad()
        scaled_loss.backward()

        # Step 5: Unscale gradients and check for overflow
        overflow = False
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                param.grad.data /= self.loss_scale
                if torch.any(torch.isinf(param.grad)) or torch.any(torch.isnan(param.grad)):
                    overflow = True
                    break

        if overflow:
            # Skip this step, reduce loss scale
            self.loss_scale /= 2
            return loss.item(), True  # Skipped

        # Step 6: Update FP32 master weights
        for name, param in self.model.named_parameters():
            if param.grad is not None:
                grad_fp32 = param.grad.data.float()
                self.master_weights[name] -= (
                    self.optimizer.defaults['lr'] * grad_fp32
                )

        # Periodically increase loss scale
        self.loss_scale = min(self.loss_scale * 2, 65536.0)

        return loss.item(), False  # Not skipped

Why Loss Scaling is Necessary

FP16 has a range of [6×108,65504][6 \times 10^{-8}, 65504]. Gradients in deep networks often fall below 6×1086 \times 10^{-8}, becoming zero in FP16 (underflow). Loss scaling multiplies the loss by a large factor (e.g., 216=655362^{16} = 65536) before backward, which scales all gradients up proportionally, keeping them in FP16’s representable range.

def demonstrate_gradient_underflow(hidden_dim=4096, depth=32):
    """Show that gradients underflow in FP16 without loss scaling."""

    # Simulate gradient magnitudes through a deep network
    # Each layer multiplies gradient by ~sqrt(2/hidden_dim)
    grad_scale = 1.0
    for layer in range(depth):
        grad_scale *= (2 / hidden_dim) ** 0.5

    print(f"Expected gradient scale after {depth} layers: {grad_scale:.2e}")
    print(f"FP16 minimum subnormal: {np.float16(0).itemsize}")

    # FP16 smallest representable positive: 2^(-24) ~ 5.96e-8
    fp16_min = 2 ** (-24)
    print(f"FP16 min subnormal: {fp16_min:.2e}")

    if grad_scale < fp16_min:
        print(f"UNDERFLOW: gradient ({grad_scale:.2e}) < FP16 min ({fp16_min:.2e})")
        required_scale = fp16_min / grad_scale
        print(f"Required loss scale: >= {required_scale:.0f}")
    else:
        print("No underflow risk")

BF16: Why It Displaced FP16

BF16 (Brain Float 16) has 8 exponent bits and 7 mantissa bits, compared to FP16’s 5 exponent bits and 10 mantissa bits:

FormatSignExponentMantissaRangePrecision
FP161510±65504\pm 655043.3\sim 3.3 digits
BF16187±3.4×1038\pm 3.4 \times 10^{38}2.4\sim 2.4 digits

BF16 has the same range as FP32 (both have 8 exponent bits), which eliminates the need for loss scaling entirely. The reduced precision (7 vs 10 mantissa bits) causes slightly more rounding error per operation, but this is compensated by the elimination of gradient underflow and overflow.

def compare_fp16_bf16_training():
    """Compare FP16 and BF16 training characteristics."""

    # FP16 gradient range issues
    fp16_max = 65504
    fp16_min_normal = 2 ** (-14)  # ~6.1e-5
    fp16_min_subnormal = 2 ** (-24)  # ~5.96e-8

    # BF16 gradient range
    bf16_max = 3.389e38  # Same as FP32
    bf16_min_normal = 2 ** (-126)  # ~1.18e-38
    bf16_min_subnormal = 2 ** (-133)  # ~9.18e-41

    print("FP16:")
    print(f"  Max: {fp16_max:.0f}")
    print(f"  Min normal: {fp16_min_normal:.2e}")
    print(f"  Needs loss scaling: Yes")

    print("\nBF16:")
    print(f"  Max: {bf16_max:.2e}")
    print(f"  Min normal: {bf16_min_normal:.2e}")
    print(f"  Needs loss scaling: No")

    # Precision comparison
    fp16_precision = 2 ** (-10)  # Relative precision
    bf16_precision = 2 ** (-7)

    print(f"\nRelative precision:")
    print(f"  FP16: {fp16_precision:.2e} (~3.3 decimal digits)")
    print(f"  BF16: {bf16_precision:.2e} (~2.4 decimal digits)")
    print(f"  BF16 rounding error is {bf16_precision / fp16_precision:.0f}x "
          f"larger per operation")
💡 BF16: No Loss Scaling Required

BF16 training is simpler than FP16 training because the FP32-equivalent range eliminates gradient underflow. The training recipe reduces to: (1) master weights in FP32, (2) forward and backward in BF16, (3) weight update in FP32. No loss scaling, no overflow detection, no skipped steps. This simplicity is why BF16 became the default for LLM training starting with T5 (2019).

FP8 Training on Hopper

FP8 training uses two FP8 formats:

  • E4M3 (4 exponent, 3 mantissa): for forward pass weights and activations. Higher precision, narrower range.
  • E5M2 (5 exponent, 2 mantissa): for backward pass gradients. Wider range, lower precision. Gradients need range more than precision.
def fp8_format_comparison():
    """Compare E4M3 and E5M2 FP8 formats."""
    formats = {
        'E4M3': {
            'exponent_bits': 4,
            'mantissa_bits': 3,
            'bias': 7,
            'max_value': 448,
            'min_normal': 2 ** (-6),  # 0.015625
            'precision': 2 ** (-3),   # 0.125 relative
            'use': 'Forward pass (weights, activations)',
        },
        'E5M2': {
            'exponent_bits': 5,
            'mantissa_bits': 2,
            'bias': 15,
            'max_value': 57344,
            'min_normal': 2 ** (-14),  # ~6.1e-5
            'precision': 2 ** (-2),    # 0.25 relative
            'use': 'Backward pass (gradients)',
        },
    }

    for name, fmt in formats.items():
        print(f"\n{name}:")
        print(f"  Range: [{fmt['min_normal']:.2e}, {fmt['max_value']}]")
        print(f"  Precision: {fmt['precision']} relative")
        print(f"  Dynamic range: {fmt['max_value'] / fmt['min_normal']:.0f}x")
        print(f"  Use: {fmt['use']}")

    return formats

# E4M3 for forward: precision matters (weight values affect output directly)
# E5M2 for backward: range matters (gradients can be tiny or huge)

The FP8 Training Pipeline

class FP8TrainingStep:
    """One training step with FP8 GEMM on Hopper."""

    def __init__(self, model, optimizer):
        self.model = model
        self.optimizer = optimizer
        # Per-tensor scale factors (maintained across steps)
        self.weight_scales = {}
        self.activation_scales = {}
        self.gradient_scales = {}

    def compute_scale(self, tensor, format_type='e4m3'):
        """Compute dynamic per-tensor scale for FP8 quantization.

        Scale maps the tensor's range to FP8's representable range.
        """
        abs_max = tensor.abs().max()

        if format_type == 'e4m3':
            fp8_max = 448.0
        elif format_type == 'e5m2':
            fp8_max = 57344.0

        # Scale so that abs_max maps to fp8_max
        scale = fp8_max / abs_max.clamp(min=1e-12)
        return scale

    def forward_fp8(self, x, weight, layer_name):
        """FP8 forward pass for a single linear layer.

        1. Quantize weight to FP8 E4M3
        2. Quantize activation to FP8 E4M3
        3. Run FP8 GEMM (accumulate in FP32)
        4. Output in BF16/FP32
        """
        # Compute scales
        w_scale = self.compute_scale(weight, 'e4m3')
        x_scale = self.compute_scale(x, 'e4m3')

        # Quantize to FP8 (simulated -- real HW uses native FP8)
        w_fp8 = self.simulate_fp8_quantize(weight * w_scale, 'e4m3') / w_scale
        x_fp8 = self.simulate_fp8_quantize(x * x_scale, 'e4m3') / x_scale

        # FP8 GEMM with FP32 accumulation
        # On Hopper: fp8_e4m3 x fp8_e4m3 -> fp32 (tensor cores)
        output = x_fp8 @ w_fp8.T  # Simulated as FP32

        # Store scales for backward
        self.weight_scales[layer_name] = w_scale
        self.activation_scales[layer_name] = x_scale

        return output

    def backward_fp8(self, grad_output, x, weight, layer_name):
        """FP8 backward pass.

        Gradient w.r.t. input: dX = dY @ W (use E5M2 for gradients)
        Gradient w.r.t. weight: dW = dY^T @ X (use E5M2 for gradients)
        """
        # Quantize gradients to FP8 E5M2
        g_scale = self.compute_scale(grad_output, 'e5m2')
        grad_fp8 = self.simulate_fp8_quantize(
            grad_output * g_scale, 'e5m2'
        ) / g_scale

        # dX = grad @ W (FP8 GEMM)
        w_scale = self.weight_scales[layer_name]
        w_fp8 = self.simulate_fp8_quantize(weight * w_scale, 'e4m3') / w_scale
        dX = grad_fp8 @ w_fp8  # E5M2 x E4M3 -> FP32

        # dW = grad^T @ X (FP8 GEMM)
        x_scale = self.activation_scales[layer_name]
        x_fp8 = self.simulate_fp8_quantize(x * x_scale, 'e4m3') / x_scale
        dW = grad_fp8.T @ x_fp8  # E5M2 x E4M3 -> FP32

        return dX, dW

    @staticmethod
    def simulate_fp8_quantize(tensor, format_type):
        """Simulate FP8 quantization by rounding to FP8 precision."""
        if format_type == 'e4m3':
            # 3 mantissa bits: round to nearest 1/8
            mantissa_bits = 3
        elif format_type == 'e5m2':
            # 2 mantissa bits: round to nearest 1/4
            mantissa_bits = 2

        # Simplified simulation: add noise proportional to precision
        precision = 2 ** (-mantissa_bits)
        # Stochastic rounding (better for training than RTN)
        noise = torch.rand_like(tensor) - 0.5
        quantized = torch.round(tensor / precision + noise) * precision

        return quantized

NVIDIA Transformer Engine

NVIDIA’s Transformer Engine (TE) is the production implementation of FP8 training. It manages per-tensor scales using a delayed scaling strategy:

class TransformerEngineScaling:
    """Simplified Transformer Engine delayed scaling.

    TE maintains a history of tensor max values and uses the
    max from the PREVIOUS iteration to set the scale for the
    CURRENT iteration. This avoids an extra synchronization
    point within each iteration.
    """

    def __init__(self, format_type='e4m3', history_length=1024, margin=0):
        self.format_type = format_type
        self.history = []
        self.history_length = history_length
        self.margin = margin

        if format_type == 'e4m3':
            self.fp8_max = 448.0
        elif format_type == 'e5m2':
            self.fp8_max = 57344.0

    def get_scale(self):
        """Get scale factor based on history.

        Uses max of recent history with safety margin.
        """
        if not self.history:
            return 1.0

        amax = max(self.history)
        scale = self.fp8_max / (amax * (2 ** self.margin))
        return max(scale, 1e-12)

    def update_history(self, tensor_abs_max):
        """Record the current tensor's max for future scaling."""
        self.history.append(tensor_abs_max)
        if len(self.history) > self.history_length:
            self.history.pop(0)

# Transformer Engine integration in PyTorch:
# import transformer_engine.pytorch as te
#
# # Replace nn.Linear with te.Linear
# linear = te.Linear(4096, 4096, bias=False)
#
# # The te.Linear layer automatically:
# # 1. Maintains FP8 scale history for weights and activations
# # 2. Quantizes to FP8 E4M3 in forward, E5M2 in backward
# # 3. Uses FP8 tensor cores for GEMM
# # 4. Accumulates in FP32
# # 5. Returns BF16 output
#
# # Wrap training loop in FP8 context:
# with te.fp8_autocast(enabled=True):
#     output = model(input)
#     loss = criterion(output, target)
#     loss.backward()
ℹ️ Delayed Scaling Avoids Synchronization

Transformer Engine uses the previous iteration’s tensor max to compute the current iteration’s scale. This is safe because tensor magnitudes change slowly between iterations (learning rate is small). The alternative — computing the max within the current iteration — would require an extra device-wide synchronization, adding latency to every GEMM.

FP8 Training Results

📊

FP8 vs BF16 Training: Quality and Throughput

ModelBF16 LossFP8 LossDegradationThroughput Gain
GPT-3 175B (reported) 2.80 2.80 +0.00 ~1.6x
Llama-2 7B (reproduced) 1.82 1.83 +0.01 ~1.4x
Llama-2 13B (reproduced) 1.72 1.72 +0.00 ~1.5x
Llama-2 70B (reproduced) 1.56 1.56 +0.00 ~1.7x
Mistral 7B (reported) --- --- matched ~1.4x
Note: FP8 training achieves lossless quality (< 0.01 loss degradation) while providing 1.4-1.7x throughput improvement on H100. Larger models benefit more because they are more compute-bound.

H100 Training Throughput by Precision

(Tokens/sec (normalized to BF16=1.0))
FP32 No tensor cores
0.5 Tokens/sec (normalized to BF16=1.0)
BF16 Baseline
1 Tokens/sec (normalized to BF16=1.0)
FP8 (TE) 1.5x faster
1.5 Tokens/sec (normalized to BF16=1.0)
FP8 (theoretical) 2x TC peak
2 Tokens/sec (normalized to BF16=1.0)
⚠️ FP8 Achieves 1.5x, Not 2x, Throughput

FP8 tensor cores are 2x faster than BF16, but end-to-end training throughput is only 1.4-1.7x faster. The gap comes from: (1) non-GEMM operations (LayerNorm, softmax, activation functions) still run in BF16/FP32, (2) FP8 quantization/dequantization overhead, (3) scale factor management. As the model grows larger (more compute vs overhead), the throughput gain approaches 2x.

Why INT8 Training Fails

INT8 has been tried for training but does not work well:

def why_int8_training_fails():
    """Demonstrate why INT8 is unsuitable for training.

    Three fundamental problems:

    1. No subnormals: INT8 has a hard zero at 0 and jumps to +/- 1.
       Gradients near zero are quantized to exactly 0, destroying
       information. FP8's subnormal range provides gradual underflow.

    2. Uniform spacing: Gradients span many orders of magnitude within
       a single layer. INT8's uniform spacing wastes precision on the
       large range while losing small gradients. FP8's logarithmic
       spacing naturally handles this.

    3. No signed zero: INT8 symmetric has zero but cannot distinguish
       very small positive from very small negative gradients. Both
       round to 0, losing the sign information needed for SGD.
    """

    # Simulate gradient distribution in a deep network
    grad_magnitudes = np.abs(np.random.randn(10000))
    # Scale to typical gradient range
    grad_magnitudes *= 1e-5

    # INT8 symmetric: scale to fit range
    int8_max = 127
    scale = np.max(grad_magnitudes) / int8_max
    int8_quant = np.round(grad_magnitudes / scale).clip(-128, 127)

    # Count gradients that become zero
    zeros_pct = np.mean(int8_quant == 0) * 100

    # FP8 E5M2: logarithmic spacing preserves small values
    # (simplified simulation)
    fp8_min = 2 ** (-16)  # E5M2 smallest subnormal
    fp8_nonzero_pct = np.mean(grad_magnitudes > fp8_min) * 100

    print(f"Gradients zeroed by INT8: {zeros_pct:.1f}%")
    print(f"Gradients representable by FP8 E5M2: {fp8_nonzero_pct:.1f}%")

# Expected: INT8 zeros ~40% of gradients, FP8 preserves >99%

FP4 Training on Blackwell

Blackwell’s FP4 tensor cores enable 4-bit training, achieving 4x the throughput of BF16. Early results use FP4 for the forward pass with FP8 for the backward:

class FP4TrainingConfig:
    """Configuration for FP4 training on Blackwell."""

    def __init__(self):
        # Forward pass: FP4 weights and activations
        self.forward_weight_format = 'fp4_e2m1'  # 2 exponent, 1 mantissa
        self.forward_activation_format = 'fp4_e2m1'
        self.forward_accumulation = 'fp32'  # FP32 accumulation

        # Backward pass: FP8 gradients (FP4 gradients lose too much)
        self.backward_gradient_format = 'fp8_e5m2'
        self.backward_weight_format = 'fp8_e4m3'

        # Master weights: FP32
        self.master_weight_format = 'fp32'

        # Optimizer states: FP32 (momentum, variance for Adam)
        self.optimizer_format = 'fp32'

        # Scaling
        self.weight_block_size = 16  # MXFP4: 16 elements share one scale
        self.activation_block_size = 16
        self.gradient_per_tensor = True

    def memory_estimate(self, num_params_B):
        """Estimate memory usage for FP4 training."""
        # Master weights: FP32 = 4 bytes per param
        master = num_params_B * 4

        # Optimizer states (Adam): 2 FP32 states per param
        optimizer = num_params_B * 4 * 2

        # FP4 weights: 0.5 bytes per param + scales
        fp4_weights = num_params_B * 0.5 + num_params_B * 2 / 16  # FP16 scales per 16

        # Activation memory: depends on batch size and seq length
        # Roughly proportional to batch * seq * hidden * num_layers * 0.5
        # (FP4 activations for checkpointing)

        return {
            'master_weights_GB': master,
            'optimizer_states_GB': optimizer,
            'fp4_weights_GB': fp4_weights,
            'total_static_GB': master + optimizer + fp4_weights,
        }

config = FP4TrainingConfig()
mem = config.memory_estimate(70)  # 70B model
print(f"70B FP4 training memory estimate:")
print(f"  Master weights: {mem['master_weights_GB']:.0f} GB")
print(f"  Optimizer states: {mem['optimizer_states_GB']:.0f} GB")
print(f"  FP4 weights: {mem['fp4_weights_GB']:.0f} GB")
print(f"  Total static: {mem['total_static_GB']:.0f} GB")

The BF16 Default: Why It Persists

Despite the availability of FP8 and FP4, BF16 remains the default training precision for most organizations:

📊

Training Precision Selection Guide

FormatThroughput vs FP32Quality RiskComplexityWhen to Use
FP32 1.0x None Minimal Debugging, small models
BF16 mixed 2.0x None Low Default for all training
FP16 mixed 2.0x Low (need loss scaling) Medium Legacy, A100 without BF16
FP8 (TE) ~3.0x Very low Medium Large models on H100+
FP4 (MXFP4) ~4.0x Under research High Blackwell, experimental
Note: BF16 is the sweet spot: 2x speedup with zero quality risk and minimal complexity. FP8 adds 1.5x more throughput but requires Transformer Engine and careful validation. FP4 is bleeding-edge.
def recommend_training_precision(
    gpu_type,
    model_size_B,
    risk_tolerance,
    team_experience,
):
    """Recommend training precision."""

    if gpu_type in ['V100', 'T4']:
        return 'FP16 mixed', 'Only FP16 tensor cores available'

    if gpu_type == 'A100':
        if risk_tolerance == 'zero':
            return 'BF16 mixed', 'Safe default, well-validated'
        return 'BF16 mixed', 'FP8 not available on A100'

    if gpu_type in ['H100', 'H200']:
        if risk_tolerance == 'zero':
            return 'BF16 mixed', 'Proven safe, slight throughput sacrifice'
        if model_size_B >= 13 and team_experience == 'advanced':
            return 'FP8 (TE)', 'Meaningful throughput gain at this scale'
        return 'BF16 mixed', 'FP8 benefit small for models < 13B'

    if gpu_type in ['B200', 'B100']:
        if team_experience == 'advanced' and model_size_B >= 70:
            return 'FP4 experimental', 'Maximum throughput for large models'
        if model_size_B >= 13:
            return 'FP8 (TE)', 'Well-validated on Blackwell'
        return 'BF16 mixed', 'Default safe choice'

    return 'BF16 mixed', 'Unknown GPU, use safe default'

Stochastic Rounding for Training

Training benefits from stochastic rounding instead of round-to-nearest. With RTN, small gradients that are below the quantization step are always rounded to zero, introducing a consistent bias. Stochastic rounding randomly rounds up or down with probability proportional to the distance to the nearest level:

SR(x)={xwith probability xxxwith probability xx\text{SR}(x) = \begin{cases} \lfloor x \rfloor & \text{with probability } \lceil x \rceil - x \\ \lceil x \rceil & \text{with probability } x - \lfloor x \rfloor \end{cases}

def stochastic_round(x, scale):
    """Stochastic rounding for quantization.

    Unlike RTN, stochastic rounding is unbiased:
    E[SR(x)] = x

    This means that on average, the gradient direction is preserved
    even when individual gradients are below the quantization step.
    """
    x_scaled = x / scale
    x_floor = torch.floor(x_scaled)
    # Probability of rounding up = fractional part
    prob_up = x_scaled - x_floor
    # Random rounding
    x_rounded = x_floor + (torch.rand_like(x_scaled) < prob_up).float()
    return x_rounded * scale

# Demonstrate: 1000 applications of SR to a small value
# should average to the true value
value = torch.tensor(0.3)
scale = 1.0  # Step size
rounds = torch.tensor([stochastic_round(value, scale).item()
                        for _ in range(10000)])
print(f"True value: {value.item()}")
print(f"RTN: {torch.round(value / scale).item() * scale}")
print(f"SR mean: {rounds.mean():.4f}")
print(f"SR std: {rounds.std():.4f}")
# SR mean ~ 0.3 (unbiased), RTN = 0.0 (biased to zero)
ℹ️ Stochastic Rounding is Critical for Sub-8-bit Training

At FP8 and especially FP4 precision, the quantization step size is large enough that many gradients fall below it. Stochastic rounding preserves the expected gradient direction, allowing convergence despite aggressive quantization. Transformer Engine uses stochastic rounding for FP8 backward pass quantization.

Quantized Optimizer States

Beyond GEMM quantization, optimizer states can also be quantized to reduce memory:

class Int8AdamW:
    """Adam optimizer with INT8 quantized momentum and variance states.

    Standard Adam: ~12 bytes per parameter (FP32 m, v, master weight)
    INT8 Adam: ~6 bytes per parameter (INT8 m, v, FP32 master weight)

    Block-wise quantization: each block of 2048 parameters shares
    one FP32 scale for m and one for v.
    """

    def __init__(self, params, lr=1e-4, betas=(0.9, 0.999),
                 eps=1e-8, weight_decay=0.01, block_size=2048):
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.weight_decay = weight_decay
        self.block_size = block_size

        self.states = {}
        for p in params:
            n = p.numel()
            num_blocks = (n + block_size - 1) // block_size
            self.states[p] = {
                'm_int8': torch.zeros(n, dtype=torch.int8, device=p.device),
                'm_scales': torch.zeros(num_blocks, device=p.device),
                'v_int8': torch.zeros(n, dtype=torch.int8, device=p.device),
                'v_scales': torch.zeros(num_blocks, device=p.device),
                'step': 0,
            }

    def step(self, params_with_grads):
        for p, grad in params_with_grads:
            state = self.states[p]
            state['step'] += 1
            t = state['step']

            # Dequantize states
            m = self._dequantize(state['m_int8'], state['m_scales'])
            v = self._dequantize(state['v_int8'], state['v_scales'])

            # Standard Adam update
            m = self.beta1 * m + (1 - self.beta1) * grad.flatten()
            v = self.beta2 * v + (1 - self.beta2) * grad.flatten() ** 2

            m_hat = m / (1 - self.beta1 ** t)
            v_hat = v / (1 - self.beta2 ** t)

            update = m_hat / (v_hat.sqrt() + self.eps)

            if self.weight_decay > 0:
                update += self.weight_decay * p.data.flatten()

            p.data.flatten().add_(-self.lr * update)

            # Re-quantize states
            state['m_int8'], state['m_scales'] = self._quantize(m)
            state['v_int8'], state['v_scales'] = self._quantize(v)

    def _quantize(self, tensor):
        n = tensor.numel()
        num_blocks = (n + self.block_size - 1) // self.block_size
        scales = torch.zeros(num_blocks, device=tensor.device)
        q = torch.zeros(n, dtype=torch.int8, device=tensor.device)

        for b in range(num_blocks):
            start = b * self.block_size
            end = min(start + self.block_size, n)
            block = tensor[start:end]
            amax = block.abs().max()
            s = amax / 127 if amax > 0 else 1.0
            scales[b] = s
            q[start:end] = (block / s).round().clamp(-128, 127).to(torch.int8)

        return q, scales

    def _dequantize(self, q, scales):
        n = q.numel()
        result = torch.zeros(n, device=q.device)
        for b in range(len(scales)):
            start = b * self.block_size
            end = min(start + self.block_size, n)
            result[start:end] = q[start:end].float() * scales[b]
        return result