You quantized your model to INT4, the perplexity increased by 12%, and some downstream tasks degraded by 20%. Now what? The naive response is to increase precision to INT8 or FP8, but that doubles your memory and halves your throughput gain. The systematic response is to identify which layers cause the degradation, what weight or activation patterns trigger quantization error, and apply targeted fixes that preserve most of the compression benefit.
This post covers the debugging methodology: layer sensitivity analysis, outlier detection and mitigation, calibration dataset selection, mixed-precision strategies, and quality recovery techniques.
The Quantization Quality Pipeline
Quality degradation from quantization follows a predictable chain:
This per-layer reconstruction error propagates through the network. Layers early in the network propagate error to all subsequent layers. Layers with large weight magnitudes amplify error. Layers with high sensitivity (small weight changes cause large output changes) are the critical targets.
import torch
import numpy as np
def layer_reconstruction_error(weight_fp16, weight_int4, scale, zero_point, group_size=128):
"""Compute per-layer reconstruction error after quantization."""
K, N = weight_fp16.shape
num_groups = K // group_size
# Dequantize
weight_dequant = torch.zeros_like(weight_fp16)
for g in range(num_groups):
start = g * group_size
end = start + group_size
w_q = weight_int4[start:end, :]
s = scale[g, :]
z = zero_point[g, :]
weight_dequant[start:end, :] = s * (w_q.float() - z.float())
# Frobenius norm of error
error = torch.norm(weight_fp16 - weight_dequant).item()
relative_error = error / torch.norm(weight_fp16).item()
# Max absolute error (worst case)
max_error = torch.max(torch.abs(weight_fp16 - weight_dequant)).item()
return {
'frobenius_error': error,
'relative_error': relative_error,
'max_abs_error': max_error,
'mean_abs_error': torch.mean(torch.abs(weight_fp16 - weight_dequant)).item()
}
Layer Sensitivity Analysis
Not all layers are equally sensitive to quantization. The standard approach is to quantize one layer at a time while keeping all others at FP16, then measure the impact on a validation metric (perplexity, accuracy).
Per-Layer Sensitivity Sweep
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
def layer_sensitivity_sweep(model_name, calibration_data, eval_data):
"""Quantize each layer independently and measure perplexity impact."""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model.eval()
# Baseline perplexity (FP16)
baseline_ppl = compute_perplexity(model, eval_data, tokenizer)
print(f"Baseline FP16 perplexity: {baseline_ppl:.4f}")
results = {}
for name, param in model.named_parameters():
if 'weight' not in name or param.dim() != 2:
continue # Skip biases and non-matrix params
# Save original weight
original_weight = param.data.clone()
# Quantize this layer to INT4
quantized, scale, zero = quantize_to_int4(param.data, group_size=128)
dequantized = dequantize_int4(quantized, scale, zero, group_size=128)
param.data = dequantized
# Measure perplexity with this layer quantized
ppl = compute_perplexity(model, eval_data, tokenizer)
delta_ppl = ppl - baseline_ppl
relative_delta = delta_ppl / baseline_ppl
results[name] = {
'ppl': ppl,
'delta_ppl': delta_ppl,
'relative_delta': relative_delta
}
print(f"{name}: ppl={ppl:.4f}, delta={delta_ppl:+.4f} ({relative_delta:+.2%})")
# Restore original weight
param.data = original_weight
# Sort by sensitivity (highest delta first)
sorted_results = sorted(results.items(), key=lambda x: x[1]['delta_ppl'], reverse=True)
return sorted_results
Layer Sensitivity Analysis: Llama 70B INT4 (Top 10 Most Sensitive Layers)
| Layer | Delta Perplexity | Relative Impact | Category | Recommendation |
|---|---|---|---|---|
| model.layers.0.self_attn.q_proj | +0.42 | +12.7% | First layer attention | Keep FP16 or FP8 |
| model.layers.0.self_attn.k_proj | +0.38 | +11.4% | First layer attention | Keep FP16 or FP8 |
| model.layers.79.mlp.down_proj | +0.31 | +9.3% | Last layer MLP | Keep FP16 or FP8 |
| model.layers.0.mlp.gate_proj | +0.22 | +6.6% | First layer MLP | INT8 with group=64 |
| model.layers.79.self_attn.o_proj | +0.18 | +5.4% | Last layer attention | INT8 with group=64 |
| model.layers.1.self_attn.q_proj | +0.12 | +3.6% | Second layer attention | INT4 with group=64 |
| model.layers.78.mlp.down_proj | +0.09 | +2.7% | Near-last MLP | INT4 with group=64 |
| model.layers.40.self_attn.v_proj | +0.04 | +1.2% | Middle layer attention | INT4 (default) |
| model.layers.40.mlp.up_proj | +0.02 | +0.6% | Middle layer MLP | INT4 (default) |
| model.layers.40.mlp.gate_proj | +0.01 | +0.3% | Middle layer MLP | INT4 (default) |
Across Llama, Mistral, Qwen, and other transformer architectures, the first 1-2 layers and last 1-2 layers consistently show 5-10x higher quantization sensitivity than middle layers. Keeping these 4-8 layers at FP16 or FP8 while quantizing the remaining 76+ layers to INT4 recovers 60-80% of the quality loss with minimal memory impact (4 layers out of 80 = 5% of weights at higher precision).
Outlier Detection
Weight and activation outliers are the primary cause of quantization degradation. A single outlier value can dominate the scale factor for an entire group, forcing all other values into a narrow range of quantization bins.
Weight Outlier Analysis
def analyze_weight_outliers(weight, threshold_sigma=3.0):
"""Detect outlier weights that will cause quantization issues."""
mean = weight.mean()
std = weight.std()
threshold = threshold_sigma * std
outliers = torch.abs(weight - mean) > threshold
num_outliers = outliers.sum().item()
total = weight.numel()
# Outlier magnitudes
outlier_values = weight[outliers]
max_outlier = torch.max(torch.abs(outlier_values)).item() if num_outliers > 0 else 0
# Dynamic range analysis
weight_range = weight.max().item() - weight.min().item()
non_outlier_range = weight[~outliers].max().item() - weight[~outliers].min().item()
range_ratio = weight_range / non_outlier_range
return {
'num_outliers': num_outliers,
'outlier_fraction': num_outliers / total,
'max_outlier_magnitude': max_outlier,
'weight_range': weight_range,
'non_outlier_range': non_outlier_range,
'range_ratio': range_ratio,
'quantization_waste': 1 - 1 / range_ratio # Fraction of bins wasted
}
# Example: check all layers
for name, param in model.named_parameters():
if param.dim() == 2:
stats = analyze_weight_outliers(param.data, threshold_sigma=4.0)
if stats['range_ratio'] > 2.0:
print(f"WARNING: {name}")
print(f" Outliers: {stats['outlier_fraction']:.4%}")
print(f" Range ratio: {stats['range_ratio']:.2f}x")
print(f" Quantization waste: {stats['quantization_waste']:.1%}")
Activation Outlier Analysis
Activation outliers are values in the intermediate tensors during forward pass that are much larger than typical values. These are particularly problematic because they affect the quantization scale of subsequent operations.
def analyze_activation_outliers(model, calibration_loader, num_batches=32):
"""Profile activation magnitudes across all layers."""
activation_stats = {}
hooks = []
def make_hook(name):
def hook(module, input, output):
if isinstance(output, torch.Tensor):
x = output.detach()
if name not in activation_stats:
activation_stats[name] = {
'max_vals': [],
'mean_vals': [],
'std_vals': [],
'outlier_channels': []
}
stats = activation_stats[name]
stats['max_vals'].append(x.abs().max().item())
stats['mean_vals'].append(x.abs().mean().item())
stats['std_vals'].append(x.std().item())
# Per-channel analysis
if x.dim() >= 2:
channel_max = x.abs().amax(dim=list(range(x.dim()-1)))
channel_mean = channel_max.mean().item()
channel_outliers = (channel_max > 6 * channel_mean).sum().item()
stats['outlier_channels'].append(channel_outliers)
return hook
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
hooks.append(module.register_forward_hook(make_hook(name)))
# Run calibration data through model
model.eval()
with torch.no_grad():
for i, batch in enumerate(calibration_loader):
if i >= num_batches:
break
model(**batch)
# Remove hooks
for h in hooks:
h.remove()
# Summarize
for name, stats in activation_stats.items():
max_val = max(stats['max_vals'])
mean_val = np.mean(stats['mean_vals'])
if max_val / mean_val > 100:
print(f"OUTLIER ALERT: {name}")
print(f" Max activation: {max_val:.1f}")
print(f" Mean activation: {mean_val:.4f}")
print(f" Ratio: {max_val/mean_val:.0f}x")
return activation_stats
Activation Outlier Magnitude by Layer (Llama 70B, sample input)
(max / mean ratio)Outlier Mitigation Techniques
Technique 1: Smaller Group Size
Reducing group size from 128 to 32 means each scale factor covers fewer values, reducing the impact of any single outlier.
def quantize_with_group_size(weight, group_size):
"""Quantize with specified group size, return error."""
K, N = weight.shape
num_groups = K // group_size
total_error = 0.0
for g in range(num_groups):
group = weight[g*group_size:(g+1)*group_size, :]
gmin = group.min(dim=0).values
gmax = group.max(dim=0).values
scale = (gmax - gmin) / 15.0 # 4-bit range: 0-15
zero = torch.round(-gmin / scale).clamp(0, 15)
quantized = torch.round(group / scale + zero).clamp(0, 15)
dequantized = (quantized - zero) * scale
total_error += torch.sum((group - dequantized) ** 2).item()
rmse = (total_error / weight.numel()) ** 0.5
return rmse
# Compare group sizes
for gs in [32, 64, 128, 256]:
error = quantize_with_group_size(sample_weight, gs)
overhead = (2 + 0.5) / gs # scale (2B) + zero (0.5B) per group element
effective_bits = 4 + overhead * 8
print(f"Group size {gs}: RMSE={error:.6f}, "
f"effective bits={effective_bits:.2f}")
Group Size Impact on Quality and Size (Llama 70B)
| Group Size | Effective Bits | Model Size (GB) | Perplexity | RMSE vs FP16 |
|---|---|---|---|---|
| 32 | 4.62 | 38.4 | 3.41 | 0.00312 |
| 64 | 4.31 | 35.8 | 3.49 | 0.00387 |
| 128 | 4.16 | 34.5 | 3.58 | 0.00465 |
| 256 | 4.08 | 33.9 | 3.72 | 0.00548 |
| Channel-wise | 4.00 | 33.2 | 4.15 | 0.00812 |
Technique 2: Clipping (Outlier Suppression)
Instead of letting outliers dictate the scale, clip the weight range to optimize for the majority of values:
def quantize_with_clipping(weight, bits=4, clip_ratio=0.999):
"""Quantize with percentile-based clipping."""
# Find clip thresholds based on percentile
sorted_abs = weight.abs().flatten().sort().values
clip_idx = int(len(sorted_abs) * clip_ratio)
clip_val = sorted_abs[clip_idx].item()
# Clip weights
weight_clipped = weight.clamp(-clip_val, clip_val)
# Quantize clipped weights
wmin = weight_clipped.min()
wmax = weight_clipped.max()
num_levels = 2 ** bits - 1
scale = (wmax - wmin) / num_levels
zero_point = torch.round(-wmin / scale)
quantized = torch.round(weight_clipped / scale + zero_point).clamp(0, num_levels)
dequantized = (quantized - zero_point) * scale
# Error analysis
clip_error = torch.sum((weight - weight_clipped) ** 2) # Error from clipping
quant_error = torch.sum((weight_clipped - dequantized) ** 2) # Quantization error
total_error = torch.sum((weight - dequantized) ** 2)
return {
'clip_error': clip_error.item(),
'quant_error': quant_error.item(),
'total_error': total_error.item(),
'num_clipped': (weight.abs() > clip_val).sum().item()
}
# Sweep clip ratios to find optimal
for ratio in [0.99, 0.995, 0.999, 0.9995, 1.0]:
result = quantize_with_clipping(sample_weight, clip_ratio=ratio)
print(f"Clip {ratio}: total_error={result['total_error']:.6f}, "
f"clipped={result['num_clipped']} values")
Technique 3: SmoothQuant (Activation Migration)
SmoothQuant migrates quantization difficulty from activations to weights by applying per-channel scaling:
The scaling factor is chosen to balance the outlier ranges between and :
where controls how much difficulty is migrated from activations to weights. Typically .
def compute_smooth_scales(activation_max, weight_max, alpha=0.5):
"""Compute SmoothQuant scaling factors."""
# activation_max: [hidden_dim] - max absolute activation per channel
# weight_max: [hidden_dim] - max absolute weight per input channel
scales = (activation_max.pow(alpha) /
weight_max.pow(1 - alpha)).clamp(min=1e-5)
return scales
def apply_smoothquant(linear_layer, activation_max, alpha=0.5):
"""Apply SmoothQuant to a linear layer."""
weight = linear_layer.weight.data # [out, in]
weight_max = weight.abs().amax(dim=0) # [in]
scales = compute_smooth_scales(activation_max, weight_max, alpha)
# Scale weights: W_smooth = W * diag(scales)
linear_layer.weight.data = weight * scales.unsqueeze(0)
# The inverse scaling diag(scales)^{-1} is applied to activations
# at runtime (fused into previous layer's output or layernorm)
return scales
SmoothQuant was designed for INT8 weight + INT8 activation quantization (W8A8). For INT4 weight-only quantization, the activation migration is less relevant because activations remain in FP16. However, the per-channel analysis from SmoothQuant is valuable for identifying which channels have the worst outlier behavior.
Calibration Dataset Selection
The calibration dataset used for quantization (GPTQ, AWQ) significantly affects quality. A poor calibration set produces poor scale factors.
def evaluate_calibration_quality(model, quant_method, calib_datasets, eval_data):
"""Compare calibration datasets by measuring downstream quality."""
results = {}
for name, calib_data in calib_datasets.items():
# Quantize with this calibration set
quantized_model = quant_method(model, calib_data, bits=4, group_size=128)
# Evaluate on held-out data
ppl = compute_perplexity(quantized_model, eval_data)
results[name] = ppl
print(f"Calibration: {name}, Eval PPL: {ppl:.4f}")
return results
# Common calibration datasets
calib_datasets = {
'c4_128': load_c4(num_samples=128, seq_len=2048),
'c4_512': load_c4(num_samples=512, seq_len=2048),
'wikitext': load_wikitext(num_samples=128, seq_len=2048),
'pile_sample': load_pile(num_samples=128, seq_len=2048),
'domain_data': load_custom_domain(num_samples=128, seq_len=2048),
}
Calibration Dataset Impact on INT4 Quality (Llama 70B GPTQ)
| Calibration Set | Samples | Seq Length | WikiText PPL | MMLU Acc | Notes |
|---|---|---|---|---|---|
| C4 (128 samples) | 128 | 2048 | 3.58 | 63.2% | Standard default |
| C4 (512 samples) | 512 | 2048 | 3.55 | 63.5% | Marginal improvement |
| C4 (32 samples) | 32 | 2048 | 3.71 | 62.1% | Too few samples |
| WikiText-2 | 128 | 2048 | 3.52 | 63.1% | Slightly better PPL, same acc |
| The Pile | 128 | 2048 | 3.56 | 63.8% | Good diversity |
| Code only | 128 | 2048 | 3.85 | 61.4% | Poor for general use |
| Random noise | 128 | 2048 | 4.92 | 55.8% | Worst case |
Calibration Best Practices
# Best practices for calibration data:
# Use at least 128 samples
# Use sequence length >= 2048 (longer captures more weight activation patterns)
# Match the domain of your deployment (code model -> code calibration)
# Include diverse content (not all short prompts or all long documents)
# Avoid repetitive or degenerate text
def create_calibration_dataset(tokenizer, texts, num_samples=128, seq_len=2048):
"""Create properly formatted calibration dataset."""
encodings = []
for text in texts:
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) >= seq_len:
# Take a random window
start = np.random.randint(0, len(tokens) - seq_len)
encodings.append(tokens[start:start + seq_len])
if len(encodings) >= num_samples:
break
if len(encodings) < num_samples:
print(f"Warning: only {len(encodings)} samples (requested {num_samples})")
return torch.tensor(encodings[:num_samples])
Mixed-Precision Quantization
The most effective quality recovery technique is mixed-precision: keeping sensitive layers at higher precision while quantizing the majority to INT4.
def create_mixed_precision_config(sensitivity_results, budget_bits=4.5):
"""Generate mixed-precision config from sensitivity analysis.
Args:
sensitivity_results: sorted list of (layer_name, sensitivity_dict)
budget_bits: target average bits per parameter
"""
total_params = sum(r[1]['num_params'] for r in sensitivity_results)
remaining_budget = budget_bits * total_params
config = {}
# Assign INT4 (4 bits) to all layers initially
for name, info in sensitivity_results:
config[name] = 4
remaining_budget -= 4 * info['num_params']
# Upgrade the most sensitive layers until budget is spent
# Extra bits from INT4 to FP16 = 12 bits per param
# Extra bits from INT4 to INT8 = 4 bits per param
for name, info in sensitivity_results:
if remaining_budget <= 0:
break
extra_bits = 12 # Upgrade to FP16 (16 - 4 = 12 extra)
cost = extra_bits * info['num_params']
if cost <= remaining_budget:
config[name] = 16
remaining_budget -= cost
else:
# Try INT8 instead
extra_bits = 4
cost = extra_bits * info['num_params']
if cost <= remaining_budget:
config[name] = 8
remaining_budget -= cost
actual_avg_bits = sum(
config[name] * info['num_params']
for name, info in sensitivity_results
) / total_params
return config, actual_avg_bits
Mixed-Precision Quality Recovery (Llama 70B, average 4.5 bits)
(perplexity (WikiText-2))Diagnostic Tools
Weight Distribution Visualization
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
def plot_weight_distribution(model, output_path="weight_dist.png"):
"""Plot weight distributions for all linear layers."""
fig, axes = plt.subplots(10, 8, figsize=(40, 25))
axes = axes.flatten()
idx = 0
for name, param in model.named_parameters():
if param.dim() != 2 or idx >= len(axes):
continue
weights = param.data.float().cpu().flatten().numpy()
ax = axes[idx]
ax.hist(weights, bins=100, density=True, alpha=0.7)
ax.set_title(name.split('.')[-2] + '.' + name.split('.')[-1],
fontsize=6)
ax.set_xlim(-0.1, 0.1)
# Mark 3-sigma outlier threshold
std = np.std(weights)
ax.axvline(3*std, color='r', linestyle='--', linewidth=0.5)
ax.axvline(-3*std, color='r', linestyle='--', linewidth=0.5)
idx += 1
plt.tight_layout()
plt.savefig(output_path, dpi=150)
print(f"Saved to {output_path}")
Quantization Error Heatmap
def quantization_error_heatmap(weight, group_size=128):
"""Compute per-group quantization error heatmap."""
K, N = weight.shape
num_groups = K // group_size
error_map = torch.zeros(num_groups, N)
for g in range(num_groups):
group = weight[g*group_size:(g+1)*group_size, :]
gmin = group.min(dim=0).values
gmax = group.max(dim=0).values
scale = (gmax - gmin) / 15.0
scale = torch.where(scale == 0, torch.ones_like(scale), scale)
zero = torch.round(-gmin / scale).clamp(0, 15)
quantized = torch.round(group / scale + zero).clamp(0, 15)
dequantized = (quantized - zero) * scale
group_error = torch.mean((group - dequantized) ** 2, dim=0)
error_map[g, :] = group_error
return error_map
Quality Recovery Without Increasing Bits
GPTQ with Activation Reordering (desc_act)
GPTQ’s desc_act=True reorders columns by their activation magnitude (descending). This ensures that the most important weights — those multiplied by the largest activations — are quantized first, while quantization error accumulates in less important columns.
# GPTQ with desc_act=True
# Note: incompatible with Marlin kernel, uses ExLlama v2 kernel
from auto_gptq import AutoGPTQForCausalLM
quantize_config = {
"bits": 4,
"group_size": 128,
"desc_act": True, # Activation-ordered quantization
"damp_percent": 0.01,
"sym": False,
"true_sequential": True,
}
model = AutoGPTQForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
quantize_config=quantize_config
)
model.quantize(calibration_data)
AWQ Salient Weight Protection
AWQ identifies “salient” weights (those multiplied by large activations) and applies per-channel scaling to protect them from quantization error:
# AWQ's core insight: protect salient channels
def awq_scale_search(weight, activation_distribution, bits=4):
"""Search for optimal per-channel scales that minimize quantization error."""
# activation_distribution: [hidden_dim] mean absolute activation per channel
_, in_features = weight.shape
best_scales = torch.ones(in_features)
best_error = float('inf')
# Grid search over scale factors
for alpha in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
scales = activation_distribution.pow(alpha)
scales = scales / scales.mean() # Normalize
# Apply scale and quantize
scaled_weight = weight * scales.unsqueeze(0)
q_weight = quantize(scaled_weight, bits)
deq_weight = dequantize(q_weight, bits) / scales.unsqueeze(0)
error = torch.sum((weight - deq_weight) ** 2 *
activation_distribution.unsqueeze(0)).item()
if error < best_error:
best_error = error
best_scales = scales.clone()
return best_scales
Quality Recovery Techniques Comparison (Llama 70B INT4)
| Technique | Perplexity | Delta vs FP16 | Extra Cost | Compatible With |
|---|---|---|---|---|
| Baseline INT4 (GPTQ, g=128) | 3.58 | +7.8% | None | Marlin, ExLlama |
| Smaller group (g=64) | 3.49 | +5.1% | +3% model size | Marlin, ExLlama |
| desc_act=True | 3.44 | +3.6% | Slower kernel (ExLlama only) | ExLlama only |
| AWQ (activation-aware) | 3.42 | +3.0% | Calibration time | Marlin, AWQ kernel |
| Mixed precision (4.5 avg bits) | 3.36 | +1.2% | +12% model size | Custom config |
| GPTQ + Hessian tuning | 3.40 | +2.4% | Higher calibration cost | Marlin, ExLlama |
Summary
Debugging quantization quality requires systematic analysis: measure per-layer sensitivity to identify the 5-10% of layers that cause 80% of degradation, detect weight and activation outliers that waste quantization bins, choose calibration data that matches your deployment domain, and apply targeted fixes (smaller group size, mixed precision, SmoothQuant, AWQ scaling) rather than increasing bit width uniformly. The first-and-last layer pattern is consistent across model families — keeping these layers at higher precision is the highest-leverage single fix. For production deployments, AWQ with the Marlin kernel provides the best combination of quality preservation and inference speed.