Part of Series Quantization Masterclass 3 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest — Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier — Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 — How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization — The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization — OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

Weight quantization (Part 2) gives you smaller models and faster memory transfers. But the matrix multiplication itself still runs in FP16 — the GPU dequantizes INT4 weights to FP16, then performs an FP16 GEMM with FP16 activations. To get true INT8 or FP8 compute speedup (not just bandwidth savings), you must also quantize the activations. This is W8A8: both weights and activations in INT8, with the GEMM executed using INT8 tensor cores.

The problem: activations are dramatically harder to quantize than weights. Weights are static, well-behaved, and roughly Gaussian. Activations are dynamic (different for every input), and they contain outliers — individual channels with values 10-100x larger than the rest. These outliers destroy INT8 quantization quality if not handled correctly.

This post covers the outlier problem in detail, then implements SmoothQuant (Xiao et al., 2022), the algorithm that solved it by migrating quantization difficulty from activations to weights using a mathematically equivalent transformation.

Why Activations Are Harder to Quantize Than Weights

The Weight Distribution

Trained LLM weights are well-behaved. They have a roughly symmetric distribution centered around zero, with most values within a few standard deviations of the mean. The ratio between the maximum and median absolute weight value is typically 3-5x.

import torch
import torch.nn as nn
import numpy as np

# Simulate typical LLM weight distributions
torch.manual_seed(42)
weight = torch.randn(4096, 4096) * 0.02  # Typical LLM weight scale

max_w = weight.abs().max().item()
median_w = weight.abs().median().item()
print(f"Weight max/median ratio: {max_w / median_w:.1f}x")
# Typically 3-5x -- easy to quantize

The Activation Distribution

Activations after layer normalization and linear projections are a different story. Starting from the OPT-6.7B model (and observed in Llama, GPT, and virtually all large transformers), specific channels in the activation tensor consistently produce values that are 10-100x larger than the other channels. These are called activation outliers or emergent features.

def simulate_activation_outliers(batch_size=32, seq_len=128, hidden=4096,
                                  outlier_fraction=0.01, outlier_magnitude=50.0):
    """Simulate realistic LLM activations with channel-wise outliers.

    In real LLMs, outliers appear in fixed channels across all tokens
    and all inputs. They emerge during training around 6B parameters
    and persist in all larger models.
    """
    # Base activation: roughly Gaussian
    activations = torch.randn(batch_size, seq_len, hidden) * 0.5

    # Outlier channels: fixed positions, large magnitude
    num_outliers = int(hidden * outlier_fraction)
    outlier_channels = torch.randperm(hidden)[:num_outliers]
    activations[:, :, outlier_channels] *= outlier_magnitude

    return activations, outlier_channels

activations, outlier_ch = simulate_activation_outliers()
flat = activations.reshape(-1, 4096)

# Per-channel statistics
channel_max = flat.abs().max(dim=0).values
sorted_max, _ = channel_max.sort(descending=True)
print(f"Top 5 channel max values: {sorted_max[:5].tolist()}")
print(f"Median channel max: {channel_max.median().item():.2f}")
print(f"Max/median ratio: {sorted_max[0].item() / channel_max.median().item():.1f}x")
# Max/median ratio: ~50-100x -- this destroys naive quantization
🚨 The Outlier Problem

When a single channel has values 100x larger than the rest, symmetric INT8 quantization must set the scale factor to accommodate that channel. This means the 99% of channels with normal magnitudes are effectively quantized to only 1-2 bits of precision (values 0 or 1 in INT8), wasting 6-7 bits of the INT8 range. The result: catastrophic quality loss.

Quantifying the Damage

def quantize_per_tensor_int8(tensor):
    """Per-tensor symmetric INT8 quantization."""
    amax = tensor.abs().max()
    scale = amax / 127.0
    q = (tensor / scale).round().clamp(-128, 127).to(torch.int8)
    return q, scale

def quantize_per_token_int8(tensor):
    """Per-token symmetric INT8 quantization.

    tensor: (batch * seq, hidden)
    One scale per row (token).
    """
    amax = tensor.abs().amax(dim=1, keepdim=True)
    scale = amax / 127.0
    scale = scale.clamp(min=1e-10)
    q = (tensor / scale).round().clamp(-128, 127).to(torch.int8)
    return q, scale

# Compare quantization quality with and without outliers
weight = torch.randn(4096, 4096) * 0.02
activations, _ = simulate_activation_outliers()
flat_act = activations.reshape(-1, 4096)

