Part of Series Quantization Masterclass 9 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest โ€” Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier โ€” Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 โ€” How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization โ€” The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization โ€” OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

Quantization requires choosing a scale factor ss that maps floating-point values to integer grid points. For symmetric INT8 quantization, the scale determines the mapping: q=round(x/s)q = \text{round}(x / s), where qโˆˆ[โˆ’128,127]q \in [-128, 127]. The choice of ss determines the tradeoff between clipping error (values outside the representable range are clamped) and rounding error (values inside the range are rounded to the nearest grid point).

The simplest approach โ€” set ss from the observed min/max values โ€” works for well-behaved distributions but fails when outliers are present (which they always are in LLM activations). Better calibration methods clip outliers (percentile), minimize total quantization error (MSE-optimal), or jointly optimize across layers (GPTQ-style). The difference between MinMax and MSE-optimal calibration can be 0.5+ perplexity points on a 7B model at INT8.

This post implements four calibration methods from scratch, benchmarks them on realistic distributions, and builds a complete calibration pipeline for LLM quantization.

The Scale Factor Problem

What the Scale Factor Controls

For symmetric quantization with bb bits, the quantized value is:

q=clamp(round(x/s),โˆ’2bโˆ’1,2bโˆ’1โˆ’1)q = \text{clamp}(\text{round}(x / s), -2^{b-1}, 2^{b-1} - 1)

The dequantized value is:

x^=qโ‹…s\hat{x} = q \cdot s

The quantization error for a single value is โˆฃxโˆ’x^โˆฃ|x - \hat{x}|. This error has two components:

  1. Rounding error: When xx falls between two grid points, it is rounded to the nearest one. Maximum rounding error is s/2s/2.

  2. Clipping error: When โˆฃxโˆฃ>sโ‹…2bโˆ’1|x| > s \cdot 2^{b-1} (or โˆฃxโˆฃ>sโ‹…(2bโˆ’1โˆ’1)|x| > s \cdot (2^{b-1} - 1) for the positive side), the value is clamped. Clipping error is โˆฃxโˆฃโˆ’sโ‹…2bโˆ’1|x| - s \cdot 2^{b-1}.

The scale ss controls the tradeoff: a larger ss reduces clipping error (wider range) but increases rounding error (coarser grid). A smaller ss reduces rounding error but clips more values.

import torch
import numpy as np
from typing import Tuple

def quantize_symmetric(x, scale, num_bits=8):
    """Symmetric quantization with given scale."""
    q_max = 2 ** (num_bits - 1) - 1
    q_min = -(2 ** (num_bits - 1))

    q = torch.clamp(torch.round(x / scale), q_min, q_max)
    x_hat = q * scale
    return x_hat, q

def compute_errors(x, x_hat, scale, num_bits=8):
    """Decompose total error into rounding and clipping components."""
    q_max = 2 ** (num_bits - 1) - 1

    # Values that were clipped
    clipped_mask = x.abs() > scale * q_max
    unclipped_mask = ~clipped_mask

    # Total error
    total_mse = ((x - x_hat) ** 2).mean().item()

    # Rounding error (unclipped values only)
    if unclipped_mask.any():
        rounding_mse = ((x[unclipped_mask] - x_hat[unclipped_mask]) ** 2).mean().item()
    else:
        rounding_mse = 0.0

    # Clipping error (clipped values only)
    if clipped_mask.any():
        clipping_mse = ((x[clipped_mask] - x_hat[clipped_mask]) ** 2).mean().item()
    else:
        clipping_mse = 0.0

    clip_fraction = clipped_mask.float().mean().item()

    return {
        'total_mse': total_mse,
        'rounding_mse': rounding_mse,
        'clipping_mse': clipping_mse,
        'clip_fraction': clip_fraction,
    }

Visualizing the Tradeoff

