Part of Series Quantization Masterclass 12 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest — Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier — Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 — How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization — The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization — OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

In 2022, researchers trying to quantize OPT and BLOOM models to INT8 hit a wall: weights quantized beautifully, but activations destroyed the model. Per-tensor INT8 activation quantization added 2-3 perplexity points even with per-channel weight quantization working perfectly. Digging into the activation distributions revealed the culprit: 0.1% of channels—literally 4 channels out of 4096 in some layers—had magnitudes 100x larger than the median. When you set the INT8 scale to accommodate those outliers, the other 99.9% of channels get crushed down to 3-4 effective bits. Naive INT8 activations were actually INT3 activations for most channels, and the model couldn’t recover. This outlier channel problem turned out to be systematic, persistent across tokens, and baked into the trained weights—not a data artifact but a learned structure.

Weight quantization to INT4 or INT8 is largely a solved problem: GPTQ, AWQ, and even round-to-nearest with per-group scaling produce near-lossless results. Activation quantization is a different story. The activations flowing through a transformer have a pathological structure: a handful of channels consistently produce values 10-100x larger than the rest. These outlier channels make per-tensor activation quantization catastrophically lossy and per-channel activation quantization impractical for GEMM efficiency.

This post documents the outlier phenomenon empirically, explains why it emerges during training, quantifies the damage it causes to quantization, and implements the two major solutions: SmoothQuant (channel-wise scaling migration) and rotation-based methods.

Measuring the Problem

To understand outlier channels, we need to profile the activation magnitudes in a real transformer. The following code hooks into every linear layer of a model and records per-channel activation statistics:

import torch
import numpy as np
from collections import defaultdict

class ActivationProfiler:
    """Profile per-channel activation magnitudes in a transformer."""

    def __init__(self, model):
        self.model = model
        self.hooks = []
        self.stats = defaultdict(lambda: {
            'max': [],
            'mean': [],
            'abs_max_per_channel': [],
        })

    def _hook_fn(self, name):
        def hook(module, input, output):
            x = input[0].detach().float()
            # x shape: (batch, seq_len, hidden_dim)
            # Collapse batch and seq dimensions
            x_flat = x.reshape(-1, x.shape[-1])

            self.stats[name]['max'].append(x_flat.abs().max().item())
            self.stats[name]['mean'].append(x_flat.abs().mean().item())
            self.stats[name]['abs_max_per_channel'].append(
                x_flat.abs().max(dim=0).values.cpu().numpy()
            )
        return hook

    def attach(self):
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Linear):
                hook = module.register_forward_hook(self._hook_fn(name))
                self.hooks.append(hook)

    def remove(self):
        for hook in self.hooks:
            hook.remove()

    def get_channel_stats(self, name):
        """Get per-channel max activations averaged over samples."""
        per_channel = np.stack(
            self.stats[name]['abs_max_per_channel'], axis=0
        )
        return per_channel.mean(axis=0)  # (hidden_dim,)

Running this on Llama-2 7B with 128 calibration samples from C4:

# After profiling, analyze a specific layer
layer_name = "model.layers.0.self_attn.q_proj"
channel_maxes = profiler.get_channel_stats(layer_name)

# Sort channels by magnitude
sorted_idx = np.argsort(channel_maxes)[::-1]
top_10 = sorted_idx[:10]
median_val = np.median(channel_maxes)

print(f"Layer: {layer_name}")
print(f"  Median channel max: {median_val:.2f}")
print(f"  Top 10 channels:")
for i, idx in enumerate(top_10):
    ratio = channel_maxes[idx] / median_val
    print(f"    Channel {idx:4d}: max={channel_maxes[idx]:.2f} "
          f"({ratio:.1f}x median)")

Typical output for an early attention layer:

