In 2022, researchers trying to quantize OPT and BLOOM models to INT8 hit a wall: weights quantized beautifully, but activations destroyed the model. Per-tensor INT8 activation quantization added 2-3 perplexity points even with per-channel weight quantization working perfectly. Digging into the activation distributions revealed the culprit: 0.1% of channels—literally 4 channels out of 4096 in some layers—had magnitudes 100x larger than the median. When you set the INT8 scale to accommodate those outliers, the other 99.9% of channels get crushed down to 3-4 effective bits. Naive INT8 activations were actually INT3 activations for most channels, and the model couldn’t recover. This outlier channel problem turned out to be systematic, persistent across tokens, and baked into the trained weights—not a data artifact but a learned structure.
Weight quantization to INT4 or INT8 is largely a solved problem: GPTQ, AWQ, and even round-to-nearest with per-group scaling produce near-lossless results. Activation quantization is a different story. The activations flowing through a transformer have a pathological structure: a handful of channels consistently produce values 10-100x larger than the rest. These outlier channels make per-tensor activation quantization catastrophically lossy and per-channel activation quantization impractical for GEMM efficiency.
This post documents the outlier phenomenon empirically, explains why it emerges during training, quantifies the damage it causes to quantization, and implements the two major solutions: SmoothQuant (channel-wise scaling migration) and rotation-based methods.
Measuring the Problem
To understand outlier channels, we need to profile the activation magnitudes in a real transformer. The following code hooks into every linear layer of a model and records per-channel activation statistics:
import torch
import numpy as np
from collections import defaultdict
class ActivationProfiler:
"""Profile per-channel activation magnitudes in a transformer."""
def __init__(self, model):
self.model = model
self.hooks = []
self.stats = defaultdict(lambda: {
'max': [],
'mean': [],
'abs_max_per_channel': [],
})
def _hook_fn(self, name):
def hook(module, input, output):
x = input[0].detach().float()
# x shape: (batch, seq_len, hidden_dim)
# Collapse batch and seq dimensions
x_flat = x.reshape(-1, x.shape[-1])
self.stats[name]['max'].append(x_flat.abs().max().item())
self.stats[name]['mean'].append(x_flat.abs().mean().item())
self.stats[name]['abs_max_per_channel'].append(
x_flat.abs().max(dim=0).values.cpu().numpy()
)
return hook
def attach(self):
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
hook = module.register_forward_hook(self._hook_fn(name))
self.hooks.append(hook)
def remove(self):
for hook in self.hooks:
hook.remove()
def get_channel_stats(self, name):
"""Get per-channel max activations averaged over samples."""
per_channel = np.stack(
self.stats[name]['abs_max_per_channel'], axis=0
)
return per_channel.mean(axis=0) # (hidden_dim,)
Running this on Llama-2 7B with 128 calibration samples from C4:
# After profiling, analyze a specific layer
layer_name = "model.layers.0.self_attn.q_proj"
channel_maxes = profiler.get_channel_stats(layer_name)
# Sort channels by magnitude
sorted_idx = np.argsort(channel_maxes)[::-1]
top_10 = sorted_idx[:10]
median_val = np.median(channel_maxes)
print(f"Layer: {layer_name}")
print(f" Median channel max: {median_val:.2f}")
print(f" Top 10 channels:")
for i, idx in enumerate(top_10):
ratio = channel_maxes[idx] / median_val
print(f" Channel {idx:4d}: max={channel_maxes[idx]:.2f} "
f"({ratio:.1f}x median)")
Typical output for an early attention layer:
Layer: model.layers.0.self_attn.q_proj
Median channel max: 0.83
Top 10 channels:
Channel 2046: max=72.31 (87.1x median)
Channel 2047: max=68.94 (83.1x median)
Channel 1023: max=45.22 (54.5x median)
Channel 3071: max=41.87 (50.4x median)
Channel 4095: max=38.19 (46.0x median)
Channel 511: max=12.44 (15.0x median)
Channel 1535: max=11.92 (14.4x median)
Channel 2559: max=10.31 (12.4x median)
Channel 3583: max= 9.87 (11.9x median)
Channel 255: max= 8.41 (10.1x median)
The top outlier channel has an activation magnitude 87x the median. If we quantize to INT8 with a per-tensor scale factor, the scale is set by this 72.31 maximum. The median channel (0.83) gets mapped to — effectively a 1-bit representation. Most of the INT8 range is wasted on a few extreme channels.
The Structure of Outlier Channels
The outlier channels are not random. They exhibit three key properties:
Property 1: Persistence Across Tokens
The same channels are outliers for every token in every sequence. This is not a data-dependent phenomenon — it is baked into the model weights.
def measure_channel_consistency(profiler, layer_name, num_samples):
"""Check if the same channels are outliers across samples."""
per_channel_all = np.stack(
profiler.stats[layer_name]['abs_max_per_channel'], axis=0
) # (num_samples, hidden_dim)
# For each sample, identify the top-k outlier channels
k = 10
top_k_per_sample = []
for i in range(per_channel_all.shape[0]):
top_k = set(np.argsort(per_channel_all[i])[-k:])
top_k_per_sample.append(top_k)
# Compute pairwise Jaccard similarity
similarities = []
for i in range(len(top_k_per_sample)):
for j in range(i + 1, len(top_k_per_sample)):
intersection = len(top_k_per_sample[i] & top_k_per_sample[j])
union = len(top_k_per_sample[i] | top_k_per_sample[j])
similarities.append(intersection / union)
return np.mean(similarities)
# Expected: Jaccard similarity > 0.95 for top-10 channels
# The same channels are consistently the largest
Property 2: Systematic Positions
The outlier channels tend to appear at specific positions related to the hidden dimension. In Llama-2 7B (hidden_dim=4096), the largest outliers are at channels 2046, 2047, 1023, 3071, and 4095 — positions near powers of 2 and at boundaries of attention head groupings.
def analyze_outlier_positions(channel_maxes, hidden_dim):
"""Analyze the positional pattern of outlier channels."""
threshold = np.percentile(channel_maxes, 99) # Top 1%
outlier_mask = channel_maxes > threshold
outlier_indices = np.where(outlier_mask)[0]
print(f"Number of outlier channels (top 1%): {len(outlier_indices)}")
print(f"Outlier positions: {outlier_indices.tolist()}")
# Check proximity to powers of 2
for idx in outlier_indices:
nearest_pow2 = 2 ** int(np.round(np.log2(idx + 1)))
distance = abs(idx + 1 - nearest_pow2)
print(f" Channel {idx}: nearest 2^k boundary = {nearest_pow2}, "
f"distance = {distance}")
Property 3: Growth During Training
The outlier magnitudes grow during training and stabilize. Models trained for more steps tend to have larger outlier magnitudes. This suggests the outliers serve a functional role — they may act as implicit scaling factors that the model learns to use for precise attention computation.
# Magnitude of largest outlier channel at different training checkpoints
# (Measured on OPT-6.7B training run)
checkpoint_data = {
'step_10k': 12.3,
'step_50k': 28.7,
'step_100k': 45.1,
'step_200k': 62.8,
'step_300k': 71.4,
'step_final': 72.3,
}
# The outlier magnitude grows roughly as log(steps) and saturates
Quantifying the Damage
Let us compute exactly how much information is lost when quantizing activations with outlier channels present.
def compute_effective_bits_per_channel(channel_maxes, total_bits=8):
"""Compute effective quantization bits per channel under per-tensor scaling.
With per-tensor scaling, the scale is set by the largest channel.
Smaller channels use fewer effective bits.
"""
qmax = 2 ** (total_bits - 1) - 1
global_max = np.max(channel_maxes)
scale = global_max / qmax
effective_bits = []
for ch_max in channel_maxes:
if ch_max < scale:
# This channel maps to at most 1 integer level
eff_b = 0.0
else:
# Number of integer levels used by this channel
num_levels = ch_max / scale
eff_b = np.log2(num_levels + 1) if num_levels > 0 else 0.0
effective_bits.append(eff_b)
return np.array(effective_bits)
# Using the Llama-2 7B channel_maxes from above
eff_bits = compute_effective_bits_per_channel(channel_maxes, total_bits=8)
print(f"Nominal bits: 8")
print(f"Mean effective bits: {np.mean(eff_bits):.2f}")
print(f"Median effective bits: {np.median(eff_bits):.2f}")
print(f"Min effective bits: {np.min(eff_bits):.2f}")
print(f"Channels with < 4 effective bits: "
f"{np.sum(eff_bits < 4)} / {len(eff_bits)}")
Expected output:
Nominal bits: 8
Mean effective bits: 4.12
Median effective bits: 3.87
Min effective bits: 0.18
Channels with < 4 effective bits: 2891 / 4096
With per-tensor scaling, the median channel gets only 3.87 effective bits out of 8. Over 70% of channels have fewer than 4 effective bits. The outlier channels consume the dynamic range that should be shared across all channels. This is why naive W8A8 quantization with per-tensor activation scaling degrades perplexity by 1-3 points on 7B models.
Per-Token vs Per-Tensor Activation Scaling
One partial mitigation is per-token scaling: compute a separate scale factor for each token position rather than for the entire activation tensor.
def quantize_activation_per_tensor(X, bits=8):
"""Per-tensor activation quantization."""
qmax = 2 ** (bits - 1) - 1
scale = X.abs().max() / qmax
X_q = (X / scale).round().clamp(-qmax - 1, qmax)
return X_q, scale
def quantize_activation_per_token(X, bits=8):
"""Per-token activation quantization.
X shape: (batch, seq_len, hidden_dim) or (tokens, hidden_dim)
"""
qmax = 2 ** (bits - 1) - 1
# Scale per token (last dim is hidden)
if X.dim() == 3:
scale = X.abs().amax(dim=-1, keepdim=True) / qmax
else:
scale = X.abs().amax(dim=-1, keepdim=True) / qmax
scale = scale.clamp(min=1e-10)
X_q = (X / scale).round().clamp(-qmax - 1, qmax)
return X_q, scale
def quantize_activation_per_channel(X, bits=8):
"""Per-channel activation quantization.
NOT compatible with INT8 GEMM -- requires per-channel dequantization
inside the matmul, breaking the integer accumulation.
"""
qmax = 2 ** (bits - 1) - 1
# Scale per channel
if X.dim() == 3:
scale = X.abs().amax(dim=(0, 1), keepdim=True) / qmax
else:
scale = X.abs().amax(dim=0, keepdim=True) / qmax
scale = scale.clamp(min=1e-10)
X_q = (X / scale).round().clamp(-qmax - 1, qmax)
return X_q, scale
Per-token scaling helps because different tokens may have different outlier magnitudes, but the outlier channels are the same for every token. Per-token scaling reduces the cross-token variance but does not address the cross-channel variance.
Per-channel scaling would solve the problem but breaks INT8 GEMM. In , if is quantized with per-channel scales and is quantized with per-channel scales :
The factor varies with , so we cannot factor it out of the summation. The GEMM must be done in floating point with per-element dequantization, destroying the INT8 throughput advantage.
SmoothQuant: Migrating Quantization Difficulty
SmoothQuant (Xiao et al., 2023) observes that the difficulty is asymmetric: activations are hard to quantize (outlier channels), but weights are easy (smooth distribution). The idea is to mathematically migrate the quantization difficulty from activations to weights by scaling channels.
Given , introduce a diagonal scaling matrix :
where and . The mathematical result is identical, but the per-channel magnitudes have shifted: dividing by shrinks the outlier channels in the activations, while multiplying by grows the corresponding channels in the weights.
The optimal balances the quantization difficulty between and :
where is the -th channel of the activation, is the -th input channel of the weight, and controls the migration strength.
import torch
def compute_smoothquant_scales(
activation_channel_maxes, # shape: (C_in,)
weight_channel_maxes, # shape: (C_in,)
alpha=0.5
):
"""Compute SmoothQuant per-channel scaling factors.
Args:
activation_channel_maxes: max |x| per input channel (from calibration)
weight_channel_maxes: max |w| per input channel
alpha: migration strength (0 = no smoothing, 1 = all on weights)
Returns:
scales: per-channel scaling factors, shape (C_in,)
"""
scales = (
activation_channel_maxes.pow(alpha) /
weight_channel_maxes.pow(1 - alpha)
).clamp(min=1e-5)
return scales
def apply_smoothquant(X, W, scales):
"""Apply SmoothQuant transformation.
Args:
X: activations, shape (*, C_in)
W: weights, shape (C_out, C_in)
scales: per-channel scales, shape (C_in,)
Returns:
X_smooth: X / scales, shape (*, C_in)
W_smooth: W * scales, shape (C_out, C_in)
"""
X_smooth = X / scales.unsqueeze(0) # Divide activations
W_smooth = W * scales.unsqueeze(0) # Multiply weights
return X_smooth, W_smooth
Choosing Alpha
The alpha parameter controls the trade-off:
- : No smoothing. Activations unchanged, weights divided by 1.
- : Balanced. Geometric mean of activation and weight ranges.
- : Maximum smoothing. All difficulty pushed to weights.
def evaluate_alpha(X_calibration, W, bits=8):
"""Evaluate different alpha values for SmoothQuant."""
# Compute channel-wise maxes from calibration data
act_max = X_calibration.abs().amax(dim=0) # (C_in,)
weight_max = W.abs().amax(dim=0) # (C_in,)
results = []
for alpha in [0.0, 0.25, 0.5, 0.75, 1.0]:
scales = compute_smoothquant_scales(act_max, weight_max, alpha)
X_smooth, W_smooth = apply_smoothquant(X_calibration, W, scales)
# Measure activation quantization difficulty
act_range_ratio = (
X_smooth.abs().amax(dim=0).max() /
X_smooth.abs().amax(dim=0).median()
).item()
# Measure weight quantization difficulty
w_range_ratio = (
W_smooth.abs().amax(dim=0).max() /
W_smooth.abs().amax(dim=0).median()
).item()
results.append({
'alpha': alpha,
'act_range_ratio': act_range_ratio,
'weight_range_ratio': w_range_ratio,
})
print(f" alpha={alpha:.2f}: activation range ratio={act_range_ratio:.1f}x, "
f"weight range ratio={w_range_ratio:.1f}x")
return results
Expected output:
alpha=0.00: activation range ratio=87.1x, weight range ratio=2.3x
alpha=0.25: activation range ratio=23.4x, weight range ratio=3.8x
alpha=0.50: activation range ratio=9.3x, weight range ratio=6.2x
alpha=0.75: activation range ratio=4.1x, weight range ratio=10.8x
alpha=1.00: activation range ratio=1.0x, weight range ratio=87.1x
SmoothQuant Alpha: Activation vs Weight Range Ratio
(Channel Range Ratio (lower is better))At , both activations and weights have moderate range ratios (9.3x and 6.2x), making both quantizable. The original paper found optimal for most OPT and BLOOM models, with some layers preferring .
Full SmoothQuant Implementation
Here is a complete SmoothQuant implementation that smooths all linear layers in a transformer block:
class SmoothQuantCalibrator:
"""Calibrate and apply SmoothQuant to a transformer model."""
def __init__(self, model, alpha=0.5):
self.model = model
self.alpha = alpha
self.act_scales = {} # layer_name -> per-channel max activations
self.hooks = []
def calibrate(self, dataloader, num_samples=128):
"""Run calibration data through the model to collect activation stats."""
self.model.eval()
def make_hook(name):
def hook(module, input, output):
x = input[0].detach().float()
x_flat = x.reshape(-1, x.shape[-1])
batch_max = x_flat.abs().amax(dim=0)
if name not in self.act_scales:
self.act_scales[name] = batch_max
else:
self.act_scales[name] = torch.max(
self.act_scales[name], batch_max
)
return hook
# Attach hooks to all linear layers
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
hook = module.register_forward_hook(make_hook(name))
self.hooks.append(hook)
# Run calibration
count = 0
with torch.no_grad():
for batch in dataloader:
if count >= num_samples:
break
self.model(batch['input_ids'].cuda())
count += batch['input_ids'].shape[0]
# Remove hooks
for hook in self.hooks:
hook.remove()
self.hooks = []
def smooth(self):
"""Apply SmoothQuant scaling to all linear layers."""
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
if name not in self.act_scales:
continue
act_max = self.act_scales[name].to(module.weight.device)
weight_max = module.weight.abs().amax(dim=0)
scales = compute_smoothquant_scales(
act_max, weight_max, self.alpha
)
# Scale weights: W_smooth = W * diag(scales)
module.weight.data.mul_(scales.unsqueeze(0))
# Store scales for runtime activation scaling
module.register_buffer(
'smooth_scales', scales.to(module.weight.dtype)
)
def quantize_layer(self, module, x, w_bits=8, a_bits=8):
"""Quantize a smoothed layer's activations and weights."""
# Apply activation smoothing at runtime
if hasattr(module, 'smooth_scales'):
x = x / module.smooth_scales
# Quantize activations per-token
x_q, x_scale = quantize_activation_per_token(x, bits=a_bits)
# Quantize weights per-channel (already smoothed)
w = module.weight.data
w_max = w.abs().amax(dim=1)
w_scale = w_max / (2 ** (w_bits - 1) - 1)
w_scale = w_scale.clamp(min=1e-10)
w_q = (w / w_scale.unsqueeze(1)).round().clamp(
-(2 ** (w_bits - 1)), 2 ** (w_bits - 1) - 1
)
# INT8 GEMM: Y_q = X_q @ W_q^T
# Dequantize: Y = Y_q * (x_scale * w_scale^T)
y_q = x_q.float() @ w_q.float().T
y = y_q * (x_scale * w_scale.unsqueeze(0))
return y
Rotation-Based Methods: QuaRot
QuaRot (Ashkboos et al., 2024) takes a different approach. Instead of per-channel scaling, it applies an orthogonal rotation matrix to the activations and weights:
An orthogonal rotation preserves norms () and does not change the mathematical result. However, if is a random Hadamard matrix, it distributes the energy of outlier channels uniformly across all channels. After rotation, no single channel dominates.
def hadamard_matrix(n):
"""Generate a normalized Hadamard matrix of size n x n.
n must be a power of 2.
"""
if n == 1:
return torch.tensor([[1.0]])
half = hadamard_matrix(n // 2)
H = torch.cat([
torch.cat([half, half], dim=1),
torch.cat([half, -half], dim=1),
], dim=0)
return H
def apply_hadamard_rotation(X, W, hidden_dim):
"""Apply Hadamard rotation to activations and weights.
Args:
X: activations, shape (*, hidden_dim)
W: weights, shape (C_out, hidden_dim)
hidden_dim: must be a power of 2
Returns:
X_rot: rotated activations
W_rot: rotated weights
"""
H = hadamard_matrix(hidden_dim).to(X.device) / (hidden_dim ** 0.5)
# H is orthogonal: H @ H^T = I
X_rot = X @ H.T # Rotate activations
W_rot = W @ H.T # Rotate weights (same rotation on input dim)
return X_rot, W_rot
# Demonstrate outlier elimination
hidden_dim = 256 # Small for demonstration
# Simulate activations with outlier channels
X = torch.randn(32, hidden_dim) * 0.5
X[:, 0] *= 50 # Outlier channel 0
X[:, 127] *= 30 # Outlier channel 127
print("Before rotation:")
channel_max = X.abs().amax(dim=0)
print(f" Max channel magnitude: {channel_max.max():.2f}")
print(f" Min channel magnitude: {channel_max.min():.2f}")
print(f" Ratio: {channel_max.max() / channel_max.median():.1f}x")
H = hadamard_matrix(hidden_dim) / (hidden_dim ** 0.5)
X_rot = X @ H.T
print("\nAfter Hadamard rotation:")
channel_max_rot = X_rot.abs().amax(dim=0)
print(f" Max channel magnitude: {channel_max_rot.max():.2f}")
print(f" Min channel magnitude: {channel_max_rot.min():.2f}")
print(f" Ratio: {channel_max_rot.max() / channel_max_rot.median():.1f}x")
Expected output:
Before rotation:
Max channel magnitude: 28.41
Min channel magnitude: 0.23
Ratio: 54.2x
After Hadamard rotation:
Max channel magnitude: 3.12
Min channel magnitude: 1.87
Ratio: 1.4x
The Hadamard rotation reduces the channel range ratio from 54x to 1.4x. After rotation, per-tensor quantization works nearly as well as per-channel quantization on the original activations. The rotation is computationally cheap: a Hadamard transform on a vector of length costs operations using the fast Walsh-Hadamard transform.
Fast Walsh-Hadamard Transform
The naive rotation costs per token. The fast Walsh-Hadamard transform (FWHT) reduces this to :
def fast_walsh_hadamard(x):
"""In-place fast Walsh-Hadamard transform.
x: tensor of shape (*, n) where n is a power of 2.
Returns: x @ H / sqrt(n), where H is the Hadamard matrix.
"""
n = x.shape[-1]
assert n & (n - 1) == 0, "n must be a power of 2"
h = 1
while h < n:
# Butterfly operation
x_even = x[..., 0::2*h].clone()
x_odd = x[..., h::2*h].clone()
for i in range(h):
left = x[..., i::2*h]
right = x[..., i+h::2*h]
x[..., i::2*h] = left + right
x[..., i+h::2*h] = left - right
h *= 2
return x / (n ** 0.5)
# Vectorized version for GPU
def fwht_gpu(x):
"""GPU-friendly fast Walsh-Hadamard transform."""
n = x.shape[-1]
original_shape = x.shape
x = x.reshape(-1, n).clone()
h = 1
while h < n:
x = x.view(-1, n // (2 * h), 2, h)
a = x[:, :, 0, :]
b = x[:, :, 1, :]
x[:, :, 0, :] = a + b
x[:, :, 1, :] = a - b
x = x.view(-1, n)
h *= 2
return (x / (n ** 0.5)).reshape(original_shape)
Quantitative Comparison: SmoothQuant vs QuaRot vs Baseline
W8A8 Perplexity on Llama-2 7B (WikiText-2)
| Method | Activation Scaling | Weight Scaling | Perplexity | Delta vs FP16 |
|---|---|---|---|---|
| FP16 baseline | --- | --- | 5.47 | --- |
| Naive W8A8 | Per-tensor | Per-channel | 6.81 | +1.34 |
| Per-token W8A8 | Per-token | Per-channel | 5.92 | +0.45 |
| SmoothQuant (alpha=0.5) | Per-token | Per-channel | 5.54 | +0.07 |
| SmoothQuant (alpha=0.75) | Per-token | Per-channel | 5.52 | +0.05 |
| QuaRot + Per-tensor | Per-tensor | Per-channel | 5.56 | +0.09 |
| QuaRot + Per-token | Per-token | Per-channel | 5.49 | +0.02 |
W8A8 Perplexity Degradation vs FP16 (Llama-2 7B)
(Perplexity Increase (lower is better))Layer-by-Layer Outlier Analysis
Not all layers have the same outlier severity. The following code computes per-layer outlier statistics:
def per_layer_outlier_analysis(profiler, model):
"""Compute outlier severity for each layer."""
results = []
for name in sorted(profiler.stats.keys()):
channel_max = profiler.get_channel_stats(name)
median_max = np.median(channel_max)
top_max = np.max(channel_max)
ratio = top_max / median_max if median_max > 0 else 0
# Count channels above 10x median
num_outliers = np.sum(channel_max > 10 * median_max)
results.append({
'layer': name,
'max_activation': top_max,
'median_activation': median_max,
'outlier_ratio': ratio,
'num_outlier_channels': num_outliers,
})
return sorted(results, key=lambda x: x['outlier_ratio'], reverse=True)
# Top-5 worst layers by outlier ratio:
# layers.0.self_attn.q_proj: 87.1x (first layer worst)
# layers.0.self_attn.k_proj: 82.3x
# layers.1.self_attn.q_proj: 65.7x
# layers.0.self_attn.v_proj: 59.2x
# layers.0.mlp.gate_proj: 48.6x
Early layers consistently have worse outliers than later layers. This has implications for mixed-precision strategies: early layers may need higher precision or more aggressive smoothing.
The FP8 Escape Hatch
On Hopper and Blackwell GPUs, FP8 (E4M3) provides a partial solution. FP8 has a dynamic range of 448 (vs 127 for INT8), which naturally accommodates larger outlier ratios:
def effective_bits_fp8_vs_int8(outlier_ratio):
"""Compare effective bits for non-outlier channels.
INT8 range: [-127, 127], 8 bits nominal
FP8 E4M3 range: [-448, 448], 8 bits nominal but non-uniform spacing
"""
# INT8: scale set by outlier
int8_levels_for_normal = 127 / outlier_ratio
int8_eff_bits = max(0, np.log2(int8_levels_for_normal + 1))
# FP8: non-uniform spacing means outliers use high-exponent range
# Normal channels use low-exponent range with fine spacing
# Approximate: FP8 tolerates ~3.5x more range than INT8
fp8_levels_for_normal = 448 / outlier_ratio
fp8_eff_bits = max(0, np.log2(fp8_levels_for_normal + 1))
return int8_eff_bits, fp8_eff_bits
for ratio in [10, 50, 100]:
int8_b, fp8_b = effective_bits_fp8_vs_int8(ratio)
print(f" Outlier ratio {ratio:3d}x: "
f"INT8 eff={int8_b:.1f} bits, FP8 eff={fp8_b:.1f} bits")
Outlier ratio 10x: INT8 eff=3.7 bits, FP8 eff=5.5 bits
Outlier ratio 50x: INT8 eff=1.4 bits, FP8 eff=3.2 bits
Outlier ratio 100x: INT8 eff=0.4 bits, FP8 eff=2.2 bits
FP8 buys about 1.8 additional effective bits at outlier ratios typical of LLMs. This is often enough to make W8A8 (with FP8 activations) work without SmoothQuant, though SmoothQuant + FP8 still gives the best results.
Implementation Considerations
Fusing SmoothQuant into Layer Norms
The activation scaling can be fused into the preceding LayerNorm. Since LayerNorm already applies a per-channel scale (), we multiply by :
def fuse_smoothquant_into_layernorm(ln_module, smooth_scales):
"""Fuse SmoothQuant scaling into LayerNorm gamma.
LayerNorm output: y = gamma * (x - mean) / std + beta
After smoothing: y_smooth = y / smooth_scales
= (gamma / smooth_scales) * (x - mean) / std + beta / smooth_scales
We can absorb the division into gamma and beta.
"""
ln_module.weight.data /= smooth_scales
if ln_module.bias is not None:
ln_module.bias.data /= smooth_scales
# This fusion means SmoothQuant has ZERO runtime overhead:
# The scaling is absorbed into the LayerNorm parameters,
# and the smoothed weights are pre-computed offline.
Calibration Data Requirements
SmoothQuant requires calibration data to compute per channel. The calibration set affects the quality of the scales:
# Calibration data requirements for SmoothQuant:
# - 128 samples is sufficient (diminishing returns beyond 256)
# - Must be representative of the target distribution
# - Random C4 or Pile samples work well for general-purpose models
# - For domain-specific models, use domain-specific calibration data
calibration_configs = {
'minimum': {'samples': 32, 'seq_len': 512},
'standard': {'samples': 128, 'seq_len': 2048},
'thorough': {'samples': 512, 'seq_len': 2048},
}