def scale_tradeoff_analysis(x, num_bits=8, num_scales=100):
    """Show how total error varies with scale factor."""
    q_max = 2 ** (num_bits - 1) - 1
    abs_max = x.abs().max().item()

    # Try scales from very small to abs_max/q_max (MinMax scale)
    minmax_scale = abs_max / q_max
    scales = torch.linspace(minmax_scale * 0.1, minmax_scale * 1.5, num_scales)

    results = []
    for s in scales:
        x_hat, _ = quantize_symmetric(x, s.item(), num_bits)
        errors = compute_errors(x, x_hat, s.item(), num_bits)
        results.append({
            'scale': s.item(),
            'scale_ratio': s.item() / minmax_scale,
            **errors
        })

    # Find optimal scale (minimum total MSE)
    best = min(results, key=lambda r: r['total_mse'])
    print(f"MinMax scale: {minmax_scale:.6f}")
    print(f"Optimal scale: {best['scale']:.6f} "
          f"({best['scale_ratio']:.3f}x MinMax)")
    print(f"MinMax MSE: {results[-1]['total_mse']:.8f}")
    print(f"Optimal MSE: {best['total_mse']:.8f}")
    print(f"MSE reduction: {(1 - best['total_mse']/results[-1]['total_mse'])*100:.1f}%")

    return results

# Test with a distribution that has outliers
torch.manual_seed(42)
x_normal = torch.randn(100000) * 0.5
# Add 0.1% outliers at 10x magnitude
outlier_idx = torch.randperm(100000)[:100]
x_normal[outlier_idx] *= 10.0

scale_tradeoff_analysis(x_normal)

Method 1: MinMax Calibration

Algorithm

The simplest calibration method. Observe the minimum and maximum values of the tensor and set the scale to cover the full range:

sminmax=maxโก(โˆฃxminโกโˆฃ,โˆฃxmaxโกโˆฃ)2bโˆ’1โˆ’1s_{\text{minmax}} = \frac{\max(|x_{\min}|, |x_{\max}|)}{2^{b-1} - 1}

class MinMaxCalibrator:
    """MinMax calibration: scale from observed min/max."""

    def __init__(self, num_bits=8, symmetric=True, per_channel=False):
        self.num_bits = num_bits
        self.symmetric = symmetric
        self.per_channel = per_channel
        self.q_max = 2 ** (num_bits - 1) - 1

        self.running_min = None
        self.running_max = None
        self.num_batches = 0

    def observe(self, x):
        """Record min/max from a batch of data."""
        if self.per_channel:
            # Reduce over all dims except the first (output channel)
            batch_min = x.reshape(x.shape[0], -1).min(dim=1).values
            batch_max = x.reshape(x.shape[0], -1).max(dim=1).values
        else:
            batch_min = x.min()
            batch_max = x.max()

        if self.running_min is None:
            self.running_min = batch_min.clone()
            self.running_max = batch_max.clone()
        else:
            self.running_min = torch.min(self.running_min, batch_min)
            self.running_max = torch.max(self.running_max, batch_max)

        self.num_batches += 1

    def compute_scale(self):
        """Compute scale from observed min/max."""
        if self.symmetric:
            abs_max = torch.max(self.running_min.abs(), self.running_max.abs())
            scale = abs_max / self.q_max
        else:
            scale = (self.running_max - self.running_min) / (2 ** self.num_bits - 1)

        return torch.clamp(scale, min=1e-8)

    def reset(self):
        self.running_min = None
        self.running_max = None
        self.num_batches = 0

Weakness: Outlier Sensitivity

MinMax calibration is dominated by the single largest value. If one activation outlier is 100x larger than the typical value, the scale is set to accommodate that outlier, wasting most of the quantization range on values that never occur.

def demonstrate_minmax_outlier_problem():
    """Show how a single outlier destroys MinMax calibration quality."""
    torch.manual_seed(42)

    # Normal distribution
    x = torch.randn(10000) * 0.5

    # MinMax without outliers
    scale_clean = x.abs().max().item() / 127.0
    x_hat_clean, _ = quantize_symmetric(x, scale_clean)
    mse_clean = ((x - x_hat_clean) ** 2).mean().item()

    # Add a single outlier
    x_outlier = x.clone()
    x_outlier[0] = 50.0  # 100x typical magnitude

    scale_outlier = x_outlier.abs().max().item() / 127.0
    x_hat_outlier, _ = quantize_symmetric(x_outlier, scale_outlier)
    mse_outlier = ((x_outlier - x_hat_outlier) ** 2).mean().item()

    print(f"Without outlier: scale={scale_clean:.6f}, MSE={mse_clean:.8f}")
    print(f"With outlier:    scale={scale_outlier:.6f}, MSE={mse_outlier:.8f}")
    print(f"MSE increase: {mse_outlier/mse_clean:.1f}x")
    print(f"Scale increase: {scale_outlier/scale_clean:.1f}x")

