Weight quantization (Part 2) gives you smaller models and faster memory transfers. But the matrix multiplication itself still runs in FP16 — the GPU dequantizes INT4 weights to FP16, then performs an FP16 GEMM with FP16 activations. To get true INT8 or FP8 compute speedup (not just bandwidth savings), you must also quantize the activations. This is W8A8: both weights and activations in INT8, with the GEMM executed using INT8 tensor cores.
The problem: activations are dramatically harder to quantize than weights. Weights are static, well-behaved, and roughly Gaussian. Activations are dynamic (different for every input), and they contain outliers — individual channels with values 10-100x larger than the rest. These outliers destroy INT8 quantization quality if not handled correctly.
This post covers the outlier problem in detail, then implements SmoothQuant (Xiao et al., 2022), the algorithm that solved it by migrating quantization difficulty from activations to weights using a mathematically equivalent transformation.
Why Activations Are Harder to Quantize Than Weights
The Weight Distribution
Trained LLM weights are well-behaved. They have a roughly symmetric distribution centered around zero, with most values within a few standard deviations of the mean. The ratio between the maximum and median absolute weight value is typically 3-5x.
import torch
import torch.nn as nn
import numpy as np
# Simulate typical LLM weight distributions
torch.manual_seed(42)
weight = torch.randn(4096, 4096) * 0.02 # Typical LLM weight scale
max_w = weight.abs().max().item()
median_w = weight.abs().median().item()
print(f"Weight max/median ratio: {max_w / median_w:.1f}x")
# Typically 3-5x -- easy to quantize
The Activation Distribution
Activations after layer normalization and linear projections are a different story. Starting from the OPT-6.7B model (and observed in Llama, GPT, and virtually all large transformers), specific channels in the activation tensor consistently produce values that are 10-100x larger than the other channels. These are called activation outliers or emergent features.
def simulate_activation_outliers(batch_size=32, seq_len=128, hidden=4096,
outlier_fraction=0.01, outlier_magnitude=50.0):
"""Simulate realistic LLM activations with channel-wise outliers.
In real LLMs, outliers appear in fixed channels across all tokens
and all inputs. They emerge during training around 6B parameters
and persist in all larger models.
"""
# Base activation: roughly Gaussian
activations = torch.randn(batch_size, seq_len, hidden) * 0.5
# Outlier channels: fixed positions, large magnitude
num_outliers = int(hidden * outlier_fraction)
outlier_channels = torch.randperm(hidden)[:num_outliers]
activations[:, :, outlier_channels] *= outlier_magnitude
return activations, outlier_channels
activations, outlier_ch = simulate_activation_outliers()
flat = activations.reshape(-1, 4096)
# Per-channel statistics
channel_max = flat.abs().max(dim=0).values
sorted_max, _ = channel_max.sort(descending=True)
print(f"Top 5 channel max values: {sorted_max[:5].tolist()}")
print(f"Median channel max: {channel_max.median().item():.2f}")
print(f"Max/median ratio: {sorted_max[0].item() / channel_max.median().item():.1f}x")
# Max/median ratio: ~50-100x -- this destroys naive quantization
When a single channel has values 100x larger than the rest, symmetric INT8 quantization must set the scale factor to accommodate that channel. This means the 99% of channels with normal magnitudes are effectively quantized to only 1-2 bits of precision (values 0 or 1 in INT8), wasting 6-7 bits of the INT8 range. The result: catastrophic quality loss.
Quantifying the Damage
def quantize_per_tensor_int8(tensor):
"""Per-tensor symmetric INT8 quantization."""
amax = tensor.abs().max()
scale = amax / 127.0
q = (tensor / scale).round().clamp(-128, 127).to(torch.int8)
return q, scale
def quantize_per_token_int8(tensor):
"""Per-token symmetric INT8 quantization.
tensor: (batch * seq, hidden)
One scale per row (token).
"""
amax = tensor.abs().amax(dim=1, keepdim=True)
scale = amax / 127.0
scale = scale.clamp(min=1e-10)
q = (tensor / scale).round().clamp(-128, 127).to(torch.int8)
return q, scale
# Compare quantization quality with and without outliers
weight = torch.randn(4096, 4096) * 0.02
activations, _ = simulate_activation_outliers()
flat_act = activations.reshape(-1, 4096)
# Reference output
y_ref = flat_act @ weight.T
# Quantize activations per-tensor (naive)
q_act, s_act = quantize_per_tensor_int8(flat_act)
act_recon = q_act.float() * s_act
y_naive = act_recon @ weight.T
mse_naive = ((y_ref - y_naive) ** 2).mean().item()
# Quantize activations per-token
q_act_pt, s_act_pt = quantize_per_token_int8(flat_act)
act_recon_pt = q_act_pt.float() * s_act_pt
y_per_token = act_recon_pt @ weight.T
mse_per_token = ((y_ref - y_per_token) ** 2).mean().item()
print(f"Naive per-tensor INT8 activation MSE: {mse_naive:.6f}")
print(f"Per-token INT8 activation MSE: {mse_per_token:.6f}")
print(f"Improvement: {mse_naive / mse_per_token:.1f}x")
Activation Quantization Error by Granularity
(relative MSE)Per-token quantization helps because each token gets its own scale factor, so a single outlier token does not ruin the scale for other tokens. But the outlier channels still consume most of the INT8 range within each token. Per-channel quantization would solve this, but it is incompatible with efficient INT8 GEMM kernels (the GEMM accumulation requires a single scale per output element, which is only possible with per-tensor or per-token activation scaling combined with per-channel weight scaling).
SmoothQuant: The Solution
SmoothQuant (Xiao et al., 2022) resolves the outlier problem with a mathematically equivalent transformation that migrates quantization difficulty from activations to weights. The key equation:
where is a per-channel smoothing factor. This transformation:
- Divides each activation channel by (shrinking outlier channels)
- Multiplies each weight input channel by (growing corresponding weight channels)
- Preserves the output exactly:
After smoothing, the activations have a much more uniform distribution across channels (outliers are reduced), while the weights absorb the extra magnitude. Since weights are static, they can tolerate wider dynamic range without quality loss (they are quantized offline with full knowledge of the distribution).
Computing the Smoothing Factor
The smoothing factor balances quantization difficulty between activations and weights:
where controls how aggressively to smooth. is the default, meaning equal difficulty split. closer to 1.0 smooths activations more aggressively (good when outliers are extreme). The original SmoothQuant paper found works well for most models, with needed for models with very strong outliers (e.g., GLM-130B).
The smoothing operates per input channel. Channel with a large activation outlier gets a large , which divides down that channel’s activations and multiplies up that channel’s weights. Channels without outliers get and are barely affected.
Complete SmoothQuant Implementation
class SmoothQuant:
"""SmoothQuant: migrate activation difficulty to weights."""
def __init__(self, alpha=0.5):
self.alpha = alpha
def compute_smoothing_factors(self, activation_scales, weight):
"""Compute per-channel smoothing factors.
activation_scales: (in_features,) max absolute activation per channel
(computed from calibration data)
weight: (out_features, in_features) original weight matrix
Returns: (in_features,) smoothing factors
"""
# Per-channel weight scale: max absolute value per input channel
weight_scales = weight.abs().amax(dim=0) # (in_features,)
# Smoothing factor
s = activation_scales.pow(self.alpha) / weight_scales.pow(1 - self.alpha)
# Clamp for numerical stability
s = s.clamp(min=1e-5)
return s
def smooth_layer(self, weight, ln_weight, ln_bias, activation_scales):
"""Apply SmoothQuant transformation to a layer.
In a transformer, the smoothing is absorbed into the preceding
LayerNorm parameters, avoiding any runtime cost.
weight: (out_features, in_features) linear layer weight
ln_weight: (in_features,) LayerNorm gamma
ln_bias: (in_features,) LayerNorm beta (can be None)
activation_scales: (in_features,) from calibration
Returns: (smoothed_weight, new_ln_weight, new_ln_bias, s)
"""
s = self.compute_smoothing_factors(activation_scales, weight)
# Apply to weight: W_hat = diag(s) * W
smoothed_weight = weight * s.unsqueeze(0) # Broadcast: (out, in) * (1, in)
# Absorb into LayerNorm: LN_weight_new = LN_weight / s
new_ln_weight = ln_weight / s
new_ln_bias = None
if ln_bias is not None:
new_ln_bias = ln_bias / s
return smoothed_weight, new_ln_weight, new_ln_bias, s
def calibrate(self, model_forward_fn, calibration_data, layer_names):
"""Run calibration to collect per-channel activation max values.
model_forward_fn: function that runs the model and returns
a dict mapping layer_name to input activations
calibration_data: list of input tensors
layer_names: list of layer names to calibrate
Returns: dict mapping layer_name to (in_features,) activation scales
"""
scales = {}
for data in calibration_data:
layer_activations = model_forward_fn(data)
for name in layer_names:
act = layer_activations[name] # (batch, seq, hidden)
if act.dim() == 3:
act = act.reshape(-1, act.shape[-1])
# Per-channel max absolute value
ch_max = act.abs().amax(dim=0)
if name not in scales:
scales[name] = ch_max
else:
scales[name] = torch.maximum(scales[name], ch_max)
return scales
End-to-End SmoothQuant Application
def apply_smoothquant_to_transformer_block(block, activation_scales, alpha=0.5):
"""Apply SmoothQuant to all linear layers in a transformer block.
A typical transformer block has:
- LayerNorm -> Q, K, V projections (self-attention)
- LayerNorm -> Up, Gate, Down projections (FFN)
SmoothQuant is applied between each LayerNorm and the subsequent
linear layers. The smoothing factors are absorbed into the LayerNorm
parameters, so there is ZERO runtime cost.
"""
sq = SmoothQuant(alpha=alpha)
# Smooth attention projections
# All Q, K, V projections share the same input (post-LayerNorm)
# Use the activation scale for this input
attn_act_scale = activation_scales['attn_input']
for proj_name in ['q_proj', 'k_proj', 'v_proj']:
proj = getattr(block.self_attn, proj_name)
smoothed_w, new_ln_w, new_ln_b, s = sq.smooth_layer(
weight=proj.weight.data,
ln_weight=block.input_layernorm.weight.data,
ln_bias=getattr(block.input_layernorm, 'bias',
torch.zeros_like(block.input_layernorm.weight.data))
if hasattr(block.input_layernorm, 'bias') else None,
activation_scales=attn_act_scale,
)
proj.weight.data = smoothed_w
# Update LayerNorm parameters (shared across Q, K, V)
block.input_layernorm.weight.data = new_ln_w
if new_ln_b is not None and hasattr(block.input_layernorm, 'bias'):
block.input_layernorm.bias.data = new_ln_b
# Smooth FFN projections similarly
ffn_act_scale = activation_scales['ffn_input']
for proj_name in ['up_proj', 'gate_proj']:
proj = getattr(block.mlp, proj_name)
smoothed_w, new_ln_w, new_ln_b, s = sq.smooth_layer(
weight=proj.weight.data,
ln_weight=block.post_attention_layernorm.weight.data,
ln_bias=None, # RMSNorm has no bias
activation_scales=ffn_act_scale,
)
proj.weight.data = smoothed_w
block.post_attention_layernorm.weight.data = new_ln_w
return block
Per-Tensor Dynamic Scaling for Online Quantization
After SmoothQuant smoothing, the activations have a much more uniform distribution. Now we can apply standard per-tensor or per-token INT8 quantization at runtime with acceptable quality.
Static quantization uses fixed scale factors determined during calibration. The scale does not change at inference time. This is faster (no per-token max computation) but less accurate for inputs that differ from the calibration distribution.
Dynamic quantization computes the scale factor from each actual input at runtime. This adds a small overhead (computing the per-tensor or per-token max) but adapts to any input distribution.
class DynamicInt8Quantizer:
"""Runtime dynamic INT8 quantization for activations."""
@staticmethod
def quantize_per_tensor(x):
"""Per-tensor dynamic quantization.
x: (batch * seq, hidden) FP16 tensor
Returns: (q_x, scale)
"""
amax = x.abs().max()
scale = amax / 127.0
scale = max(scale.item(), 1e-10)
q_x = (x / scale).round().clamp(-128, 127).to(torch.int8)
return q_x, scale
@staticmethod
def quantize_per_token(x):
"""Per-token dynamic quantization.
x: (tokens, hidden) FP16 tensor
Returns: (q_x, scales) where scales is (tokens, 1)
"""
amax = x.abs().amax(dim=1, keepdim=True)
scales = amax / 127.0
scales = scales.clamp(min=1e-10)
q_x = (x / scales).round().clamp(-128, 127).to(torch.int8)
return q_x, scales
@staticmethod
def quantize_per_channel(x):
"""Per-channel dynamic quantization.
x: (tokens, hidden) FP16 tensor
Returns: (q_x, scales) where scales is (1, hidden)
"""
amax = x.abs().amax(dim=0, keepdim=True)
scales = amax / 127.0
scales = scales.clamp(min=1e-10)
q_x = (x / scales).round().clamp(-128, 127).to(torch.int8)
return q_x, scales
W8A8: Full INT8 Inference
With SmoothQuant-smoothed weights and dynamic activation quantization, we can implement complete W8A8 inference. Both the weights and activations are in INT8, and the GEMM uses INT8 tensor cores with INT32 accumulation.
The W8A8 GEMM
The INT8 GEMM computes:
The accumulation in INT32 is critical — INT8 * INT8 products can reach , and summing thousands of these requires 32-bit precision.
class W8A8Linear:
"""INT8 linear layer with SmoothQuant preprocessing."""
def __init__(self, weight_int8, weight_scale, per_channel=True):
"""
weight_int8: (out_features, in_features) INT8
weight_scale: (out_features, 1) or scalar, FP32
per_channel: whether weight was quantized per-channel
"""
self.weight_int8 = weight_int8
self.weight_scale = weight_scale
self.per_channel = per_channel
self.out_features = weight_int8.shape[0]
self.in_features = weight_int8.shape[1]
@classmethod
def from_float(cls, linear_layer):
"""Quantize a float linear layer to W8A8."""
weight = linear_layer.weight.data.float()
# Per-channel symmetric quantization for weights
amax = weight.abs().amax(dim=1, keepdim=True)
scale = amax / 127.0
scale = scale.clamp(min=1e-10)
weight_int8 = (weight / scale).round().clamp(-128, 127).to(torch.int8)
return cls(weight_int8, scale, per_channel=True)
def forward(self, x):
"""W8A8 forward pass with dynamic per-token activation quantization.
x: (batch, seq_len, hidden) or (tokens, hidden) FP16
"""
orig_shape = x.shape
if x.dim() == 3:
x = x.reshape(-1, x.shape[-1])
# Dynamic per-token activation quantization
act_amax = x.abs().amax(dim=1, keepdim=True)
act_scale = act_amax / 127.0
act_scale = act_scale.clamp(min=1e-10)
x_int8 = (x / act_scale).round().clamp(-128, 127).to(torch.int8)
# INT8 GEMM with INT32 accumulation
# On real hardware, this uses INT8 tensor cores via cuBLAS
y_int32 = torch.matmul(
x_int8.float(), # Simulated -- real kernel stays in INT8
self.weight_int8.float().T
).to(torch.int32)
# Dequantize: y_fp = y_int32 * act_scale * weight_scale
y_fp = y_int32.float() * act_scale * self.weight_scale.T
if len(orig_shape) == 3:
y_fp = y_fp.reshape(orig_shape[0], orig_shape[1], self.out_features)
return y_fp
def benchmark_w8a8(hidden_size=4096, batch_size=32, seq_len=128):
"""Benchmark W8A8 vs FP16 linear layer."""
layer = nn.Linear(hidden_size, hidden_size, bias=False).float()
x = torch.randn(batch_size, seq_len, hidden_size)
# FP16 reference
y_ref = layer(x)
# W8A8
w8a8 = W8A8Linear.from_float(layer)
y_w8a8 = w8a8.forward(x)
mse = ((y_ref - y_w8a8) ** 2).mean().item()
cos_sim = torch.nn.functional.cosine_similarity(
y_ref.flatten().unsqueeze(0),
y_w8a8.flatten().unsqueeze(0)
).item()
print(f"W8A8 vs FP16: MSE={mse:.8f}, cosine_sim={cos_sim:.6f}")
return mse, cos_sim
benchmark_w8a8()
W8A8 Quality: SmoothQuant + Per-Token Dynamic Scaling
| Model | FP16 PPL | W8A8 Naive PPL | W8A8 SmoothQuant PPL | Degradation |
|---|---|---|---|---|
| OPT-6.7B | 10.86 | 23.54 | 10.93 | +0.07 |
| OPT-13B | 10.13 | 18.91 | 10.19 | +0.06 |
| OPT-30B | 9.56 | 67.82 | 9.64 | +0.08 |
| OPT-66B | 9.34 | 940+ | 9.41 | +0.07 |
| Llama 7B | 5.68 | 7.92 | 5.73 | +0.05 |
| Llama 13B | 5.09 | 6.18 | 5.13 | +0.04 |
W8A8 Throughput Improvement Over FP16 (A100, Llama 7B)
(tokens/sec)W8A8 and W4A16 serve different needs. W8A8 is better for prefill (compute-bound: INT8 tensor cores are 2x faster than FP16). W4A16 is better for decode (memory-bound: 4x less weight data to read). The best serving systems use W8A8 for prefill and W4A16 for decode, or use FP8 (Part 4) which combines the benefits.
The Complete SmoothQuant Pipeline
Here is the end-to-end pipeline for quantizing a model with SmoothQuant:
class SmoothQuantPipeline:
"""End-to-end SmoothQuant + W8A8 quantization pipeline."""
def __init__(self, model, alpha=0.5):
self.model = model
self.alpha = alpha
self.activation_scales = {}
self.smoothing_factors = {}
def calibrate(self, calibration_loader, num_batches=128):
"""Phase 1: Collect activation statistics."""
hooks = []
act_max = {}
def make_hook(name):
def hook_fn(module, inp, out):
x = inp[0]
if x.dim() == 3:
x = x.reshape(-1, x.shape[-1])
ch_max = x.abs().amax(dim=0).detach()
if name in act_max:
act_max[name] = torch.maximum(act_max[name], ch_max)
else:
act_max[name] = ch_max
return hook_fn
# Register hooks on all linear layers
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear):
hooks.append(module.register_forward_hook(make_hook(name)))
# Run calibration
self.model.eval()
count = 0
with torch.no_grad():
for batch in calibration_loader:
self.model(batch)
count += 1
if count >= num_batches:
break
# Remove hooks
for h in hooks:
h.remove()
self.activation_scales = act_max
print(f"Calibrated {len(act_max)} layers over {count} batches")
def smooth(self):
"""Phase 2: Apply SmoothQuant transformation."""
sq = SmoothQuant(self.alpha)
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear) and name in self.activation_scales:
act_scale = self.activation_scales[name]
s = sq.compute_smoothing_factors(act_scale, module.weight.data)
# Apply smoothing to weight
module.weight.data *= s.unsqueeze(0)
self.smoothing_factors[name] = s
print(f"Smoothed {len(self.smoothing_factors)} layers (alpha={self.alpha})")
def quantize_weights(self):
"""Phase 3: Quantize smoothed weights to INT8."""
quantized = {}
for name, module in self.model.named_modules():
if isinstance(module, nn.Linear):
w8a8 = W8A8Linear.from_float(module)
quantized[name] = w8a8
print(f"Quantized {len(quantized)} linear layers to INT8")
return quantized
def run(self, calibration_loader):
"""Run the full pipeline."""
print("Phase 1: Calibrating...")
self.calibrate(calibration_loader)
print("Phase 2: Smoothing...")
self.smooth()
print("Phase 3: Quantizing...")
return self.quantize_weights()
Advanced: Per-Channel Smoothing Factor Analysis
Understanding which channels get smoothed and by how much gives insight into the model’s internal structure:
def analyze_smoothing(activation_scales, weight, alpha=0.5):
"""Analyze the smoothing factors for a layer."""
sq = SmoothQuant(alpha)
s = sq.compute_smoothing_factors(activation_scales, weight)
print(f"Smoothing factor statistics:")
print(f" Min: {s.min().item():.4f}")
print(f" Max: {s.max().item():.4f}")
print(f" Mean: {s.mean().item():.4f}")
print(f" Median: {s.median().item():.4f}")
print(f" Std: {s.std().item():.4f}")
print(f" Max/Min ratio: {s.max().item() / s.min().item():.1f}x")
# Channels with large smoothing factors are the outlier channels
outlier_threshold = s.mean() + 3 * s.std()
outlier_mask = s > outlier_threshold
n_outliers = outlier_mask.sum().item()
print(f" Outlier channels (3-sigma): {n_outliers} "
f"({100 * n_outliers / len(s):.1f}%)")
# Before and after smoothing
act_range_before = activation_scales.max() / activation_scales.median()
act_range_after = (activation_scales / s).max() / (activation_scales / s).median()
print(f" Activation range before: {act_range_before:.1f}x")
print(f" Activation range after: {act_range_after:.1f}x")
return s
# Demonstrate
act_scales = torch.ones(4096) * 0.5
outlier_idx = torch.randperm(4096)[:40] # ~1% outliers
act_scales[outlier_idx] = 25.0 # 50x larger
weight = torch.randn(4096, 4096) * 0.02
s = analyze_smoothing(act_scales, weight, alpha=0.5)
Alpha Selection: Tuning the Smoothing Strength
The parameter controls how aggressively SmoothQuant migrates quantization difficulty from activations to weights. Choosing the right is critical for quality.
- : No smoothing. . All difficulty stays on activations.
- : Equal split. Difficulty is balanced between activations and weights.
- : Maximum smoothing. . All difficulty migrates to weights.
In practice, between 0.5 and 0.75 works for most models. Models with stronger outliers need higher .
def search_optimal_alpha(layer, calibration_inputs, alphas=None):
"""Grid search for optimal SmoothQuant alpha.
Tests each alpha value and returns the one that minimizes
output error after W8A8 quantization.
"""
if alphas is None:
alphas = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
weight = layer.weight.data.clone().float()
# Collect activation statistics
all_inputs = []
for inp in calibration_inputs:
if inp.dim() == 3:
inp = inp.reshape(-1, inp.shape[-1])
all_inputs.append(inp)
all_inputs = torch.cat(all_inputs, dim=0)
# Per-channel activation max
act_scales = all_inputs.abs().amax(dim=0)
# Reference output
y_ref = all_inputs @ weight.T
best_alpha = 0.5
best_mse = float('inf')
for alpha in alphas:
# Compute smoothing factors
w_scales = weight.abs().amax(dim=0)
s = act_scales.pow(alpha) / w_scales.pow(1 - alpha)
s = s.clamp(min=1e-5)
s = s / s.mean()
# Apply smoothing
w_smooth = weight * s.unsqueeze(0)
x_smooth = all_inputs / s.unsqueeze(0)
# Quantize both to INT8
# Weight: per-channel
w_amax = w_smooth.abs().amax(dim=1, keepdim=True)
w_scale = w_amax / 127.0
w_scale = w_scale.clamp(min=1e-10)
w_q = (w_smooth / w_scale).round().clamp(-128, 127)
# Activation: per-token
x_amax = x_smooth.abs().amax(dim=1, keepdim=True)
x_scale = x_amax / 127.0
x_scale = x_scale.clamp(min=1e-10)
x_q = (x_smooth / x_scale).round().clamp(-128, 127)
# Compute output
y_q = (x_q * x_scale) @ (w_q * w_scale).T
mse = ((y_ref - y_q) ** 2).mean().item()
if mse < best_mse:
best_mse = mse
best_alpha = alpha
return best_alpha, best_mse
# Usage
torch.manual_seed(42)
layer = nn.Linear(4096, 4096, bias=False).float()
calib = [torch.randn(1, 128, 4096) for _ in range(64)]
# Inject outliers in calibration data
for c in calib:
c[:, :, 42] *= 30.0
c[:, :, 1337] *= 50.0
alpha, mse = search_optimal_alpha(layer, calib)
print(f"Optimal alpha: {alpha}, MSE: {mse:.8f}")
Effect of Alpha on W8A8 Perplexity (OPT-66B, WikiText-2)
| Alpha | Perplexity | Degradation vs FP16 |
|---|---|---|
| 0.0 (no smoothing) | 940+ | catastrophic |
| 0.25 | 12.81 | +3.47 |
| 0.50 (default) | 9.41 | +0.07 |
| 0.75 | 9.43 | +0.09 |
| 1.0 (max smoothing) | 9.89 | +0.55 |
The Outlier Emergence Phenomenon
Activation outliers are not present in small models. They emerge as models scale beyond approximately 6 billion parameters. This was first documented by Dettmers et al. (2022) in the “LLM.int8()” paper, which showed that OPT models below 6.7B have no significant outliers, while models at 6.7B and above develop persistent outlier channels.
The outliers have distinctive properties:
- Fixed channels: The same channel indices produce outliers across all inputs, all tokens, and all layers. Channel 42 might always be an outlier in layer 15.
- Consistent magnitude: The outlier magnitude is relatively stable — it does not vary wildly between inputs.
- Small number: Typically 0.1-1% of channels are outliers. In a 4096-dimensional hidden state, that is 4-40 channels.
- Critical for function: Zeroing out outlier channels catastrophically degrades model quality. They encode important information despite being numerically extreme.
def detect_outlier_channels(activation_scales, threshold_sigma=3.0):
"""Detect outlier channels in activation statistics.
activation_scales: (hidden_dim,) per-channel max absolute values
threshold_sigma: channels above mean + threshold * std are outliers
Returns: outlier channel indices
"""
mean_scale = activation_scales.mean()
std_scale = activation_scales.std()
threshold = mean_scale + threshold_sigma * std_scale
outlier_mask = activation_scales > threshold
outlier_indices = torch.where(outlier_mask)[0]
print(f"Detection results:")
print(f" Total channels: {len(activation_scales)}")
print(f" Outlier threshold: {threshold:.4f}")
print(f" Outlier channels: {len(outlier_indices)} "
f"({100 * len(outlier_indices) / len(activation_scales):.2f}%)")
print(f" Max outlier magnitude: {activation_scales[outlier_mask].max():.4f}")
print(f" Median normal magnitude: "
f"{activation_scales[~outlier_mask].median():.4f}")
print(f" Outlier/normal ratio: "
f"{activation_scales[outlier_mask].max() / activation_scales[~outlier_mask].median():.1f}x")
return outlier_indices
# Simulate and detect
act_scales = torch.ones(4096) * 0.5
outlier_idx = torch.tensor([42, 137, 256, 512, 1024, 1337, 2048, 3000, 3500, 4000])
act_scales[outlier_idx] = torch.tensor([25.0, 30.0, 18.0, 22.0, 35.0, 50.0, 28.0, 15.0, 20.0, 40.0])
detected = detect_outlier_channels(act_scales)
Static vs Dynamic Quantization Tradeoffs
Beyond SmoothQuant, the choice between static and dynamic activation quantization has significant implications for both quality and latency.
Static quantization computes scale factors during calibration and fixes them at deployment. Every inference request uses the same pre-computed scales. This is faster (no per-request max computation) but assumes the calibration distribution matches production traffic.
Dynamic quantization computes the scale factor from each input at runtime. This adds one reduction operation per linear layer (computing the per-token or per-tensor max) but perfectly adapts to any input distribution.
In practice, the latency cost of dynamic quantization is negligible for large models. The reduction to compute the per-token max of a 4096-dimensional vector is a few microseconds — invisible compared to the GEMM latency. Dynamic quantization is therefore preferred for production serving where input distributions are unpredictable.
Static vs Dynamic Activation Quantization (Llama 13B)
| Method | Calibration Data | In-Domain PPL | Out-of-Domain PPL | Latency Overhead |
|---|---|---|---|---|
| Static (C4 calib) | C4 | 5.15 | 6.42 | 0% |
| Static (Wiki calib) | WikiText | 5.13 | 6.89 | 0% |
| Dynamic per-token | None needed | 5.14 | 5.14 | ~1% |
| Dynamic per-tensor | None needed | 5.18 | 5.18 | ~0.5% |
When SmoothQuant Is Not Enough
SmoothQuant works well for W8A8 but has limitations:
W4A4 or W4A8: At 4-bit precision, even smoothed activations have too much quantization error. SmoothQuant can reduce outlier impact, but the fundamental precision is too low for general activation quantization at INT4. This is why W4A16 (INT4 weights, FP16 activations) remains the dominant 4-bit inference format.
Extreme outliers: Some models (notably GLM-130B) have outliers so extreme that no value of can fully smooth them. In these cases, a mixed-precision approach is needed: quantize most channels to INT8 and keep the handful of outlier channels in FP16.
Dynamic outlier patterns: SmoothQuant assumes outliers appear in fixed channels. If a model has input-dependent outlier patterns (rare but possible), the smoothing factors computed during calibration may not generalize. Per-token dynamic scaling partially mitigates this.
def mixed_precision_quantize(activations, weight, outlier_threshold=3.0):
"""Mixed-precision approach: INT8 for normal channels, FP16 for outliers."""
ch_max = activations.abs().amax(dim=0)
median_max = ch_max.median()
# Identify outlier channels
outlier_mask = ch_max > outlier_threshold * median_max
normal_mask = ~outlier_mask
n_outlier = outlier_mask.sum().item()
n_normal = normal_mask.sum().item()
print(f"Mixed precision: {n_normal} INT8 channels, {n_outlier} FP16 channels")
# Split computation
act_normal = activations[:, normal_mask]
act_outlier = activations[:, outlier_mask]
w_normal = weight[:, normal_mask]
w_outlier = weight[:, outlier_mask]
# INT8 path for normal channels
q_act, act_scale = quantize_per_tensor_int8(act_normal)
q_w_amax = w_normal.abs().amax(dim=1, keepdim=True)
w_scale = q_w_amax / 127.0
q_w = (w_normal / w_scale).round().clamp(-128, 127).to(torch.int8)
y_int8 = (q_act.float() @ q_w.float().T) * (act_scale * w_scale.T)
# FP16 path for outlier channels
y_fp16 = act_outlier @ w_outlier.T
return y_int8 + y_fp16
Summary
Activation quantization is fundamentally harder than weight quantization because of outliers: specific channels in LLM activations carry values 10-100x larger than the rest, and these outliers destroy uniform quantization quality.
SmoothQuant solves this by applying a per-channel transformation that divides down outlier activations and multiplies up the corresponding weights. The transformation is mathematically equivalent — the output does not change. After smoothing, standard per-tensor or per-token INT8 quantization works with minimal quality loss.
W8A8 inference quantizes both weights and activations to INT8, enabling INT8 tensor cores for 2x compute throughput over FP16. The key enabler is SmoothQuant for activation smoothing combined with dynamic per-token scaling at inference time.
The next post covers FP8, which provides a more elegant solution to the activation quantization problem by using floating-point representation (better for non-uniform distributions) instead of integer representation.