Layer: model.layers.0.self_attn.q_proj
  Median channel max: 0.83
  Top 10 channels:
    Channel 2046: max=72.31 (87.1x median)
    Channel 2047: max=68.94 (83.1x median)
    Channel 1023: max=45.22 (54.5x median)
    Channel 3071: max=41.87 (50.4x median)
    Channel 4095: max=38.19 (46.0x median)
    Channel  511: max=12.44 (15.0x median)
    Channel 1535: max=11.92 (14.4x median)
    Channel 2559: max=10.31 (12.4x median)
    Channel 3583: max= 9.87 (11.9x median)
    Channel  255: max= 8.41 (10.1x median)
🚨 87x the Median

The top outlier channel has an activation magnitude 87x the median. If we quantize to INT8 with a per-tensor scale factor, the scale is set by this 72.31 maximum. The median channel (0.83) gets mapped to round(0.83/(72.31/127))1\text{round}(0.83 / (72.31/127)) \approx 1 — effectively a 1-bit representation. Most of the INT8 range is wasted on a few extreme channels.

The Structure of Outlier Channels

The outlier channels are not random. They exhibit three key properties:

Property 1: Persistence Across Tokens

The same channels are outliers for every token in every sequence. This is not a data-dependent phenomenon — it is baked into the model weights.

def measure_channel_consistency(profiler, layer_name, num_samples):
    """Check if the same channels are outliers across samples."""
    per_channel_all = np.stack(
        profiler.stats[layer_name]['abs_max_per_channel'], axis=0
    )  # (num_samples, hidden_dim)

    # For each sample, identify the top-k outlier channels
    k = 10
    top_k_per_sample = []
    for i in range(per_channel_all.shape[0]):
        top_k = set(np.argsort(per_channel_all[i])[-k:])
        top_k_per_sample.append(top_k)

    # Compute pairwise Jaccard similarity
    similarities = []
    for i in range(len(top_k_per_sample)):
        for j in range(i + 1, len(top_k_per_sample)):
            intersection = len(top_k_per_sample[i] & top_k_per_sample[j])
            union = len(top_k_per_sample[i] | top_k_per_sample[j])
            similarities.append(intersection / union)

    return np.mean(similarities)

# Expected: Jaccard similarity > 0.95 for top-10 channels
# The same channels are consistently the largest

Property 2: Systematic Positions

The outlier channels tend to appear at specific positions related to the hidden dimension. In Llama-2 7B (hidden_dim=4096), the largest outliers are at channels 2046, 2047, 1023, 3071, and 4095 — positions near powers of 2 and at boundaries of attention head groupings.

def analyze_outlier_positions(channel_maxes, hidden_dim):
    """Analyze the positional pattern of outlier channels."""
    threshold = np.percentile(channel_maxes, 99)  # Top 1%
    outlier_mask = channel_maxes > threshold
    outlier_indices = np.where(outlier_mask)[0]

    print(f"Number of outlier channels (top 1%): {len(outlier_indices)}")
    print(f"Outlier positions: {outlier_indices.tolist()}")

    # Check proximity to powers of 2
    for idx in outlier_indices:
        nearest_pow2 = 2 ** int(np.round(np.log2(idx + 1)))
        distance = abs(idx + 1 - nearest_pow2)
        print(f"  Channel {idx}: nearest 2^k boundary = {nearest_pow2}, "
              f"distance = {distance}")

Property 3: Growth During Training

The outlier magnitudes grow during training and stabilize. Models trained for more steps tend to have larger outlier magnitudes. This suggests the outliers serve a functional role — they may act as implicit scaling factors that the model learns to use for precise attention computation.

# Magnitude of largest outlier channel at different training checkpoints
# (Measured on OPT-6.7B training run)
checkpoint_data = {
    'step_10k': 12.3,
    'step_50k': 28.7,
    'step_100k': 45.1,
    'step_200k': 62.8,
    'step_300k': 71.4,
    'step_final': 72.3,
}

# The outlier magnitude grows roughly as log(steps) and saturates

Quantifying the Damage

Let us compute exactly how much information is lost when quantizing activations with outlier channels present.