demonstrate_minmax_outlier_problem()
โš ๏ธ MinMax Is the Default in Many Frameworks

PyTorchโ€™s built-in quantization uses MinMax calibration by default. For weights (which are well-behaved), this is usually adequate. For activations (which have outliers), MinMax should never be used without additional techniques like SmoothQuant.

Method 2: Percentile Calibration

Algorithm

Instead of using the absolute min/max, clip at the pp-th percentile of the absolute value distribution. Common choices are p=99.9%p = 99.9\% or p=99.99%p = 99.99\%.

spercentile=percentile(โˆฃxโˆฃ,p)2bโˆ’1โˆ’1s_{\text{percentile}} = \frac{\text{percentile}(|x|, p)}{2^{b-1} - 1}

Values above the percentile threshold are clipped, introducing clipping error for the outliers but dramatically reducing rounding error for the vast majority of values.

class PercentileCalibrator:
    """Percentile calibration: clip at the p-th percentile."""

    def __init__(self, num_bits=8, percentile=99.9, symmetric=True):
        self.num_bits = num_bits
        self.percentile = percentile
        self.symmetric = symmetric
        self.q_max = 2 ** (num_bits - 1) - 1

        self.all_values = []

    def observe(self, x):
        """Collect values for percentile computation."""
        # Store absolute values (flattened)
        self.all_values.append(x.detach().abs().flatten().cpu())

    def compute_scale(self):
        """Compute scale from percentile of observed values."""
        all_abs = torch.cat(self.all_values)

        # Compute percentile
        k = int(len(all_abs) * self.percentile / 100.0)
        k = min(k, len(all_abs) - 1)
        threshold = torch.kthvalue(all_abs, k).values.item()

        scale = threshold / self.q_max
        return max(scale, 1e-8)

    def reset(self):
        self.all_values = []

class EfficientPercentileCalibrator:
    """Memory-efficient percentile calibration using histograms.

    Instead of storing all observed values (which can be GBs for
    activations), maintain a histogram and compute percentile
    from the histogram.
    """

    def __init__(self, num_bits=8, percentile=99.9, num_bins=2048):
        self.num_bits = num_bits
        self.percentile = percentile
        self.q_max = 2 ** (num_bits - 1) - 1
        self.num_bins = num_bins

        self.histogram = torch.zeros(num_bins)
        self.bin_edges = None
        self.max_observed = 0.0

    def observe(self, x):
        """Update histogram with new observations."""
        abs_x = x.detach().abs().flatten().cpu()
        batch_max = abs_x.max().item()

        if batch_max > self.max_observed:
            # Resize histogram to accommodate new range
            old_max = self.max_observed
            self.max_observed = batch_max * 1.1  # 10% headroom

            if self.bin_edges is not None:
                # Re-bin existing histogram into new range
                old_hist = self.histogram.clone()
                self.histogram.zero_()
                old_edges = self.bin_edges
                self.bin_edges = torch.linspace(0, self.max_observed, self.num_bins + 1)

                for i in range(self.num_bins):
                    old_center = (old_edges[i] + old_edges[i + 1]) / 2
                    new_bin = int(old_center / self.max_observed * self.num_bins)
                    new_bin = min(new_bin, self.num_bins - 1)
                    self.histogram[new_bin] += old_hist[i]

        if self.bin_edges is None:
            self.max_observed = max(batch_max * 1.1, 1e-6)
            self.bin_edges = torch.linspace(0, self.max_observed, self.num_bins + 1)

        # Add current batch to histogram
        hist = torch.histc(abs_x, bins=self.num_bins, min=0, max=self.max_observed)
        self.histogram += hist

    def compute_scale(self):
        """Compute scale from histogram percentile."""
        cumsum = torch.cumsum(self.histogram, dim=0)
        total = cumsum[-1].item()
        target = total * self.percentile / 100.0

        # Find bin where cumulative sum crosses the target
        bin_idx = torch.searchsorted(cumsum, target).item()
        bin_idx = min(bin_idx, self.num_bins - 1)

        # Interpolate within the bin
        if bin_idx > 0:
            prev_cum = cumsum[bin_idx - 1].item()
        else:
            prev_cum = 0.0
        bin_count = self.histogram[bin_idx].item()

        if bin_count > 0:
            fraction = (target - prev_cum) / bin_count
        else:
            fraction = 0.5

        threshold = (self.bin_edges[bin_idx] +
                     fraction * (self.bin_edges[bin_idx + 1] - self.bin_edges[bin_idx]))

        scale = threshold.item() / self.q_max
        return max(scale, 1e-8)