# Reference output
y_ref = flat_act @ weight.T

# Quantize activations per-tensor (naive)
q_act, s_act = quantize_per_tensor_int8(flat_act)
act_recon = q_act.float() * s_act
y_naive = act_recon @ weight.T
mse_naive = ((y_ref - y_naive) ** 2).mean().item()

# Quantize activations per-token
q_act_pt, s_act_pt = quantize_per_token_int8(flat_act)
act_recon_pt = q_act_pt.float() * s_act_pt
y_per_token = act_recon_pt @ weight.T
mse_per_token = ((y_ref - y_per_token) ** 2).mean().item()

print(f"Naive per-tensor INT8 activation MSE: {mse_naive:.6f}")
print(f"Per-token INT8 activation MSE: {mse_per_token:.6f}")
print(f"Improvement: {mse_naive / mse_per_token:.1f}x")

Activation Quantization Error by Granularity

(relative MSE)
Per-tensor outliers dominate
100 relative MSE
Per-token 4x better
25 relative MSE
Per-channel 20x better
5 relative MSE
SmoothQuant 33x better
3 relative MSE

Per-token quantization helps because each token gets its own scale factor, so a single outlier token does not ruin the scale for other tokens. But the outlier channels still consume most of the INT8 range within each token. Per-channel quantization would solve this, but it is incompatible with efficient INT8 GEMM kernels (the GEMM accumulation requires a single scale per output element, which is only possible with per-tensor or per-token activation scaling combined with per-channel weight scaling).

SmoothQuant: The Solution

SmoothQuant (Xiao et al., 2022) resolves the outlier problem with a mathematically equivalent transformation that migrates quantization difficulty from activations to weights. The key equation:

Y=XW=(Xdiag(s)1)(diag(s)W)=X^W^Y = XW = (X \text{diag}(s)^{-1})(\text{diag}(s) W) = \hat{X}\hat{W}

where sRCs \in \mathbb{R}^{C} is a per-channel smoothing factor. This transformation:

  1. Divides each activation channel by sjs_j (shrinking outlier channels)
  2. Multiplies each weight input channel by sjs_j (growing corresponding weight channels)
  3. Preserves the output exactly: X^W^=XW\hat{X}\hat{W} = XW

After smoothing, the activations have a much more uniform distribution across channels (outliers are reduced), while the weights absorb the extra magnitude. Since weights are static, they can tolerate wider dynamic range without quality loss (they are quantized offline with full knowledge of the distribution).

Computing the Smoothing Factor

The smoothing factor balances quantization difficulty between activations and weights:

sj=max(Xj)αmax(Wj)1αs_j = \frac{\max(|X_j|)^{\alpha}}{\max(|W_j|)^{1-\alpha}}

where α[0,1]\alpha \in [0, 1] controls how aggressively to smooth. α=0.5\alpha = 0.5 is the default, meaning equal difficulty split. α\alpha closer to 1.0 smooths activations more aggressively (good when outliers are extreme). The original SmoothQuant paper found α=0.5\alpha = 0.5 works well for most models, with α=0.75\alpha = 0.75 needed for models with very strong outliers (e.g., GLM-130B).

ℹ️ Per-Channel is Key

The smoothing operates per input channel. Channel jj with a large activation outlier gets a large sjs_j, which divides down that channel’s activations and multiplies up that channel’s weights. Channels without outliers get sj1s_j \approx 1 and are barely affected.

Complete SmoothQuant Implementation