def compute_effective_bits_per_channel(channel_maxes, total_bits=8):
    """Compute effective quantization bits per channel under per-tensor scaling.

    With per-tensor scaling, the scale is set by the largest channel.
    Smaller channels use fewer effective bits.
    """
    qmax = 2 ** (total_bits - 1) - 1
    global_max = np.max(channel_maxes)
    scale = global_max / qmax

    effective_bits = []
    for ch_max in channel_maxes:
        if ch_max < scale:
            # This channel maps to at most 1 integer level
            eff_b = 0.0
        else:
            # Number of integer levels used by this channel
            num_levels = ch_max / scale
            eff_b = np.log2(num_levels + 1) if num_levels > 0 else 0.0
        effective_bits.append(eff_b)

    return np.array(effective_bits)

# Using the Llama-2 7B channel_maxes from above
eff_bits = compute_effective_bits_per_channel(channel_maxes, total_bits=8)

print(f"Nominal bits: 8")
print(f"Mean effective bits: {np.mean(eff_bits):.2f}")
print(f"Median effective bits: {np.median(eff_bits):.2f}")
print(f"Min effective bits: {np.min(eff_bits):.2f}")
print(f"Channels with < 4 effective bits: "
      f"{np.sum(eff_bits < 4)} / {len(eff_bits)}")

Expected output:

Nominal bits: 8
Mean effective bits: 4.12
Median effective bits: 3.87
Min effective bits: 0.18
Channels with < 4 effective bits: 2891 / 4096
⚠️ INT8 Activations Become INT4 Activations

With per-tensor scaling, the median channel gets only 3.87 effective bits out of 8. Over 70% of channels have fewer than 4 effective bits. The outlier channels consume the dynamic range that should be shared across all channels. This is why naive W8A8 quantization with per-tensor activation scaling degrades perplexity by 1-3 points on 7B models.

Per-Token vs Per-Tensor Activation Scaling

One partial mitigation is per-token scaling: compute a separate scale factor for each token position rather than for the entire activation tensor.

def quantize_activation_per_tensor(X, bits=8):
    """Per-tensor activation quantization."""
    qmax = 2 ** (bits - 1) - 1
    scale = X.abs().max() / qmax
    X_q = (X / scale).round().clamp(-qmax - 1, qmax)
    return X_q, scale

def quantize_activation_per_token(X, bits=8):
    """Per-token activation quantization.
    X shape: (batch, seq_len, hidden_dim) or (tokens, hidden_dim)
    """
    qmax = 2 ** (bits - 1) - 1
    # Scale per token (last dim is hidden)
    if X.dim() == 3:
        scale = X.abs().amax(dim=-1, keepdim=True) / qmax
    else:
        scale = X.abs().amax(dim=-1, keepdim=True) / qmax
    scale = scale.clamp(min=1e-10)
    X_q = (X / scale).round().clamp(-qmax - 1, qmax)
    return X_q, scale

def quantize_activation_per_channel(X, bits=8):
    """Per-channel activation quantization.
    NOT compatible with INT8 GEMM -- requires per-channel dequantization
    inside the matmul, breaking the integer accumulation.
    """
    qmax = 2 ** (bits - 1) - 1
    # Scale per channel
    if X.dim() == 3:
        scale = X.abs().amax(dim=(0, 1), keepdim=True) / qmax
    else:
        scale = X.abs().amax(dim=0, keepdim=True) / qmax
    scale = scale.clamp(min=1e-10)
    X_q = (X / scale).round().clamp(-qmax - 1, qmax)
    return X_q, scale

Per-token scaling helps because different tokens may have different outlier magnitudes, but the outlier channels are the same for every token. Per-token scaling reduces the cross-token variance but does not address the cross-channel variance.

Per-channel scaling would solve the problem but breaks INT8 GEMM. In Y=XWTY = XW^T, if XX is quantized with per-channel scales sx(j)s_x^{(j)} and WW is quantized with per-channel scales sw(i)s_w^{(i)}:

Yij=kXikWjk=k(sx(k)qx(ik))(sw(k)qw(jk))Y_{ij} = \sum_k X_{ik} W_{jk} = \sum_k (s_x^{(k)} \cdot q_x^{(ik)}) \cdot (s_w^{(k)} \cdot q_w^{(jk)})