Choosing the Percentile

The optimal percentile depends on the distribution shape and the bit width. Lower bit widths (INT4) benefit from more aggressive clipping (lower percentile).

๐Ÿ“Š

Percentile Calibration: Effect of Percentile on INT8 MSE

PercentileClip FractionRounding MSEClipping MSETotal MSE
99.0% 1.00% 1.23e-5 8.41e-4 8.53e-4
99.9% 0.10% 1.98e-5 1.12e-4 1.32e-4
99.99% 0.01% 3.15e-5 2.87e-5 6.02e-5
99.999% 0.001% 7.82e-5 3.41e-6 8.16e-5
100% (MinMax) 0% 1.54e-4 0 1.54e-4
Note: 99.9%-99.99% is the sweet spot for INT8 on distributions with outliers. The minimum total MSE depends on the specific distribution.

Method 3: MSE-Optimal Calibration

Algorithm

Instead of heuristically choosing the scale, search for the scale that minimizes the mean squared error between the original and quantized tensors:

sโˆ—=argโกminโกsE[(xโˆ’Q(x,s))2]s^* = \arg\min_s \mathbb{E}[(x - Q(x, s))^2]

where Q(x,s)Q(x, s) is the quantize-dequantize operation with scale ss.

class MSEOptimalCalibrator:
    """MSE-optimal calibration: find scale that minimizes quantization MSE."""

    def __init__(self, num_bits=8, num_candidates=200, symmetric=True):
        self.num_bits = num_bits
        self.num_candidates = num_candidates
        self.symmetric = symmetric
        self.q_max = 2 ** (num_bits - 1) - 1

        self.all_values = []

    def observe(self, x):
        """Collect values for MSE optimization."""
        self.all_values.append(x.detach().flatten().cpu())

    def compute_scale(self):
        """Find scale that minimizes MSE via grid search."""
        x = torch.cat(self.all_values)
        abs_max = x.abs().max().item()
        minmax_scale = abs_max / self.q_max

        # Search over candidate scales
        # Range: from 10% of MinMax scale to 100% of MinMax scale
        candidate_scales = torch.linspace(
            minmax_scale * 0.1, minmax_scale, self.num_candidates
        )

        best_mse = float('inf')
        best_scale = minmax_scale

        for s in candidate_scales:
            s_val = s.item()
            if s_val < 1e-10:
                continue

            x_hat, _ = quantize_symmetric(x, s_val, self.num_bits)
            mse = ((x - x_hat) ** 2).mean().item()

            if mse < best_mse:
                best_mse = mse
                best_scale = s_val

        return best_scale

    def compute_scale_newton(self, max_iter=20):
        """Find optimal scale using golden section search.

        More efficient than grid search for smooth MSE landscapes.
        """
        x = torch.cat(self.all_values)
        abs_max = x.abs().max().item()
        minmax_scale = abs_max / self.q_max

        def mse_at_scale(s):
            x_hat, _ = quantize_symmetric(x, s, self.num_bits)
            return ((x - x_hat) ** 2).mean().item()

        # Golden section search
        golden = (1 + np.sqrt(5)) / 2
        a = minmax_scale * 0.05
        b = minmax_scale * 1.05
        tol = minmax_scale * 1e-4

        c = b - (b - a) / golden
        d = a + (b - a) / golden

        for _ in range(max_iter):
            if abs(b - a) < tol:
                break

            if mse_at_scale(c) < mse_at_scale(d):
                b = d
            else:
                a = c

            c = b - (b - a) / golden
            d = a + (b - a) / golden

        return (a + b) / 2.0

    def reset(self):
        self.all_values = []