class SmoothQuant:
    """SmoothQuant: migrate activation difficulty to weights."""

    def __init__(self, alpha=0.5):
        self.alpha = alpha

    def compute_smoothing_factors(self, activation_scales, weight):
        """Compute per-channel smoothing factors.

        activation_scales: (in_features,) max absolute activation per channel
                          (computed from calibration data)
        weight: (out_features, in_features) original weight matrix

        Returns: (in_features,) smoothing factors
        """
        # Per-channel weight scale: max absolute value per input channel
        weight_scales = weight.abs().amax(dim=0)  # (in_features,)

        # Smoothing factor
        s = activation_scales.pow(self.alpha) / weight_scales.pow(1 - self.alpha)

        # Clamp for numerical stability
        s = s.clamp(min=1e-5)

        return s

    def smooth_layer(self, weight, ln_weight, ln_bias, activation_scales):
        """Apply SmoothQuant transformation to a layer.

        In a transformer, the smoothing is absorbed into the preceding
        LayerNorm parameters, avoiding any runtime cost.

        weight: (out_features, in_features) linear layer weight
        ln_weight: (in_features,) LayerNorm gamma
        ln_bias: (in_features,) LayerNorm beta (can be None)
        activation_scales: (in_features,) from calibration

        Returns: (smoothed_weight, new_ln_weight, new_ln_bias, s)
        """
        s = self.compute_smoothing_factors(activation_scales, weight)

        # Apply to weight: W_hat = diag(s) * W
        smoothed_weight = weight * s.unsqueeze(0)  # Broadcast: (out, in) * (1, in)

        # Absorb into LayerNorm: LN_weight_new = LN_weight / s
        new_ln_weight = ln_weight / s

        new_ln_bias = None
        if ln_bias is not None:
            new_ln_bias = ln_bias / s

        return smoothed_weight, new_ln_weight, new_ln_bias, s

    def calibrate(self, model_forward_fn, calibration_data, layer_names):
        """Run calibration to collect per-channel activation max values.

        model_forward_fn: function that runs the model and returns
                         a dict mapping layer_name to input activations
        calibration_data: list of input tensors
        layer_names: list of layer names to calibrate

        Returns: dict mapping layer_name to (in_features,) activation scales
        """
        scales = {}

        for data in calibration_data:
            layer_activations = model_forward_fn(data)
            for name in layer_names:
                act = layer_activations[name]  # (batch, seq, hidden)
                if act.dim() == 3:
                    act = act.reshape(-1, act.shape[-1])

                # Per-channel max absolute value
                ch_max = act.abs().amax(dim=0)

                if name not in scales:
                    scales[name] = ch_max
                else:
                    scales[name] = torch.maximum(scales[name], ch_max)

        return scales

End-to-End SmoothQuant Application

def apply_smoothquant_to_transformer_block(block, activation_scales, alpha=0.5):
    """Apply SmoothQuant to all linear layers in a transformer block.

    A typical transformer block has:
    - LayerNorm -> Q, K, V projections (self-attention)
    - LayerNorm -> Up, Gate, Down projections (FFN)

    SmoothQuant is applied between each LayerNorm and the subsequent
    linear layers. The smoothing factors are absorbed into the LayerNorm
    parameters, so there is ZERO runtime cost.
    """
    sq = SmoothQuant(alpha=alpha)

    # Smooth attention projections
    # All Q, K, V projections share the same input (post-LayerNorm)
    # Use the activation scale for this input
    attn_act_scale = activation_scales['attn_input']

    for proj_name in ['q_proj', 'k_proj', 'v_proj']:
        proj = getattr(block.self_attn, proj_name)
        smoothed_w, new_ln_w, new_ln_b, s = sq.smooth_layer(
            weight=proj.weight.data,
            ln_weight=block.input_layernorm.weight.data,
            ln_bias=getattr(block.input_layernorm, 'bias',
                           torch.zeros_like(block.input_layernorm.weight.data))
                   if hasattr(block.input_layernorm, 'bias') else None,
            activation_scales=attn_act_scale,
        )
        proj.weight.data = smoothed_w

    # Update LayerNorm parameters (shared across Q, K, V)
    block.input_layernorm.weight.data = new_ln_w
    if new_ln_b is not None and hasattr(block.input_layernorm, 'bias'):
        block.input_layernorm.bias.data = new_ln_b

    # Smooth FFN projections similarly
    ffn_act_scale = activation_scales['ffn_input']
    for proj_name in ['up_proj', 'gate_proj']:
        proj = getattr(block.mlp, proj_name)
        smoothed_w, new_ln_w, new_ln_b, s = sq.smooth_layer(
            weight=proj.weight.data,
            ln_weight=block.post_attention_layernorm.weight.data,
            ln_bias=None,  # RMSNorm has no bias
            activation_scales=ffn_act_scale,
        )
        proj.weight.data = smoothed_w

    block.post_attention_layernorm.weight.data = new_ln_w

    return block

Per-Tensor Dynamic Scaling for Online Quantization

After SmoothQuant smoothing, the activations have a much more uniform distribution. Now we can apply standard per-tensor or per-token INT8 quantization at runtime with acceptable quality.

Static quantization uses fixed scale factors determined during calibration. The scale does not change at inference time. This is faster (no per-token max computation) but less accurate for inputs that differ from the calibration distribution.