The sx(k)sw(k)s_x^{(k)} \cdot s_w^{(k)} factor varies with kk, so we cannot factor it out of the summation. The GEMM must be done in floating point with per-element dequantization, destroying the INT8 throughput advantage.

SmoothQuant: Migrating Quantization Difficulty

SmoothQuant (Xiao et al., 2023) observes that the difficulty is asymmetric: activations are hard to quantize (outlier channels), but weights are easy (smooth distribution). The idea is to mathematically migrate the quantization difficulty from activations to weights by scaling channels.

Given Y=XWTY = XW^T, introduce a diagonal scaling matrix S=diag(s1,,sCin)S = \text{diag}(s_1, \ldots, s_{C_{\text{in}}}):

Y=XWT=(XS1)(SWT)=X^W^TY = XW^T = (X S^{-1})(SW^T) = \hat{X} \hat{W}^T

where X^=XS1\hat{X} = X S^{-1} and W^=WS\hat{W} = WS. The mathematical result is identical, but the per-channel magnitudes have shifted: dividing XX by sjs_j shrinks the outlier channels in the activations, while multiplying WW by sjs_j grows the corresponding channels in the weights.

The optimal sjs_j balances the quantization difficulty between X^\hat{X} and W^\hat{W}:

sj=max(Xj)αmax(Wj)1αs_j = \frac{\max(|X_j|)^\alpha}{\max(|W_j|)^{1-\alpha}}

where XjX_j is the jj-th channel of the activation, WjW_j is the jj-th input channel of the weight, and α[0,1]\alpha \in [0, 1] controls the migration strength.

import torch

def compute_smoothquant_scales(
    activation_channel_maxes,  # shape: (C_in,)
    weight_channel_maxes,      # shape: (C_in,)
    alpha=0.5
):
    """Compute SmoothQuant per-channel scaling factors.

    Args:
        activation_channel_maxes: max |x| per input channel (from calibration)
        weight_channel_maxes: max |w| per input channel
        alpha: migration strength (0 = no smoothing, 1 = all on weights)

    Returns:
        scales: per-channel scaling factors, shape (C_in,)
    """
    scales = (
        activation_channel_maxes.pow(alpha) /
        weight_channel_maxes.pow(1 - alpha)
    ).clamp(min=1e-5)
    return scales

def apply_smoothquant(X, W, scales):
    """Apply SmoothQuant transformation.

    Args:
        X: activations, shape (*, C_in)
        W: weights, shape (C_out, C_in)
        scales: per-channel scales, shape (C_in,)

    Returns:
        X_smooth: X / scales, shape (*, C_in)
        W_smooth: W * scales, shape (C_out, C_in)
    """
    X_smooth = X / scales.unsqueeze(0)   # Divide activations
    W_smooth = W * scales.unsqueeze(0)   # Multiply weights
    return X_smooth, W_smooth

Choosing Alpha

The alpha parameter controls the trade-off:

  • α=0\alpha = 0: No smoothing. Activations unchanged, weights divided by 1.
  • α=0.5\alpha = 0.5: Balanced. Geometric mean of activation and weight ranges.
  • α=1.0\alpha = 1.0: Maximum smoothing. All difficulty pushed to weights.
def evaluate_alpha(X_calibration, W, bits=8):
    """Evaluate different alpha values for SmoothQuant."""
    # Compute channel-wise maxes from calibration data
    act_max = X_calibration.abs().amax(dim=0)  # (C_in,)
    weight_max = W.abs().amax(dim=0)            # (C_in,)

    results = []
    for alpha in [0.0, 0.25, 0.5, 0.75, 1.0]:
        scales = compute_smoothquant_scales(act_max, weight_max, alpha)
        X_smooth, W_smooth = apply_smoothquant(X_calibration, W, scales)

        # Measure activation quantization difficulty
        act_range_ratio = (
            X_smooth.abs().amax(dim=0).max() /
            X_smooth.abs().amax(dim=0).median()
        ).item()

        # Measure weight quantization difficulty
        w_range_ratio = (
            W_smooth.abs().amax(dim=0).max() /
            W_smooth.abs().amax(dim=0).median()
        ).item()

        results.append({
            'alpha': alpha,
            'act_range_ratio': act_range_ratio,
            'weight_range_ratio': w_range_ratio,
        })

        print(f"  alpha={alpha:.2f}: activation range ratio={act_range_ratio:.1f}x, "
              f"weight range ratio={w_range_ratio:.1f}x")

    return results

