Quantization requires choosing a scale factor that maps floating-point values to integer grid points. For symmetric INT8 quantization, the scale determines the mapping: , where . The choice of determines the tradeoff between clipping error (values outside the representable range are clamped) and rounding error (values inside the range are rounded to the nearest grid point).
The simplest approach โ set from the observed min/max values โ works for well-behaved distributions but fails when outliers are present (which they always are in LLM activations). Better calibration methods clip outliers (percentile), minimize total quantization error (MSE-optimal), or jointly optimize across layers (GPTQ-style). The difference between MinMax and MSE-optimal calibration can be 0.5+ perplexity points on a 7B model at INT8.
This post implements four calibration methods from scratch, benchmarks them on realistic distributions, and builds a complete calibration pipeline for LLM quantization.
The Scale Factor Problem
What the Scale Factor Controls
For symmetric quantization with bits, the quantized value is:
The dequantized value is:
The quantization error for a single value is . This error has two components:
-
Rounding error: When falls between two grid points, it is rounded to the nearest one. Maximum rounding error is .
-
Clipping error: When (or for the positive side), the value is clamped. Clipping error is .
The scale controls the tradeoff: a larger reduces clipping error (wider range) but increases rounding error (coarser grid). A smaller reduces rounding error but clips more values.
import torch
import numpy as np
from typing import Tuple
def quantize_symmetric(x, scale, num_bits=8):
"""Symmetric quantization with given scale."""
q_max = 2 ** (num_bits - 1) - 1
q_min = -(2 ** (num_bits - 1))
q = torch.clamp(torch.round(x / scale), q_min, q_max)
x_hat = q * scale
return x_hat, q
def compute_errors(x, x_hat, scale, num_bits=8):
"""Decompose total error into rounding and clipping components."""
q_max = 2 ** (num_bits - 1) - 1
# Values that were clipped
clipped_mask = x.abs() > scale * q_max
unclipped_mask = ~clipped_mask
# Total error
total_mse = ((x - x_hat) ** 2).mean().item()
# Rounding error (unclipped values only)
if unclipped_mask.any():
rounding_mse = ((x[unclipped_mask] - x_hat[unclipped_mask]) ** 2).mean().item()
else:
rounding_mse = 0.0
# Clipping error (clipped values only)
if clipped_mask.any():
clipping_mse = ((x[clipped_mask] - x_hat[clipped_mask]) ** 2).mean().item()
else:
clipping_mse = 0.0
clip_fraction = clipped_mask.float().mean().item()
return {
'total_mse': total_mse,
'rounding_mse': rounding_mse,
'clipping_mse': clipping_mse,
'clip_fraction': clip_fraction,
}
Visualizing the Tradeoff
def scale_tradeoff_analysis(x, num_bits=8, num_scales=100):
"""Show how total error varies with scale factor."""
q_max = 2 ** (num_bits - 1) - 1
abs_max = x.abs().max().item()
# Try scales from very small to abs_max/q_max (MinMax scale)
minmax_scale = abs_max / q_max
scales = torch.linspace(minmax_scale * 0.1, minmax_scale * 1.5, num_scales)
results = []
for s in scales:
x_hat, _ = quantize_symmetric(x, s.item(), num_bits)
errors = compute_errors(x, x_hat, s.item(), num_bits)
results.append({
'scale': s.item(),
'scale_ratio': s.item() / minmax_scale,
**errors
})
# Find optimal scale (minimum total MSE)
best = min(results, key=lambda r: r['total_mse'])
print(f"MinMax scale: {minmax_scale:.6f}")
print(f"Optimal scale: {best['scale']:.6f} "
f"({best['scale_ratio']:.3f}x MinMax)")
print(f"MinMax MSE: {results[-1]['total_mse']:.8f}")
print(f"Optimal MSE: {best['total_mse']:.8f}")
print(f"MSE reduction: {(1 - best['total_mse']/results[-1]['total_mse'])*100:.1f}%")
return results
# Test with a distribution that has outliers
torch.manual_seed(42)
x_normal = torch.randn(100000) * 0.5
# Add 0.1% outliers at 10x magnitude
outlier_idx = torch.randperm(100000)[:100]
x_normal[outlier_idx] *= 10.0
scale_tradeoff_analysis(x_normal)
Method 1: MinMax Calibration
Algorithm
The simplest calibration method. Observe the minimum and maximum values of the tensor and set the scale to cover the full range:
class MinMaxCalibrator:
"""MinMax calibration: scale from observed min/max."""
def __init__(self, num_bits=8, symmetric=True, per_channel=False):
self.num_bits = num_bits
self.symmetric = symmetric
self.per_channel = per_channel
self.q_max = 2 ** (num_bits - 1) - 1
self.running_min = None
self.running_max = None
self.num_batches = 0
def observe(self, x):
"""Record min/max from a batch of data."""
if self.per_channel:
# Reduce over all dims except the first (output channel)
batch_min = x.reshape(x.shape[0], -1).min(dim=1).values
batch_max = x.reshape(x.shape[0], -1).max(dim=1).values
else:
batch_min = x.min()
batch_max = x.max()
if self.running_min is None:
self.running_min = batch_min.clone()
self.running_max = batch_max.clone()
else:
self.running_min = torch.min(self.running_min, batch_min)
self.running_max = torch.max(self.running_max, batch_max)
self.num_batches += 1
def compute_scale(self):
"""Compute scale from observed min/max."""
if self.symmetric:
abs_max = torch.max(self.running_min.abs(), self.running_max.abs())
scale = abs_max / self.q_max
else:
scale = (self.running_max - self.running_min) / (2 ** self.num_bits - 1)
return torch.clamp(scale, min=1e-8)
def reset(self):
self.running_min = None
self.running_max = None
self.num_batches = 0
Weakness: Outlier Sensitivity
MinMax calibration is dominated by the single largest value. If one activation outlier is 100x larger than the typical value, the scale is set to accommodate that outlier, wasting most of the quantization range on values that never occur.
def demonstrate_minmax_outlier_problem():
"""Show how a single outlier destroys MinMax calibration quality."""
torch.manual_seed(42)
# Normal distribution
x = torch.randn(10000) * 0.5
# MinMax without outliers
scale_clean = x.abs().max().item() / 127.0
x_hat_clean, _ = quantize_symmetric(x, scale_clean)
mse_clean = ((x - x_hat_clean) ** 2).mean().item()
# Add a single outlier
x_outlier = x.clone()
x_outlier[0] = 50.0 # 100x typical magnitude
scale_outlier = x_outlier.abs().max().item() / 127.0
x_hat_outlier, _ = quantize_symmetric(x_outlier, scale_outlier)
mse_outlier = ((x_outlier - x_hat_outlier) ** 2).mean().item()
print(f"Without outlier: scale={scale_clean:.6f}, MSE={mse_clean:.8f}")
print(f"With outlier: scale={scale_outlier:.6f}, MSE={mse_outlier:.8f}")
print(f"MSE increase: {mse_outlier/mse_clean:.1f}x")
print(f"Scale increase: {scale_outlier/scale_clean:.1f}x")
demonstrate_minmax_outlier_problem()
PyTorchโs built-in quantization uses MinMax calibration by default. For weights (which are well-behaved), this is usually adequate. For activations (which have outliers), MinMax should never be used without additional techniques like SmoothQuant.
Method 2: Percentile Calibration
Algorithm
Instead of using the absolute min/max, clip at the -th percentile of the absolute value distribution. Common choices are or .
Values above the percentile threshold are clipped, introducing clipping error for the outliers but dramatically reducing rounding error for the vast majority of values.
class PercentileCalibrator:
"""Percentile calibration: clip at the p-th percentile."""
def __init__(self, num_bits=8, percentile=99.9, symmetric=True):
self.num_bits = num_bits
self.percentile = percentile
self.symmetric = symmetric
self.q_max = 2 ** (num_bits - 1) - 1
self.all_values = []
def observe(self, x):
"""Collect values for percentile computation."""
# Store absolute values (flattened)
self.all_values.append(x.detach().abs().flatten().cpu())
def compute_scale(self):
"""Compute scale from percentile of observed values."""
all_abs = torch.cat(self.all_values)
# Compute percentile
k = int(len(all_abs) * self.percentile / 100.0)
k = min(k, len(all_abs) - 1)
threshold = torch.kthvalue(all_abs, k).values.item()
scale = threshold / self.q_max
return max(scale, 1e-8)
def reset(self):
self.all_values = []
class EfficientPercentileCalibrator:
"""Memory-efficient percentile calibration using histograms.
Instead of storing all observed values (which can be GBs for
activations), maintain a histogram and compute percentile
from the histogram.
"""
def __init__(self, num_bits=8, percentile=99.9, num_bins=2048):
self.num_bits = num_bits
self.percentile = percentile
self.q_max = 2 ** (num_bits - 1) - 1
self.num_bins = num_bins
self.histogram = torch.zeros(num_bins)
self.bin_edges = None
self.max_observed = 0.0
def observe(self, x):
"""Update histogram with new observations."""
abs_x = x.detach().abs().flatten().cpu()
batch_max = abs_x.max().item()
if batch_max > self.max_observed:
# Resize histogram to accommodate new range
old_max = self.max_observed
self.max_observed = batch_max * 1.1 # 10% headroom
if self.bin_edges is not None:
# Re-bin existing histogram into new range
old_hist = self.histogram.clone()
self.histogram.zero_()
old_edges = self.bin_edges
self.bin_edges = torch.linspace(0, self.max_observed, self.num_bins + 1)
for i in range(self.num_bins):
old_center = (old_edges[i] + old_edges[i + 1]) / 2
new_bin = int(old_center / self.max_observed * self.num_bins)
new_bin = min(new_bin, self.num_bins - 1)
self.histogram[new_bin] += old_hist[i]
if self.bin_edges is None:
self.max_observed = max(batch_max * 1.1, 1e-6)
self.bin_edges = torch.linspace(0, self.max_observed, self.num_bins + 1)
# Add current batch to histogram
hist = torch.histc(abs_x, bins=self.num_bins, min=0, max=self.max_observed)
self.histogram += hist
def compute_scale(self):
"""Compute scale from histogram percentile."""
cumsum = torch.cumsum(self.histogram, dim=0)
total = cumsum[-1].item()
target = total * self.percentile / 100.0
# Find bin where cumulative sum crosses the target
bin_idx = torch.searchsorted(cumsum, target).item()
bin_idx = min(bin_idx, self.num_bins - 1)
# Interpolate within the bin
if bin_idx > 0:
prev_cum = cumsum[bin_idx - 1].item()
else:
prev_cum = 0.0
bin_count = self.histogram[bin_idx].item()
if bin_count > 0:
fraction = (target - prev_cum) / bin_count
else:
fraction = 0.5
threshold = (self.bin_edges[bin_idx] +
fraction * (self.bin_edges[bin_idx + 1] - self.bin_edges[bin_idx]))
scale = threshold.item() / self.q_max
return max(scale, 1e-8)
Choosing the Percentile
The optimal percentile depends on the distribution shape and the bit width. Lower bit widths (INT4) benefit from more aggressive clipping (lower percentile).
Percentile Calibration: Effect of Percentile on INT8 MSE
| Percentile | Clip Fraction | Rounding MSE | Clipping MSE | Total MSE |
|---|---|---|---|---|
| 99.0% | 1.00% | 1.23e-5 | 8.41e-4 | 8.53e-4 |
| 99.9% | 0.10% | 1.98e-5 | 1.12e-4 | 1.32e-4 |
| 99.99% | 0.01% | 3.15e-5 | 2.87e-5 | 6.02e-5 |
| 99.999% | 0.001% | 7.82e-5 | 3.41e-6 | 8.16e-5 |
| 100% (MinMax) | 0% | 1.54e-4 | 0 | 1.54e-4 |
Method 3: MSE-Optimal Calibration
Algorithm
Instead of heuristically choosing the scale, search for the scale that minimizes the mean squared error between the original and quantized tensors:
where is the quantize-dequantize operation with scale .
class MSEOptimalCalibrator:
"""MSE-optimal calibration: find scale that minimizes quantization MSE."""
def __init__(self, num_bits=8, num_candidates=200, symmetric=True):
self.num_bits = num_bits
self.num_candidates = num_candidates
self.symmetric = symmetric
self.q_max = 2 ** (num_bits - 1) - 1
self.all_values = []
def observe(self, x):
"""Collect values for MSE optimization."""
self.all_values.append(x.detach().flatten().cpu())
def compute_scale(self):
"""Find scale that minimizes MSE via grid search."""
x = torch.cat(self.all_values)
abs_max = x.abs().max().item()
minmax_scale = abs_max / self.q_max
# Search over candidate scales
# Range: from 10% of MinMax scale to 100% of MinMax scale
candidate_scales = torch.linspace(
minmax_scale * 0.1, minmax_scale, self.num_candidates
)
best_mse = float('inf')
best_scale = minmax_scale
for s in candidate_scales:
s_val = s.item()
if s_val < 1e-10:
continue
x_hat, _ = quantize_symmetric(x, s_val, self.num_bits)
mse = ((x - x_hat) ** 2).mean().item()
if mse < best_mse:
best_mse = mse
best_scale = s_val
return best_scale
def compute_scale_newton(self, max_iter=20):
"""Find optimal scale using golden section search.
More efficient than grid search for smooth MSE landscapes.
"""
x = torch.cat(self.all_values)
abs_max = x.abs().max().item()
minmax_scale = abs_max / self.q_max
def mse_at_scale(s):
x_hat, _ = quantize_symmetric(x, s, self.num_bits)
return ((x - x_hat) ** 2).mean().item()
# Golden section search
golden = (1 + np.sqrt(5)) / 2
a = minmax_scale * 0.05
b = minmax_scale * 1.05
tol = minmax_scale * 1e-4
c = b - (b - a) / golden
d = a + (b - a) / golden
for _ in range(max_iter):
if abs(b - a) < tol:
break
if mse_at_scale(c) < mse_at_scale(d):
b = d
else:
a = c
c = b - (b - a) / golden
d = a + (b - a) / golden
return (a + b) / 2.0
def reset(self):
self.all_values = []
Weighted MSE: Prioritizing Important Values
Not all values contribute equally to model quality. Values near zero contribute little to the output, while large values are disproportionately important. Weighted MSE calibration assigns higher weight to larger values:
class WeightedMSECalibrator:
"""MSE-optimal calibration with value-dependent weighting.
Weights errors by value magnitude: errors on large values
matter more than errors on small values because they
contribute more to the output of matrix multiplications.
"""
def __init__(self, num_bits=8, num_candidates=200, weight_power=2.0):
self.num_bits = num_bits
self.num_candidates = num_candidates
self.q_max = 2 ** (num_bits - 1) - 1
self.weight_power = weight_power
self.all_values = []
def observe(self, x):
self.all_values.append(x.detach().flatten().cpu())
def compute_scale(self):
x = torch.cat(self.all_values)
abs_max = x.abs().max().item()
minmax_scale = abs_max / self.q_max
# Weights: higher for larger absolute values
weights = x.abs() ** self.weight_power
weights = weights / weights.sum()
candidates = torch.linspace(minmax_scale * 0.1, minmax_scale,
self.num_candidates)
best_wmse = float('inf')
best_scale = minmax_scale
for s in candidates:
s_val = s.item()
if s_val < 1e-10:
continue
x_hat, _ = quantize_symmetric(x, s_val, self.num_bits)
wmse = (weights * (x - x_hat) ** 2).sum().item()
if wmse < best_wmse:
best_wmse = wmse
best_scale = s_val
return best_scale
Method 4: Cross-Layer Calibration (GPTQ-Style)
The Layer-Wise Problem
The previous methods calibrate each tensor independently. But in a neural network, the quantization error of layer propagates to layer , where it interacts with the quantization error of that layer. Optimizing each layer independently ignores these interactions.
GPTQ (Frantar et al., 2022) addresses this by optimizing the quantized weights of each layer to minimize the output error of that layer, given the actual (quantized) inputs from the previous layer.
GPTQ Algorithm
For each linear layer with weight matrix and calibration input :
- Compute the Hessian (the second-order information about how weight changes affect output)
- For each column of (processed in order): a. Find the quantized value b. Compute the quantization error c. Update remaining columns to compensate:
class GPTQCalibrator:
"""GPTQ-style cross-layer calibration.
Quantizes weights one column at a time, using second-order
information (the Hessian) to update remaining weights and
compensate for quantization error.
"""
def __init__(self, num_bits=4, group_size=128, sym=True,
damp_percent=0.01):
self.num_bits = num_bits
self.group_size = group_size
self.sym = sym
self.damp_percent = damp_percent
self.q_max = 2 ** (num_bits - 1) - 1
self.q_min = -(2 ** (num_bits - 1))
def quantize_layer(self, weight, hessian):
"""Quantize a weight matrix using GPTQ algorithm.
Args:
weight: [out_features, in_features] weight matrix
hessian: [in_features, in_features] Hessian matrix
H = 2 * X^T @ X where X is the input to this layer
Returns:
q_weight: Quantized weight matrix
scales: Per-group scale factors
"""
W = weight.clone().float()
n_rows, n_cols = W.shape
H = hessian.float()
# Add damping for numerical stability
damp = self.damp_percent * torch.diag(H).mean()
H += damp * torch.eye(n_cols, device=H.device)
# Cholesky decomposition for efficient inverse
try:
L = torch.linalg.cholesky(H)
H_inv = torch.cholesky_inverse(L)
except RuntimeError:
# Fallback if not positive definite
H_inv = torch.linalg.pinv(H)
Q = torch.zeros_like(W)
scales = torch.zeros(n_rows, (n_cols + self.group_size - 1) //
self.group_size, device=W.device)
# Process columns in groups
for col_start in range(0, n_cols, self.group_size):
col_end = min(col_start + self.group_size, n_cols)
group_idx = col_start // self.group_size
# Compute scale for this group
w_group = W[:, col_start:col_end]
if self.sym:
abs_max = w_group.abs().amax(dim=1)
scale = abs_max / self.q_max
scale = torch.clamp(scale, min=1e-8)
else:
w_min = w_group.min(dim=1).values
w_max = w_group.max(dim=1).values
scale = (w_max - w_min) / (2 ** self.num_bits - 1)
scale = torch.clamp(scale, min=1e-8)
scales[:, group_idx] = scale
# Quantize each column and compensate
for col in range(col_start, col_end):
w_col = W[:, col]
# Quantize
q_col = torch.clamp(
torch.round(w_col / scale), self.q_min, self.q_max
)
Q[:, col] = q_col * scale
# Quantization error
error = w_col - Q[:, col]
# Compensate remaining columns
if col + 1 < n_cols:
h_inv_diag = H_inv[col, col]
if h_inv_diag > 1e-10:
compensation = error.unsqueeze(1) * \
H_inv[col, col+1:].unsqueeze(0) / h_inv_diag
W[:, col+1:] += compensation
return Q, scales
def collect_hessian(self, layer, calibration_data):
"""Collect Hessian from calibration data.
Run calibration inputs through the layer and accumulate
H = sum(X^T @ X) over all calibration batches.
"""
device = next(layer.parameters()).device
n_cols = layer.in_features
H = torch.zeros(n_cols, n_cols, device=device, dtype=torch.float32)
n_samples = 0
for batch in calibration_data:
x = batch.to(device)
# Flatten batch and sequence dimensions
x = x.reshape(-1, n_cols).float()
H += x.T @ x
n_samples += x.shape[0]
H /= n_samples
return H
The true Hessian of the layer output error with respect to the weights is where is the input matrix. This is exact for a linear layer (the output is , so and ). The Hessian captures which weight directions are important: a large diagonal entry means the -th column of has a large effect on the output, and quantization error in that column is costly.
Complete Calibration Pipeline
End-to-End Pipeline
class CalibrationPipeline:
"""Complete calibration pipeline for PTQ.
Steps:
1. Run calibration data through the model
2. Collect activation statistics at each layer
3. Compute scale factors using the chosen method
4. Apply quantization with computed scales
"""
def __init__(self, model, method='mse', num_bits=8,
percentile=99.9, num_candidates=200):
self.model = model
self.method = method
self.num_bits = num_bits
self.percentile = percentile
self.num_candidates = num_candidates
# Create calibrators for each quantizable layer
self.weight_calibrators = {}
self.activation_calibrators = {}
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
self.weight_calibrators[name] = self._create_calibrator()
self.activation_calibrators[name] = self._create_calibrator()
def _create_calibrator(self):
if self.method == 'minmax':
return MinMaxCalibrator(self.num_bits)
elif self.method == 'percentile':
return PercentileCalibrator(self.num_bits, self.percentile)
elif self.method == 'mse':
return MSEOptimalCalibrator(self.num_bits, self.num_candidates)
elif self.method == 'weighted_mse':
return WeightedMSECalibrator(self.num_bits, self.num_candidates)
else:
raise ValueError(f"Unknown method: {self.method}")
def register_hooks(self):
"""Register forward hooks to capture activations."""
self.hooks = []
for name, module in self.model.named_modules():
if name in self.activation_calibrators:
hook = module.register_forward_hook(
self._make_hook(name)
)
self.hooks.append(hook)
def _make_hook(self, layer_name):
def hook_fn(module, input, output):
# Observe input activation
if isinstance(input, tuple):
x = input[0]
else:
x = input
self.activation_calibrators[layer_name].observe(x)
return hook_fn
def calibrate(self, calibration_dataloader, num_batches=32):
"""Run calibration data through the model."""
self.register_hooks()
# Observe weights (static, only need to do once)
for name, module in self.model.named_modules():
if name in self.weight_calibrators:
self.weight_calibrators[name].observe(module.weight.data)
# Run calibration data to observe activations
self.model.eval()
with torch.no_grad():
for i, batch in enumerate(calibration_dataloader):
if i >= num_batches:
break
input_ids = batch['input_ids'].to(
next(self.model.parameters()).device
)
self.model(input_ids)
if (i + 1) % 8 == 0:
print(f"Calibration batch {i+1}/{num_batches}")
# Remove hooks
for hook in self.hooks:
hook.remove()
# Compute scales
self.weight_scales = {}
self.activation_scales = {}
for name in self.weight_calibrators:
self.weight_scales[name] = \
self.weight_calibrators[name].compute_scale()
self.activation_scales[name] = \
self.activation_calibrators[name].compute_scale()
print(f"Calibrated {len(self.weight_scales)} layers")
return self.weight_scales, self.activation_scales
def apply_quantization(self):
"""Apply computed scales to quantize the model."""
for name, module in self.model.named_modules():
if name in self.weight_scales:
w_scale = self.weight_scales[name]
if isinstance(w_scale, torch.Tensor):
w_scale = w_scale.to(module.weight.device)
else:
w_scale = torch.tensor(w_scale,
device=module.weight.device)
# Quantize weights
q_max = 2 ** (self.num_bits - 1) - 1
q_min = -(2 ** (self.num_bits - 1))
w_q = torch.clamp(
torch.round(module.weight.data / w_scale),
q_min, q_max
)
module.weight.data = w_q * w_scale
print("Quantization applied")
Calibration Data Selection
The quality of calibration depends on representative input data. Guidelines:
def prepare_calibration_data(tokenizer, num_samples=128, seq_len=2048,
dataset_name='wikitext'):
"""Prepare calibration dataset.
Best practices:
- Use 128-512 samples (diminishing returns beyond this)
- Use sequences of the same length as deployment
- Use diverse data (not all from one domain)
- Shuffle to avoid sequential correlation
"""
# Simulate loading calibration data
# In practice, load from a dataset like C4, WikiText, or RedPajama
calibration_texts = [
"Sample calibration text " * (seq_len // 4)
for _ in range(num_samples)
]
calibration_tokens = []
for text in calibration_texts:
tokens = tokenizer.encode(text, max_length=seq_len,
truncation=True, return_tensors='pt')
calibration_tokens.append(tokens)
return calibration_tokens
def calibration_sample_size_study(model, dataloader, method='mse',
num_bits=8):
"""Study how many calibration samples are needed."""
sample_counts = [4, 8, 16, 32, 64, 128, 256, 512]
for n_samples in sample_counts:
pipeline = CalibrationPipeline(model, method=method,
num_bits=num_bits)
pipeline.calibrate(dataloader, num_batches=n_samples)
# Measure output quality (e.g., perplexity on held-out data)
pipeline.apply_quantization()
# ... evaluate model ...
print(f"Samples: {n_samples:4d}, Method: {method}")
Benchmarking Calibration Methods
Systematic Comparison
def benchmark_calibration_methods(model_name="llama-2-7b"):
"""Compare all calibration methods on the same model."""
methods = ['minmax', 'percentile', 'mse', 'weighted_mse']
bit_widths = [8, 4]
results = []
for bits in bit_widths:
for method in methods:
# Run calibration and evaluate
# (Pseudocode -- actual implementation depends on model framework)
print(f"Calibrating {model_name} with {method} at INT{bits}")
# result = calibrate_and_evaluate(model, method, bits)
# results.append(result)
return results
Calibration Method Comparison: Llama-2 7B WikiText-2 Perplexity
| Method | INT8 PPL | INT8 Delta | INT4 PPL | INT4 Delta |
|---|---|---|---|---|
| FP16 (baseline) | 5.47 | --- | --- | --- |
| MinMax | 5.53 | +0.06 | 7.84 | +2.37 |
| Percentile (99.9%) | 5.50 | +0.03 | 7.12 | +1.65 |
| Percentile (99.99%) | 5.49 | +0.02 | 6.95 | +1.48 |
| MSE-Optimal | 5.48 | +0.01 | 6.71 | +1.24 |
| Weighted MSE | 5.48 | +0.01 | 6.58 | +1.11 |
| GPTQ (cross-layer) | 5.48 | +0.01 | 5.85 | +0.38 |
INT4 Calibration Quality by Method (Llama-2 7B, lower is better)
(perplexity delta vs FP16)Calibration Time Comparison
Calibration Time by Method (Llama-2 7B, A100 80GB)
| Method | Time | Memory Overhead | Requires Calibration Data |
|---|---|---|---|
| MinMax | ~1 minute | Negligible | Yes (8-32 samples sufficient) |
| Percentile | ~2 minutes | Stores all values or histogram | Yes (32-128 samples) |
| MSE-Optimal | ~5 minutes | Stores all values | Yes (32-128 samples) |
| GPTQ (128 group) | ~30 minutes | Hessian per layer (~100 MB) | Yes (128 samples) |
| AWQ | ~20 minutes | Activation statistics | Yes (128 samples) |
INT8: Use MinMax or Percentile. The quality difference between methods is negligible at 8-bit. MinMax is fastest. INT4 weights, FP16 activations: Use GPTQ or AWQ. Simple calibration methods produce unacceptable quality at 4-bit. INT4 weights + INT8 activations (W4A8): Use GPTQ for weights and percentile/MSE for activations. FP8: Use per-tensor MinMax with delayed scaling (the FP8 range is large enough that outliers are less problematic).
Summary
Calibration determines the scale factors for post-training quantization. MinMax is fast but outlier-sensitive. Percentile clips outliers at a threshold, reducing rounding error at the cost of clipping error. MSE-optimal finds the scale that minimizes total quantization error through search. GPTQ performs cross-layer optimization, adjusting remaining weights to compensate for each quantized column using second-order information.
The practical guidance is straightforward: for INT8, any calibration method works; for INT4, cross-layer methods (GPTQ, AWQ) are required for acceptable quality. Calibration data should be 128+ diverse samples at deployment sequence length. The entire calibration process takes minutes (simple methods) to an hour (GPTQ), which is negligible compared to training time but critical for quantized model quality.