Dynamic quantization computes the scale factor from each actual input at runtime. This adds a small overhead (computing the per-tensor or per-token max) but adapts to any input distribution.

class DynamicInt8Quantizer:
    """Runtime dynamic INT8 quantization for activations."""

    @staticmethod
    def quantize_per_tensor(x):
        """Per-tensor dynamic quantization.

        x: (batch * seq, hidden) FP16 tensor
        Returns: (q_x, scale)
        """
        amax = x.abs().max()
        scale = amax / 127.0
        scale = max(scale.item(), 1e-10)
        q_x = (x / scale).round().clamp(-128, 127).to(torch.int8)
        return q_x, scale

    @staticmethod
    def quantize_per_token(x):
        """Per-token dynamic quantization.

        x: (tokens, hidden) FP16 tensor
        Returns: (q_x, scales) where scales is (tokens, 1)
        """
        amax = x.abs().amax(dim=1, keepdim=True)
        scales = amax / 127.0
        scales = scales.clamp(min=1e-10)
        q_x = (x / scales).round().clamp(-128, 127).to(torch.int8)
        return q_x, scales

    @staticmethod
    def quantize_per_channel(x):
        """Per-channel dynamic quantization.

        x: (tokens, hidden) FP16 tensor
        Returns: (q_x, scales) where scales is (1, hidden)
        """
        amax = x.abs().amax(dim=0, keepdim=True)
        scales = amax / 127.0
        scales = scales.clamp(min=1e-10)
        q_x = (x / scales).round().clamp(-128, 127).to(torch.int8)
        return q_x, scales

W8A8: Full INT8 Inference

With SmoothQuant-smoothed weights and dynamic activation quantization, we can implement complete W8A8 inference. Both the weights and activations are in INT8, and the GEMM uses INT8 tensor cores with INT32 accumulation.

The W8A8 GEMM

The INT8 GEMM computes:

Yint32=Xint8Wint8TY_{\text{int32}} = X_{\text{int8}} \cdot W_{\text{int8}}^T

Yfp16=Yint32sxswY_{\text{fp16}} = Y_{\text{int32}} \cdot s_x \cdot s_w

The accumulation in INT32 is critical — INT8 * INT8 products can reach 127×127=16129127 \times 127 = 16129, and summing thousands of these requires 32-bit precision.

class W8A8Linear:
    """INT8 linear layer with SmoothQuant preprocessing."""

    def __init__(self, weight_int8, weight_scale, per_channel=True):
        """
        weight_int8: (out_features, in_features) INT8
        weight_scale: (out_features, 1) or scalar, FP32
        per_channel: whether weight was quantized per-channel
        """
        self.weight_int8 = weight_int8
        self.weight_scale = weight_scale
        self.per_channel = per_channel
        self.out_features = weight_int8.shape[0]
        self.in_features = weight_int8.shape[1]

    @classmethod
    def from_float(cls, linear_layer):
        """Quantize a float linear layer to W8A8."""
        weight = linear_layer.weight.data.float()

        # Per-channel symmetric quantization for weights
        amax = weight.abs().amax(dim=1, keepdim=True)
        scale = amax / 127.0
        scale = scale.clamp(min=1e-10)
        weight_int8 = (weight / scale).round().clamp(-128, 127).to(torch.int8)

        return cls(weight_int8, scale, per_channel=True)

    def forward(self, x):
        """W8A8 forward pass with dynamic per-token activation quantization.

        x: (batch, seq_len, hidden) or (tokens, hidden) FP16
        """
        orig_shape = x.shape
        if x.dim() == 3:
            x = x.reshape(-1, x.shape[-1])

        # Dynamic per-token activation quantization
        act_amax = x.abs().amax(dim=1, keepdim=True)
        act_scale = act_amax / 127.0
        act_scale = act_scale.clamp(min=1e-10)
        x_int8 = (x / act_scale).round().clamp(-128, 127).to(torch.int8)

        # INT8 GEMM with INT32 accumulation
        # On real hardware, this uses INT8 tensor cores via cuBLAS
        y_int32 = torch.matmul(
            x_int8.float(),  # Simulated -- real kernel stays in INT8
            self.weight_int8.float().T
        ).to(torch.int32)

        # Dequantize: y_fp = y_int32 * act_scale * weight_scale
        y_fp = y_int32.float() * act_scale * self.weight_scale.T

        if len(orig_shape) == 3:
            y_fp = y_fp.reshape(orig_shape[0], orig_shape[1], self.out_features)

        return y_fp