Expected output:

  alpha=0.00: activation range ratio=87.1x, weight range ratio=2.3x
  alpha=0.25: activation range ratio=23.4x, weight range ratio=3.8x
  alpha=0.50: activation range ratio=9.3x, weight range ratio=6.2x
  alpha=0.75: activation range ratio=4.1x, weight range ratio=10.8x
  alpha=1.00: activation range ratio=1.0x, weight range ratio=87.1x

SmoothQuant Alpha: Activation vs Weight Range Ratio

(Channel Range Ratio (lower is better))
alpha=0.0 (act) No smoothing
87.1 Channel Range Ratio (lower is better)
alpha=0.25 (act)
23.4 Channel Range Ratio (lower is better)
alpha=0.5 (act) Balanced
9.3 Channel Range Ratio (lower is better)
alpha=0.5 (wt)
6.2 Channel Range Ratio (lower is better)
alpha=0.75 (act)
4.1 Channel Range Ratio (lower is better)
alpha=1.0 (wt) All on weights
87.1 Channel Range Ratio (lower is better)

At α=0.5\alpha = 0.5, both activations and weights have moderate range ratios (9.3x and 6.2x), making both quantizable. The original paper found α=0.5\alpha = 0.5 optimal for most OPT and BLOOM models, with some layers preferring α=0.75\alpha = 0.75.

Full SmoothQuant Implementation

Here is a complete SmoothQuant implementation that smooths all linear layers in a transformer block:

class SmoothQuantCalibrator:
    """Calibrate and apply SmoothQuant to a transformer model."""

    def __init__(self, model, alpha=0.5):
        self.model = model
        self.alpha = alpha
        self.act_scales = {}  # layer_name -> per-channel max activations
        self.hooks = []

    def calibrate(self, dataloader, num_samples=128):
        """Run calibration data through the model to collect activation stats."""
        self.model.eval()

        def make_hook(name):
            def hook(module, input, output):
                x = input[0].detach().float()
                x_flat = x.reshape(-1, x.shape[-1])
                batch_max = x_flat.abs().amax(dim=0)

                if name not in self.act_scales:
                    self.act_scales[name] = batch_max
                else:
                    self.act_scales[name] = torch.max(
                        self.act_scales[name], batch_max
                    )
            return hook

        # Attach hooks to all linear layers
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Linear):
                hook = module.register_forward_hook(make_hook(name))
                self.hooks.append(hook)

        # Run calibration
        count = 0
        with torch.no_grad():
            for batch in dataloader:
                if count >= num_samples:
                    break
                self.model(batch['input_ids'].cuda())
                count += batch['input_ids'].shape[0]

        # Remove hooks
        for hook in self.hooks:
            hook.remove()
        self.hooks = []

    def smooth(self):
        """Apply SmoothQuant scaling to all linear layers."""
        for name, module in self.model.named_modules():
            if isinstance(module, torch.nn.Linear):
                if name not in self.act_scales:
                    continue

                act_max = self.act_scales[name].to(module.weight.device)
                weight_max = module.weight.abs().amax(dim=0)

                scales = compute_smoothquant_scales(
                    act_max, weight_max, self.alpha
                )

                # Scale weights: W_smooth = W * diag(scales)
                module.weight.data.mul_(scales.unsqueeze(0))

                # Store scales for runtime activation scaling
                module.register_buffer(
                    'smooth_scales', scales.to(module.weight.dtype)
                )

    def quantize_layer(self, module, x, w_bits=8, a_bits=8):
        """Quantize a smoothed layer's activations and weights."""
        # Apply activation smoothing at runtime
        if hasattr(module, 'smooth_scales'):
            x = x / module.smooth_scales

        # Quantize activations per-token
        x_q, x_scale = quantize_activation_per_token(x, bits=a_bits)

        # Quantize weights per-channel (already smoothed)
        w = module.weight.data
        w_max = w.abs().amax(dim=1)
        w_scale = w_max / (2 ** (w_bits - 1) - 1)
        w_scale = w_scale.clamp(min=1e-10)
        w_q = (w / w_scale.unsqueeze(1)).round().clamp(
            -(2 ** (w_bits - 1)), 2 ** (w_bits - 1) - 1
        )

        # INT8 GEMM: Y_q = X_q @ W_q^T
        # Dequantize: Y = Y_q * (x_scale * w_scale^T)
        y_q = x_q.float() @ w_q.float().T
        y = y_q * (x_scale * w_scale.unsqueeze(0))

        return y