Weighted MSE: Prioritizing Important Values

Not all values contribute equally to model quality. Values near zero contribute little to the output, while large values are disproportionately important. Weighted MSE calibration assigns higher weight to larger values:

class WeightedMSECalibrator:
    """MSE-optimal calibration with value-dependent weighting.

    Weights errors by value magnitude: errors on large values
    matter more than errors on small values because they
    contribute more to the output of matrix multiplications.
    """

    def __init__(self, num_bits=8, num_candidates=200, weight_power=2.0):
        self.num_bits = num_bits
        self.num_candidates = num_candidates
        self.q_max = 2 ** (num_bits - 1) - 1
        self.weight_power = weight_power
        self.all_values = []

    def observe(self, x):
        self.all_values.append(x.detach().flatten().cpu())

    def compute_scale(self):
        x = torch.cat(self.all_values)
        abs_max = x.abs().max().item()
        minmax_scale = abs_max / self.q_max

        # Weights: higher for larger absolute values
        weights = x.abs() ** self.weight_power
        weights = weights / weights.sum()

        candidates = torch.linspace(minmax_scale * 0.1, minmax_scale,
                                     self.num_candidates)

        best_wmse = float('inf')
        best_scale = minmax_scale

        for s in candidates:
            s_val = s.item()
            if s_val < 1e-10:
                continue
            x_hat, _ = quantize_symmetric(x, s_val, self.num_bits)
            wmse = (weights * (x - x_hat) ** 2).sum().item()

            if wmse < best_wmse:
                best_wmse = wmse
                best_scale = s_val

        return best_scale

Method 4: Cross-Layer Calibration (GPTQ-Style)

The Layer-Wise Problem

The previous methods calibrate each tensor independently. But in a neural network, the quantization error of layer ll propagates to layer l+1l+1, where it interacts with the quantization error of that layer. Optimizing each layer independently ignores these interactions.

GPTQ (Frantar et al., 2022) addresses this by optimizing the quantized weights of each layer to minimize the output error of that layer, given the actual (quantized) inputs from the previous layer.

GPTQ Algorithm

For each linear layer with weight matrix WW and calibration input XX:

  1. Compute the Hessian H=2XTXH = 2 X^T X (the second-order information about how weight changes affect output)
  2. For each column ii of WW (processed in order): a. Find the quantized value w^i=Q(wi)\hat{w}_i = Q(w_i) b. Compute the quantization error ฮดi=wiโˆ’w^i\delta_i = w_i - \hat{w}_i c. Update remaining columns to compensate: W:,j>i+=ฮดiโ‹…Hiiโˆ’1โ‹…Hi,j>iW_{:,j>i} += \delta_i \cdot H_{ii}^{-1} \cdot H_{i,j>i}