def benchmark_w8a8(hidden_size=4096, batch_size=32, seq_len=128):
    """Benchmark W8A8 vs FP16 linear layer."""
    layer = nn.Linear(hidden_size, hidden_size, bias=False).float()
    x = torch.randn(batch_size, seq_len, hidden_size)

    # FP16 reference
    y_ref = layer(x)

    # W8A8
    w8a8 = W8A8Linear.from_float(layer)
    y_w8a8 = w8a8.forward(x)

    mse = ((y_ref - y_w8a8) ** 2).mean().item()
    cos_sim = torch.nn.functional.cosine_similarity(
        y_ref.flatten().unsqueeze(0),
        y_w8a8.flatten().unsqueeze(0)
    ).item()

    print(f"W8A8 vs FP16: MSE={mse:.8f}, cosine_sim={cos_sim:.6f}")
    return mse, cos_sim

benchmark_w8a8()
📊

W8A8 Quality: SmoothQuant + Per-Token Dynamic Scaling

ModelFP16 PPLW8A8 Naive PPLW8A8 SmoothQuant PPLDegradation
OPT-6.7B 10.86 23.54 10.93 +0.07
OPT-13B 10.13 18.91 10.19 +0.06
OPT-30B 9.56 67.82 9.64 +0.08
OPT-66B 9.34 940+ 9.41 +0.07
Llama 7B 5.68 7.92 5.73 +0.05
Llama 13B 5.09 6.18 5.13 +0.04
Note: Without SmoothQuant, W8A8 is catastrophic for large models (OPT-66B perplexity exceeds 940). With SmoothQuant, degradation is consistently under 0.1 PPL points.

W8A8 Throughput Improvement Over FP16 (A100, Llama 7B)

(tokens/sec)
FP16 baseline
1,200 tokens/sec
W8A8 (cuBLAS INT8) 1.63x
1,950 tokens/sec
W4A16 (GPTQ) 1.75x
2,100 tokens/sec
W8A8 vs W4A16

W8A8 and W4A16 serve different needs. W8A8 is better for prefill (compute-bound: INT8 tensor cores are 2x faster than FP16). W4A16 is better for decode (memory-bound: 4x less weight data to read). The best serving systems use W8A8 for prefill and W4A16 for decode, or use FP8 (Part 4) which combines the benefits.

The Complete SmoothQuant Pipeline

Here is the end-to-end pipeline for quantizing a model with SmoothQuant:

class SmoothQuantPipeline:
    """End-to-end SmoothQuant + W8A8 quantization pipeline."""

    def __init__(self, model, alpha=0.5):
        self.model = model
        self.alpha = alpha
        self.activation_scales = {}
        self.smoothing_factors = {}

    def calibrate(self, calibration_loader, num_batches=128):
        """Phase 1: Collect activation statistics."""
        hooks = []
        act_max = {}

        def make_hook(name):
            def hook_fn(module, inp, out):
                x = inp[0]
                if x.dim() == 3:
                    x = x.reshape(-1, x.shape[-1])
                ch_max = x.abs().amax(dim=0).detach()
                if name in act_max:
                    act_max[name] = torch.maximum(act_max[name], ch_max)
                else:
                    act_max[name] = ch_max
            return hook_fn

        # Register hooks on all linear layers
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                hooks.append(module.register_forward_hook(make_hook(name)))

        # Run calibration
        self.model.eval()
        count = 0
        with torch.no_grad():
            for batch in calibration_loader:
                self.model(batch)
                count += 1
                if count >= num_batches:
                    break

        # Remove hooks
        for h in hooks:
            h.remove()

        self.activation_scales = act_max
        print(f"Calibrated {len(act_max)} layers over {count} batches")

    def smooth(self):
        """Phase 2: Apply SmoothQuant transformation."""
        sq = SmoothQuant(self.alpha)

        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear) and name in self.activation_scales:
                act_scale = self.activation_scales[name]
                s = sq.compute_smoothing_factors(act_scale, module.weight.data)

                # Apply smoothing to weight
                module.weight.data *= s.unsqueeze(0)

                self.smoothing_factors[name] = s

        print(f"Smoothed {len(self.smoothing_factors)} layers (alpha={self.alpha})")

    def quantize_weights(self):
        """Phase 3: Quantize smoothed weights to INT8."""
        quantized = {}
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Linear):
                w8a8 = W8A8Linear.from_float(module)
                quantized[name] = w8a8

        print(f"Quantized {len(quantized)} linear layers to INT8")
        return quantized

    def run(self, calibration_loader):
        """Run the full pipeline."""
        print("Phase 1: Calibrating...")
        self.calibrate(calibration_loader)

        print("Phase 2: Smoothing...")
        self.smooth()

        print("Phase 3: Quantizing...")
        return self.quantize_weights()