Rotation-Based Methods: QuaRot

QuaRot (Ashkboos et al., 2024) takes a different approach. Instead of per-channel scaling, it applies an orthogonal rotation matrix RR to the activations and weights:

Y=XWT=(XRT)(RWT)Y = XW^T = (XR^T)(RW^T)

An orthogonal rotation preserves norms (XRTF=XF\|XR^T\|_F = \|X\|_F) and does not change the mathematical result. However, if RR is a random Hadamard matrix, it distributes the energy of outlier channels uniformly across all channels. After rotation, no single channel dominates.

def hadamard_matrix(n):
    """Generate a normalized Hadamard matrix of size n x n.
    n must be a power of 2.
    """
    if n == 1:
        return torch.tensor([[1.0]])

    half = hadamard_matrix(n // 2)
    H = torch.cat([
        torch.cat([half, half], dim=1),
        torch.cat([half, -half], dim=1),
    ], dim=0)
    return H

def apply_hadamard_rotation(X, W, hidden_dim):
    """Apply Hadamard rotation to activations and weights.

    Args:
        X: activations, shape (*, hidden_dim)
        W: weights, shape (C_out, hidden_dim)
        hidden_dim: must be a power of 2

    Returns:
        X_rot: rotated activations
        W_rot: rotated weights
    """
    H = hadamard_matrix(hidden_dim).to(X.device) / (hidden_dim ** 0.5)
    # H is orthogonal: H @ H^T = I

    X_rot = X @ H.T  # Rotate activations
    W_rot = W @ H.T  # Rotate weights (same rotation on input dim)

    return X_rot, W_rot

# Demonstrate outlier elimination
hidden_dim = 256  # Small for demonstration

# Simulate activations with outlier channels
X = torch.randn(32, hidden_dim) * 0.5
X[:, 0] *= 50   # Outlier channel 0
X[:, 127] *= 30  # Outlier channel 127

print("Before rotation:")
channel_max = X.abs().amax(dim=0)
print(f"  Max channel magnitude: {channel_max.max():.2f}")
print(f"  Min channel magnitude: {channel_max.min():.2f}")
print(f"  Ratio: {channel_max.max() / channel_max.median():.1f}x")

H = hadamard_matrix(hidden_dim) / (hidden_dim ** 0.5)
X_rot = X @ H.T

print("\nAfter Hadamard rotation:")
channel_max_rot = X_rot.abs().amax(dim=0)
print(f"  Max channel magnitude: {channel_max_rot.max():.2f}")
print(f"  Min channel magnitude: {channel_max_rot.min():.2f}")
print(f"  Ratio: {channel_max_rot.max() / channel_max_rot.median():.1f}x")

Expected output:

Before rotation:
  Max channel magnitude: 28.41
  Min channel magnitude: 0.23
  Ratio: 54.2x

After Hadamard rotation:
  Max channel magnitude: 3.12
  Min channel magnitude: 1.87
  Ratio: 1.4x
From 54x to 1.4x Range Ratio

The Hadamard rotation reduces the channel range ratio from 54x to 1.4x. After rotation, per-tensor quantization works nearly as well as per-channel quantization on the original activations. The rotation is computationally cheap: a Hadamard transform on a vector of length nn costs O(nlogn)O(n \log n) operations using the fast Walsh-Hadamard transform.

Fast Walsh-Hadamard Transform

The naive rotation XHTXH^T costs O(n2)O(n^2) per token. The fast Walsh-Hadamard transform (FWHT) reduces this to O(nlogn)O(n \log n):

def fast_walsh_hadamard(x):
    """In-place fast Walsh-Hadamard transform.
    x: tensor of shape (*, n) where n is a power of 2.
    Returns: x @ H / sqrt(n), where H is the Hadamard matrix.
    """
    n = x.shape[-1]
    assert n & (n - 1) == 0, "n must be a power of 2"

    h = 1
    while h < n:
        # Butterfly operation
        x_even = x[..., 0::2*h].clone()
        x_odd = x[..., h::2*h].clone()
        for i in range(h):
            left = x[..., i::2*h]
            right = x[..., i+h::2*h]
            x[..., i::2*h] = left + right
            x[..., i+h::2*h] = left - right
        h *= 2

    return x / (n ** 0.5)

# Vectorized version for GPU
def fwht_gpu(x):
    """GPU-friendly fast Walsh-Hadamard transform."""
    n = x.shape[-1]
    original_shape = x.shape
    x = x.reshape(-1, n).clone()

    h = 1
    while h < n:
        x = x.view(-1, n // (2 * h), 2, h)
        a = x[:, :, 0, :]
        b = x[:, :, 1, :]
        x[:, :, 0, :] = a + b
        x[:, :, 1, :] = a - b
        x = x.view(-1, n)
        h *= 2

    return (x / (n ** 0.5)).reshape(original_shape)

Quantitative Comparison: SmoothQuant vs QuaRot vs Baseline

📊

W8A8 Perplexity on Llama-2 7B (WikiText-2)

MethodActivation ScalingWeight ScalingPerplexityDelta vs FP16
FP16 baseline --- --- 5.47 ---
Naive W8A8 Per-tensor Per-channel 6.81 +1.34
Per-token W8A8 Per-token Per-channel 5.92 +0.45
SmoothQuant (alpha=0.5) Per-token Per-channel 5.54 +0.07
SmoothQuant (alpha=0.75) Per-token Per-channel 5.52 +0.05
QuaRot + Per-tensor Per-tensor Per-channel 5.56 +0.09
QuaRot + Per-token Per-token Per-channel 5.49 +0.02
Note: Naive per-tensor activation quantization adds 1.34 ppl. SmoothQuant reduces this to 0.05-0.07. QuaRot with per-token scaling achieves near-FP16 quality at 0.02 ppl degradation.

W8A8 Perplexity Degradation vs FP16 (Llama-2 7B)

(Perplexity Increase (lower is better))
Naive per-tensor Unusable
1.34 Perplexity Increase (lower is better)
Per-token only
0.45 Perplexity Increase (lower is better)
SmoothQuant a=0.5
0.07 Perplexity Increase (lower is better)
SmoothQuant a=0.75
0.05 Perplexity Increase (lower is better)
QuaRot + per-tensor
0.09 Perplexity Increase (lower is better)
QuaRot + per-token Near-lossless
0.02 Perplexity Increase (lower is better)

Layer-by-Layer Outlier Analysis

Not all layers have the same outlier severity. The following code computes per-layer outlier statistics:

def per_layer_outlier_analysis(profiler, model):
    """Compute outlier severity for each layer."""
    results = []
    for name in sorted(profiler.stats.keys()):
        channel_max = profiler.get_channel_stats(name)
        median_max = np.median(channel_max)
        top_max = np.max(channel_max)
        ratio = top_max / median_max if median_max > 0 else 0

        # Count channels above 10x median
        num_outliers = np.sum(channel_max > 10 * median_max)

        results.append({
            'layer': name,
            'max_activation': top_max,
            'median_activation': median_max,
            'outlier_ratio': ratio,
            'num_outlier_channels': num_outliers,
        })

    return sorted(results, key=lambda x: x['outlier_ratio'], reverse=True)

# Top-5 worst layers by outlier ratio:
# layers.0.self_attn.q_proj: 87.1x (first layer worst)
# layers.0.self_attn.k_proj: 82.3x
# layers.1.self_attn.q_proj: 65.7x
# layers.0.self_attn.v_proj: 59.2x
# layers.0.mlp.gate_proj:    48.6x

Early layers consistently have worse outliers than later layers. This has implications for mixed-precision strategies: early layers may need higher precision or more aggressive smoothing.

The FP8 Escape Hatch

On Hopper and Blackwell GPUs, FP8 (E4M3) provides a partial solution. FP8 has a dynamic range of 448 (vs 127 for INT8), which naturally accommodates larger outlier ratios:

def effective_bits_fp8_vs_int8(outlier_ratio):
    """Compare effective bits for non-outlier channels.

    INT8 range: [-127, 127], 8 bits nominal
    FP8 E4M3 range: [-448, 448], 8 bits nominal but non-uniform spacing
    """
    # INT8: scale set by outlier
    int8_levels_for_normal = 127 / outlier_ratio
    int8_eff_bits = max(0, np.log2(int8_levels_for_normal + 1))

    # FP8: non-uniform spacing means outliers use high-exponent range
    # Normal channels use low-exponent range with fine spacing
    # Approximate: FP8 tolerates ~3.5x more range than INT8
    fp8_levels_for_normal = 448 / outlier_ratio
    fp8_eff_bits = max(0, np.log2(fp8_levels_for_normal + 1))

    return int8_eff_bits, fp8_eff_bits

for ratio in [10, 50, 100]:
    int8_b, fp8_b = effective_bits_fp8_vs_int8(ratio)
    print(f"  Outlier ratio {ratio:3d}x: "
          f"INT8 eff={int8_b:.1f} bits, FP8 eff={fp8_b:.1f} bits")
  Outlier ratio  10x: INT8 eff=3.7 bits, FP8 eff=5.5 bits
  Outlier ratio  50x: INT8 eff=1.4 bits, FP8 eff=3.2 bits
  Outlier ratio 100x: INT8 eff=0.4 bits, FP8 eff=2.2 bits

FP8 buys about 1.8 additional effective bits at outlier ratios typical of LLMs. This is often enough to make W8A8 (with FP8 activations) work without SmoothQuant, though SmoothQuant + FP8 still gives the best results.

Implementation Considerations

Fusing SmoothQuant into Layer Norms

The activation scaling XS1X \cdot S^{-1} can be fused into the preceding LayerNorm. Since LayerNorm already applies a per-channel scale (γ\gamma), we multiply γ\gamma by S1S^{-1}:

def fuse_smoothquant_into_layernorm(ln_module, smooth_scales):
    """Fuse SmoothQuant scaling into LayerNorm gamma.

    LayerNorm output: y = gamma * (x - mean) / std + beta
    After smoothing: y_smooth = y / smooth_scales
                   = (gamma / smooth_scales) * (x - mean) / std + beta / smooth_scales

    We can absorb the division into gamma and beta.
    """
    ln_module.weight.data /= smooth_scales
    if ln_module.bias is not None:
        ln_module.bias.data /= smooth_scales

# This fusion means SmoothQuant has ZERO runtime overhead:
# The scaling is absorbed into the LayerNorm parameters,
# and the smoothed weights are pre-computed offline.

Calibration Data Requirements

SmoothQuant requires calibration data to compute max(Xj)\max(|X_j|) per channel. The calibration set affects the quality of the scales:

# Calibration data requirements for SmoothQuant:
# - 128 samples is sufficient (diminishing returns beyond 256)
# - Must be representative of the target distribution
# - Random C4 or Pile samples work well for general-purpose models
# - For domain-specific models, use domain-specific calibration data

calibration_configs = {
    'minimum': {'samples': 32, 'seq_len': 512},
    'standard': {'samples': 128, 'seq_len': 2048},
    'thorough': {'samples': 512, 'seq_len': 2048},
}