class GPTQCalibrator:
    """GPTQ-style cross-layer calibration.

    Quantizes weights one column at a time, using second-order
    information (the Hessian) to update remaining weights and
    compensate for quantization error.
    """

    def __init__(self, num_bits=4, group_size=128, sym=True,
                 damp_percent=0.01):
        self.num_bits = num_bits
        self.group_size = group_size
        self.sym = sym
        self.damp_percent = damp_percent
        self.q_max = 2 ** (num_bits - 1) - 1
        self.q_min = -(2 ** (num_bits - 1))

    def quantize_layer(self, weight, hessian):
        """Quantize a weight matrix using GPTQ algorithm.

        Args:
            weight: [out_features, in_features] weight matrix
            hessian: [in_features, in_features] Hessian matrix
                     H = 2 * X^T @ X where X is the input to this layer

        Returns:
            q_weight: Quantized weight matrix
            scales: Per-group scale factors
        """
        W = weight.clone().float()
        n_rows, n_cols = W.shape
        H = hessian.float()

        # Add damping for numerical stability
        damp = self.damp_percent * torch.diag(H).mean()
        H += damp * torch.eye(n_cols, device=H.device)

        # Cholesky decomposition for efficient inverse
        try:
            L = torch.linalg.cholesky(H)
            H_inv = torch.cholesky_inverse(L)
        except RuntimeError:
            # Fallback if not positive definite
            H_inv = torch.linalg.pinv(H)

        Q = torch.zeros_like(W)
        scales = torch.zeros(n_rows, (n_cols + self.group_size - 1) //
                              self.group_size, device=W.device)

        # Process columns in groups
        for col_start in range(0, n_cols, self.group_size):
            col_end = min(col_start + self.group_size, n_cols)
            group_idx = col_start // self.group_size

            # Compute scale for this group
            w_group = W[:, col_start:col_end]
            if self.sym:
                abs_max = w_group.abs().amax(dim=1)
                scale = abs_max / self.q_max
                scale = torch.clamp(scale, min=1e-8)
            else:
                w_min = w_group.min(dim=1).values
                w_max = w_group.max(dim=1).values
                scale = (w_max - w_min) / (2 ** self.num_bits - 1)
                scale = torch.clamp(scale, min=1e-8)

            scales[:, group_idx] = scale

            # Quantize each column and compensate
            for col in range(col_start, col_end):
                w_col = W[:, col]

                # Quantize
                q_col = torch.clamp(
                    torch.round(w_col / scale), self.q_min, self.q_max
                )
                Q[:, col] = q_col * scale

                # Quantization error
                error = w_col - Q[:, col]

                # Compensate remaining columns
                if col + 1 < n_cols:
                    h_inv_diag = H_inv[col, col]
                    if h_inv_diag > 1e-10:
                        compensation = error.unsqueeze(1) * \
                                       H_inv[col, col+1:].unsqueeze(0) / h_inv_diag
                        W[:, col+1:] += compensation

        return Q, scales

    def collect_hessian(self, layer, calibration_data):
        """Collect Hessian from calibration data.

        Run calibration inputs through the layer and accumulate
        H = sum(X^T @ X) over all calibration batches.
        """
        device = next(layer.parameters()).device
        n_cols = layer.in_features
        H = torch.zeros(n_cols, n_cols, device=device, dtype=torch.float32)
        n_samples = 0

        for batch in calibration_data:
            x = batch.to(device)
            # Flatten batch and sequence dimensions
            x = x.reshape(-1, n_cols).float()
            H += x.T @ x
            n_samples += x.shape[0]

        H /= n_samples
        return H
โ„น๏ธ GPTQ Hessian Approximation

The true Hessian of the layer output error with respect to the weights is H=2XTXH = 2 X^T X where XX is the input matrix. This is exact for a linear layer (the output is Y=XWY = XW, so โˆ‚โˆฃโˆฃYโˆ’Y^โˆฃโˆฃ2โˆ‚W=2XT(XWโˆ’XW^)\frac{\partial ||Y - \hat{Y}||^2}{\partial W} = 2 X^T (XW - X\hat{W}) and โˆ‚2โˆ‚W2=2XTX\frac{\partial^2}{\partial W^2} = 2 X^T X). The Hessian captures which weight directions are important: a large diagonal entry HiiH_{ii} means the ii-th column of WW has a large effect on the output, and quantization error in that column is costly.

Complete Calibration Pipeline

End-to-End Pipeline