Advanced: Per-Channel Smoothing Factor Analysis

Understanding which channels get smoothed and by how much gives insight into the model’s internal structure:

def analyze_smoothing(activation_scales, weight, alpha=0.5):
    """Analyze the smoothing factors for a layer."""
    sq = SmoothQuant(alpha)
    s = sq.compute_smoothing_factors(activation_scales, weight)

    print(f"Smoothing factor statistics:")
    print(f"  Min:    {s.min().item():.4f}")
    print(f"  Max:    {s.max().item():.4f}")
    print(f"  Mean:   {s.mean().item():.4f}")
    print(f"  Median: {s.median().item():.4f}")
    print(f"  Std:    {s.std().item():.4f}")
    print(f"  Max/Min ratio: {s.max().item() / s.min().item():.1f}x")

    # Channels with large smoothing factors are the outlier channels
    outlier_threshold = s.mean() + 3 * s.std()
    outlier_mask = s > outlier_threshold
    n_outliers = outlier_mask.sum().item()
    print(f"  Outlier channels (3-sigma): {n_outliers} "
          f"({100 * n_outliers / len(s):.1f}%)")

    # Before and after smoothing
    act_range_before = activation_scales.max() / activation_scales.median()
    act_range_after = (activation_scales / s).max() / (activation_scales / s).median()
    print(f"  Activation range before: {act_range_before:.1f}x")
    print(f"  Activation range after:  {act_range_after:.1f}x")

    return s

# Demonstrate
act_scales = torch.ones(4096) * 0.5
outlier_idx = torch.randperm(4096)[:40]  # ~1% outliers
act_scales[outlier_idx] = 25.0  # 50x larger
weight = torch.randn(4096, 4096) * 0.02

s = analyze_smoothing(act_scales, weight, alpha=0.5)

Alpha Selection: Tuning the Smoothing Strength

The α\alpha parameter controls how aggressively SmoothQuant migrates quantization difficulty from activations to weights. Choosing the right α\alpha is critical for quality.

sj=max(Xj)αmax(Wj)1αs_j = \frac{\max(|X_j|)^{\alpha}}{\max(|W_j|)^{1-\alpha}}

  • α=0\alpha = 0: No smoothing. sj=1/max(Wj)s_j = 1 / \max(|W_j|). All difficulty stays on activations.
  • α=0.5\alpha = 0.5: Equal split. Difficulty is balanced between activations and weights.
  • α=1.0\alpha = 1.0: Maximum smoothing. sj=max(Xj)s_j = \max(|X_j|). All difficulty migrates to weights.

In practice, α\alpha between 0.5 and 0.75 works for most models. Models with stronger outliers need higher α\alpha.

def search_optimal_alpha(layer, calibration_inputs, alphas=None):
    """Grid search for optimal SmoothQuant alpha.

    Tests each alpha value and returns the one that minimizes
    output error after W8A8 quantization.
    """
    if alphas is None:
        alphas = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

    weight = layer.weight.data.clone().float()

    # Collect activation statistics
    all_inputs = []
    for inp in calibration_inputs:
        if inp.dim() == 3:
            inp = inp.reshape(-1, inp.shape[-1])
        all_inputs.append(inp)
    all_inputs = torch.cat(all_inputs, dim=0)

    # Per-channel activation max
    act_scales = all_inputs.abs().amax(dim=0)

    # Reference output
    y_ref = all_inputs @ weight.T

    best_alpha = 0.5
    best_mse = float('inf')

    for alpha in alphas:
        # Compute smoothing factors
        w_scales = weight.abs().amax(dim=0)
        s = act_scales.pow(alpha) / w_scales.pow(1 - alpha)
        s = s.clamp(min=1e-5)
        s = s / s.mean()

        # Apply smoothing
        w_smooth = weight * s.unsqueeze(0)
        x_smooth = all_inputs / s.unsqueeze(0)

        # Quantize both to INT8
        # Weight: per-channel
        w_amax = w_smooth.abs().amax(dim=1, keepdim=True)
        w_scale = w_amax / 127.0
        w_scale = w_scale.clamp(min=1e-10)
        w_q = (w_smooth / w_scale).round().clamp(-128, 127)

        # Activation: per-token
        x_amax = x_smooth.abs().amax(dim=1, keepdim=True)
        x_scale = x_amax / 127.0
        x_scale = x_scale.clamp(min=1e-10)
        x_q = (x_smooth / x_scale).round().clamp(-128, 127)

        # Compute output
        y_q = (x_q * x_scale) @ (w_q * w_scale).T
        mse = ((y_ref - y_q) ** 2).mean().item()

        if mse < best_mse:
            best_mse = mse
            best_alpha = alpha

    return best_alpha, best_mse