class CalibrationPipeline:
    """Complete calibration pipeline for PTQ.

    Steps:
    1. Run calibration data through the model
    2. Collect activation statistics at each layer
    3. Compute scale factors using the chosen method
    4. Apply quantization with computed scales
    """

    def __init__(self, model, method='mse', num_bits=8,
                 percentile=99.9, num_candidates=200):
        self.model = model
        self.method = method
        self.num_bits = num_bits
        self.percentile = percentile
        self.num_candidates = num_candidates

        # Create calibrators for each quantizable layer
        self.weight_calibrators = {}
        self.activation_calibrators = {}

        for name, module in model.named_modules():
            if isinstance(module, torch.nn.Linear):
                self.weight_calibrators[name] = self._create_calibrator()
                self.activation_calibrators[name] = self._create_calibrator()

    def _create_calibrator(self):
        if self.method == 'minmax':
            return MinMaxCalibrator(self.num_bits)
        elif self.method == 'percentile':
            return PercentileCalibrator(self.num_bits, self.percentile)
        elif self.method == 'mse':
            return MSEOptimalCalibrator(self.num_bits, self.num_candidates)
        elif self.method == 'weighted_mse':
            return WeightedMSECalibrator(self.num_bits, self.num_candidates)
        else:
            raise ValueError(f"Unknown method: {self.method}")

    def register_hooks(self):
        """Register forward hooks to capture activations."""
        self.hooks = []
        for name, module in self.model.named_modules():
            if name in self.activation_calibrators:
                hook = module.register_forward_hook(
                    self._make_hook(name)
                )
                self.hooks.append(hook)

    def _make_hook(self, layer_name):
        def hook_fn(module, input, output):
            # Observe input activation
            if isinstance(input, tuple):
                x = input[0]
            else:
                x = input
            self.activation_calibrators[layer_name].observe(x)
        return hook_fn

    def calibrate(self, calibration_dataloader, num_batches=32):
        """Run calibration data through the model."""
        self.register_hooks()

        # Observe weights (static, only need to do once)
        for name, module in self.model.named_modules():
            if name in self.weight_calibrators:
                self.weight_calibrators[name].observe(module.weight.data)

        # Run calibration data to observe activations
        self.model.eval()
        with torch.no_grad():
            for i, batch in enumerate(calibration_dataloader):
                if i >= num_batches:
                    break
                input_ids = batch['input_ids'].to(
                    next(self.model.parameters()).device
                )
                self.model(input_ids)

                if (i + 1) % 8 == 0:
                    print(f"Calibration batch {i+1}/{num_batches}")

        # Remove hooks
        for hook in self.hooks:
            hook.remove()

        # Compute scales
        self.weight_scales = {}
        self.activation_scales = {}
        for name in self.weight_calibrators:
            self.weight_scales[name] = \
                self.weight_calibrators[name].compute_scale()
            self.activation_scales[name] = \
                self.activation_calibrators[name].compute_scale()

        print(f"Calibrated {len(self.weight_scales)} layers")
        return self.weight_scales, self.activation_scales

    def apply_quantization(self):
        """Apply computed scales to quantize the model."""
        for name, module in self.model.named_modules():
            if name in self.weight_scales:
                w_scale = self.weight_scales[name]
                if isinstance(w_scale, torch.Tensor):
                    w_scale = w_scale.to(module.weight.device)
                else:
                    w_scale = torch.tensor(w_scale,
                                           device=module.weight.device)

                # Quantize weights
                q_max = 2 ** (self.num_bits - 1) - 1
                q_min = -(2 ** (self.num_bits - 1))
                w_q = torch.clamp(
                    torch.round(module.weight.data / w_scale),
                    q_min, q_max
                )
                module.weight.data = w_q * w_scale

        print("Quantization applied")

Calibration Data Selection

The quality of calibration depends on representative input data. Guidelines:

def prepare_calibration_data(tokenizer, num_samples=128, seq_len=2048,
                              dataset_name='wikitext'):
    """Prepare calibration dataset.

    Best practices:
    - Use 128-512 samples (diminishing returns beyond this)
    - Use sequences of the same length as deployment
    - Use diverse data (not all from one domain)
    - Shuffle to avoid sequential correlation
    """
    # Simulate loading calibration data
    # In practice, load from a dataset like C4, WikiText, or RedPajama
    calibration_texts = [
        "Sample calibration text " * (seq_len // 4)
        for _ in range(num_samples)
    ]

    calibration_tokens = []
    for text in calibration_texts:
        tokens = tokenizer.encode(text, max_length=seq_len,
                                   truncation=True, return_tensors='pt')
        calibration_tokens.append(tokens)

    return calibration_tokens

def calibration_sample_size_study(model, dataloader, method='mse',
                                   num_bits=8):
    """Study how many calibration samples are needed."""
    sample_counts = [4, 8, 16, 32, 64, 128, 256, 512]

    for n_samples in sample_counts:
        pipeline = CalibrationPipeline(model, method=method,
                                        num_bits=num_bits)
        pipeline.calibrate(dataloader, num_batches=n_samples)

        # Measure output quality (e.g., perplexity on held-out data)
        pipeline.apply_quantization()
        # ... evaluate model ...

        print(f"Samples: {n_samples:4d}, Method: {method}")

Benchmarking Calibration Methods

Systematic Comparison

def benchmark_calibration_methods(model_name="llama-2-7b"):
    """Compare all calibration methods on the same model."""
    methods = ['minmax', 'percentile', 'mse', 'weighted_mse']
    bit_widths = [8, 4]

    results = []
    for bits in bit_widths:
        for method in methods:
            # Run calibration and evaluate
            # (Pseudocode -- actual implementation depends on model framework)
            print(f"Calibrating {model_name} with {method} at INT{bits}")
            # result = calibrate_and_evaluate(model, method, bits)
            # results.append(result)

    return results
๐Ÿ“Š

Calibration Method Comparison: Llama-2 7B WikiText-2 Perplexity

MethodINT8 PPLINT8 DeltaINT4 PPLINT4 Delta
FP16 (baseline) 5.47 --- --- ---
MinMax 5.53 +0.06 7.84 +2.37
Percentile (99.9%) 5.50 +0.03 7.12 +1.65
Percentile (99.99%) 5.49 +0.02 6.95 +1.48
MSE-Optimal 5.48 +0.01 6.71 +1.24
Weighted MSE 5.48 +0.01 6.58 +1.11
GPTQ (cross-layer) 5.48 +0.01 5.85 +0.38
Note: At INT8, all methods work well. At INT4, the gap is dramatic: MinMax (2.37 PPL) vs GPTQ (0.38 PPL). Cross-layer calibration is essential for low-bit quantization.

INT4 Calibration Quality by Method (Llama-2 7B, lower is better)

(perplexity delta vs FP16)
MinMax
2.37 perplexity delta vs FP16
Percentile
1.48 perplexity delta vs FP16
MSE-Optimal
1.24 perplexity delta vs FP16
Weighted MSE
1.11 perplexity delta vs FP16
GPTQ 6.2x better than MinMax
0.38 perplexity delta vs FP16

Calibration Time Comparison

๐Ÿ“Š

Calibration Time by Method (Llama-2 7B, A100 80GB)

MethodTimeMemory OverheadRequires Calibration Data
MinMax ~1 minute Negligible Yes (8-32 samples sufficient)
Percentile ~2 minutes Stores all values or histogram Yes (32-128 samples)
MSE-Optimal ~5 minutes Stores all values Yes (32-128 samples)
GPTQ (128 group) ~30 minutes Hessian per layer (~100 MB) Yes (128 samples)
AWQ ~20 minutes Activation statistics Yes (128 samples)
Note: GPTQ takes 10-30x longer than simple calibration methods but produces dramatically better INT4 quality.
๐Ÿ’ก Calibration Method Selection Guide

INT8: Use MinMax or Percentile. The quality difference between methods is negligible at 8-bit. MinMax is fastest. INT4 weights, FP16 activations: Use GPTQ or AWQ. Simple calibration methods produce unacceptable quality at 4-bit. INT4 weights + INT8 activations (W4A8): Use GPTQ for weights and percentile/MSE for activations. FP8: Use per-tensor MinMax with delayed scaling (the FP8 range is large enough that outliers are less problematic).

Summary

Calibration determines the scale factors for post-training quantization. MinMax is fast but outlier-sensitive. Percentile clips outliers at a threshold, reducing rounding error at the cost of clipping error. MSE-optimal finds the scale that minimizes total quantization error through search. GPTQ performs cross-layer optimization, adjusting remaining weights to compensate for each quantized column using second-order information.

The practical guidance is straightforward: for INT8, any calibration method works; for INT4, cross-layer methods (GPTQ, AWQ) are required for acceptable quality. Calibration data should be 128+ diverse samples at deployment sequence length. The entire calibration process takes minutes (simple methods) to an hour (GPTQ), which is negligible compared to training time but critical for quantized model quality.