# Usage
torch.manual_seed(42)
layer = nn.Linear(4096, 4096, bias=False).float()
calib = [torch.randn(1, 128, 4096) for _ in range(64)]
# Inject outliers in calibration data
for c in calib:
    c[:, :, 42] *= 30.0
    c[:, :, 1337] *= 50.0

alpha, mse = search_optimal_alpha(layer, calib)
print(f"Optimal alpha: {alpha}, MSE: {mse:.8f}")
📊

Effect of Alpha on W8A8 Perplexity (OPT-66B, WikiText-2)

AlphaPerplexityDegradation vs FP16
0.0 (no smoothing) 940+ catastrophic
0.25 12.81 +3.47
0.50 (default) 9.41 +0.07
0.75 9.43 +0.09
1.0 (max smoothing) 9.89 +0.55
Note: Alpha=0.5 is optimal for OPT-66B. Alpha=0.0 is catastrophic because outlier channels consume the entire INT8 range. Alpha=1.0 over-smooths, pushing too much difficulty onto weights.

The Outlier Emergence Phenomenon

Activation outliers are not present in small models. They emerge as models scale beyond approximately 6 billion parameters. This was first documented by Dettmers et al. (2022) in the “LLM.int8()” paper, which showed that OPT models below 6.7B have no significant outliers, while models at 6.7B and above develop persistent outlier channels.

The outliers have distinctive properties:

  1. Fixed channels: The same channel indices produce outliers across all inputs, all tokens, and all layers. Channel 42 might always be an outlier in layer 15.
  2. Consistent magnitude: The outlier magnitude is relatively stable — it does not vary wildly between inputs.
  3. Small number: Typically 0.1-1% of channels are outliers. In a 4096-dimensional hidden state, that is 4-40 channels.
  4. Critical for function: Zeroing out outlier channels catastrophically degrades model quality. They encode important information despite being numerically extreme.
def detect_outlier_channels(activation_scales, threshold_sigma=3.0):
    """Detect outlier channels in activation statistics.

    activation_scales: (hidden_dim,) per-channel max absolute values
    threshold_sigma: channels above mean + threshold * std are outliers

    Returns: outlier channel indices
    """
    mean_scale = activation_scales.mean()
    std_scale = activation_scales.std()
    threshold = mean_scale + threshold_sigma * std_scale

    outlier_mask = activation_scales > threshold
    outlier_indices = torch.where(outlier_mask)[0]

    print(f"Detection results:")
    print(f"  Total channels: {len(activation_scales)}")
    print(f"  Outlier threshold: {threshold:.4f}")
    print(f"  Outlier channels: {len(outlier_indices)} "
          f"({100 * len(outlier_indices) / len(activation_scales):.2f}%)")
    print(f"  Max outlier magnitude: {activation_scales[outlier_mask].max():.4f}")
    print(f"  Median normal magnitude: "
          f"{activation_scales[~outlier_mask].median():.4f}")
    print(f"  Outlier/normal ratio: "
          f"{activation_scales[outlier_mask].max() / activation_scales[~outlier_mask].median():.1f}x")

    return outlier_indices

# Simulate and detect
act_scales = torch.ones(4096) * 0.5
outlier_idx = torch.tensor([42, 137, 256, 512, 1024, 1337, 2048, 3000, 3500, 4000])
act_scales[outlier_idx] = torch.tensor([25.0, 30.0, 18.0, 22.0, 35.0, 50.0, 28.0, 15.0, 20.0, 40.0])
detected = detect_outlier_channels(act_scales)

Static vs Dynamic Quantization Tradeoffs

Beyond SmoothQuant, the choice between static and dynamic activation quantization has significant implications for both quality and latency.

Static quantization computes scale factors during calibration and fixes them at deployment. Every inference request uses the same pre-computed scales. This is faster (no per-request max computation) but assumes the calibration distribution matches production traffic.

Dynamic quantization computes the scale factor from each input at runtime. This adds one reduction operation per linear layer (computing the per-token or per-tensor max) but perfectly adapts to any input distribution.

In practice, the latency cost of dynamic quantization is negligible for large models. The reduction to compute the per-token max of a 4096-dimensional vector is a few microseconds — invisible compared to the GEMM latency. Dynamic quantization is therefore preferred for production serving where input distributions are unpredictable.

📊

Static vs Dynamic Activation Quantization (Llama 13B)

MethodCalibration DataIn-Domain PPLOut-of-Domain PPLLatency Overhead
Static (C4 calib) C4 5.15 6.42 0%
Static (Wiki calib) WikiText 5.13 6.89 0%
Dynamic per-token None needed 5.14 5.14 ~1%
Dynamic per-tensor None needed 5.18 5.18 ~0.5%
Note: Dynamic quantization shows no domain shift degradation because scales adapt to each input. Static quantization degrades when production data differs from calibration data. Latency overhead of dynamic quantization is negligible.

When SmoothQuant Is Not Enough

SmoothQuant works well for W8A8 but has limitations:

W4A4 or W4A8: At 4-bit precision, even smoothed activations have too much quantization error. SmoothQuant can reduce outlier impact, but the fundamental precision is too low for general activation quantization at INT4. This is why W4A16 (INT4 weights, FP16 activations) remains the dominant 4-bit inference format.

Extreme outliers: Some models (notably GLM-130B) have outliers so extreme that no value of α\alpha can fully smooth them. In these cases, a mixed-precision approach is needed: quantize most channels to INT8 and keep the handful of outlier channels in FP16.

Dynamic outlier patterns: SmoothQuant assumes outliers appear in fixed channels. If a model has input-dependent outlier patterns (rare but possible), the smoothing factors computed during calibration may not generalize. Per-token dynamic scaling partially mitigates this.

def mixed_precision_quantize(activations, weight, outlier_threshold=3.0):
    """Mixed-precision approach: INT8 for normal channels, FP16 for outliers."""
    ch_max = activations.abs().amax(dim=0)
    median_max = ch_max.median()

    # Identify outlier channels
    outlier_mask = ch_max > outlier_threshold * median_max
    normal_mask = ~outlier_mask

    n_outlier = outlier_mask.sum().item()
    n_normal = normal_mask.sum().item()
    print(f"Mixed precision: {n_normal} INT8 channels, {n_outlier} FP16 channels")

    # Split computation
    act_normal = activations[:, normal_mask]
    act_outlier = activations[:, outlier_mask]
    w_normal = weight[:, normal_mask]
    w_outlier = weight[:, outlier_mask]

    # INT8 path for normal channels
    q_act, act_scale = quantize_per_tensor_int8(act_normal)
    q_w_amax = w_normal.abs().amax(dim=1, keepdim=True)
    w_scale = q_w_amax / 127.0
    q_w = (w_normal / w_scale).round().clamp(-128, 127).to(torch.int8)

    y_int8 = (q_act.float() @ q_w.float().T) * (act_scale * w_scale.T)

    # FP16 path for outlier channels
    y_fp16 = act_outlier @ w_outlier.T

    return y_int8 + y_fp16

Summary

Activation quantization is fundamentally harder than weight quantization because of outliers: specific channels in LLM activations carry values 10-100x larger than the rest, and these outliers destroy uniform quantization quality.

SmoothQuant solves this by applying a per-channel transformation that divides down outlier activations and multiplies up the corresponding weights. The transformation is mathematically equivalent — the output does not change. After smoothing, standard per-tensor or per-token INT8 quantization works with minimal quality loss.

W8A8 inference quantizes both weights and activations to INT8, enabling INT8 tensor cores for 2x compute throughput over FP16. The key enabler is SmoothQuant for activation smoothing combined with dynamic per-token scaling at inference time.

The next post covers FP8, which provides a more elegant solution to the activation quantization problem by using floating-point representation (better for non-uniform distributions) instead of integer representation.