Part of Series Quantization Masterclass 28 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest — Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier — Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 — How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization — The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization — OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

You quantized your model to INT4, the perplexity increased by 12%, and some downstream tasks degraded by 20%. Now what? The naive response is to increase precision to INT8 or FP8, but that doubles your memory and halves your throughput gain. The systematic response is to identify which layers cause the degradation, what weight or activation patterns trigger quantization error, and apply targeted fixes that preserve most of the compression benefit.

This post covers the debugging methodology: layer sensitivity analysis, outlier detection and mitigation, calibration dataset selection, mixed-precision strategies, and quality recovery techniques.

The Quantization Quality Pipeline

Quality degradation from quantization follows a predictable chain:

FP16 weightsquantizeINT4 weightsdequantizeApprox FP16 weights\text{FP16 weights} \xrightarrow{\text{quantize}} \text{INT4 weights} \xrightarrow{\text{dequantize}} \text{Approx FP16 weights} Error=Wfp16W^dequantF\text{Error} = \|W_{\text{fp16}} - \hat{W}_{\text{dequant}}\|_F

This per-layer reconstruction error propagates through the network. Layers early in the network propagate error to all subsequent layers. Layers with large weight magnitudes amplify error. Layers with high sensitivity (small weight changes cause large output changes) are the critical targets.

import torch
import numpy as np

def layer_reconstruction_error(weight_fp16, weight_int4, scale, zero_point, group_size=128):
    """Compute per-layer reconstruction error after quantization."""
    K, N = weight_fp16.shape
    num_groups = K // group_size

    # Dequantize
    weight_dequant = torch.zeros_like(weight_fp16)
    for g in range(num_groups):
        start = g * group_size
        end = start + group_size
        w_q = weight_int4[start:end, :]
        s = scale[g, :]
        z = zero_point[g, :]
        weight_dequant[start:end, :] = s * (w_q.float() - z.float())

    # Frobenius norm of error
    error = torch.norm(weight_fp16 - weight_dequant).item()
    relative_error = error / torch.norm(weight_fp16).item()

    # Max absolute error (worst case)
    max_error = torch.max(torch.abs(weight_fp16 - weight_dequant)).item()

    return {
        'frobenius_error': error,
        'relative_error': relative_error,
        'max_abs_error': max_error,
        'mean_abs_error': torch.mean(torch.abs(weight_fp16 - weight_dequant)).item()
    }

Layer Sensitivity Analysis

Not all layers are equally sensitive to quantization. The standard approach is to quantize one layer at a time while keeping all others at FP16, then measure the impact on a validation metric (perplexity, accuracy).

Per-Layer Sensitivity Sweep

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

def layer_sensitivity_sweep(model_name, calibration_data, eval_data):
    """Quantize each layer independently and measure perplexity impact."""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
    model.eval()

    # Baseline perplexity (FP16)
    baseline_ppl = compute_perplexity(model, eval_data, tokenizer)
    print(f"Baseline FP16 perplexity: {baseline_ppl:.4f}")

    results = {}

    for name, param in model.named_parameters():
        if 'weight' not in name or param.dim() != 2:
            continue  # Skip biases and non-matrix params

        # Save original weight
        original_weight = param.data.clone()

        # Quantize this layer to INT4
        quantized, scale, zero = quantize_to_int4(param.data, group_size=128)
        dequantized = dequantize_int4(quantized, scale, zero, group_size=128)
        param.data = dequantized

        # Measure perplexity with this layer quantized
        ppl = compute_perplexity(model, eval_data, tokenizer)
        delta_ppl = ppl - baseline_ppl
        relative_delta = delta_ppl / baseline_ppl

        results[name] = {
            'ppl': ppl,
            'delta_ppl': delta_ppl,
            'relative_delta': relative_delta
        }

        print(f"{name}: ppl={ppl:.4f}, delta={delta_ppl:+.4f} ({relative_delta:+.2%})")

        # Restore original weight
        param.data = original_weight

    # Sort by sensitivity (highest delta first)
    sorted_results = sorted(results.items(), key=lambda x: x[1]['delta_ppl'], reverse=True)
    return sorted_results
📊

Layer Sensitivity Analysis: Llama 70B INT4 (Top 10 Most Sensitive Layers)

LayerDelta PerplexityRelative ImpactCategoryRecommendation
model.layers.0.self_attn.q_proj +0.42 +12.7% First layer attention Keep FP16 or FP8
model.layers.0.self_attn.k_proj +0.38 +11.4% First layer attention Keep FP16 or FP8
model.layers.79.mlp.down_proj +0.31 +9.3% Last layer MLP Keep FP16 or FP8
model.layers.0.mlp.gate_proj +0.22 +6.6% First layer MLP INT8 with group=64
model.layers.79.self_attn.o_proj +0.18 +5.4% Last layer attention INT8 with group=64
model.layers.1.self_attn.q_proj +0.12 +3.6% Second layer attention INT4 with group=64
model.layers.78.mlp.down_proj +0.09 +2.7% Near-last MLP INT4 with group=64
model.layers.40.self_attn.v_proj +0.04 +1.2% Middle layer attention INT4 (default)
model.layers.40.mlp.up_proj +0.02 +0.6% Middle layer MLP INT4 (default)
model.layers.40.mlp.gate_proj +0.01 +0.3% Middle layer MLP INT4 (default)
Note: First and last layers are consistently the most sensitive. Middle layers tolerate aggressive quantization. This pattern holds across model families.
The First-and-Last Rule

Across Llama, Mistral, Qwen, and other transformer architectures, the first 1-2 layers and last 1-2 layers consistently show 5-10x higher quantization sensitivity than middle layers. Keeping these 4-8 layers at FP16 or FP8 while quantizing the remaining 76+ layers to INT4 recovers 60-80% of the quality loss with minimal memory impact (4 layers out of 80 = 5% of weights at higher precision).

Outlier Detection

Weight and activation outliers are the primary cause of quantization degradation. A single outlier value can dominate the scale factor for an entire group, forcing all other values into a narrow range of quantization bins.

Weight Outlier Analysis

def analyze_weight_outliers(weight, threshold_sigma=3.0):
    """Detect outlier weights that will cause quantization issues."""
    mean = weight.mean()
    std = weight.std()
    threshold = threshold_sigma * std

    outliers = torch.abs(weight - mean) > threshold
    num_outliers = outliers.sum().item()
    total = weight.numel()

    # Outlier magnitudes
    outlier_values = weight[outliers]
    max_outlier = torch.max(torch.abs(outlier_values)).item() if num_outliers > 0 else 0

    # Dynamic range analysis
    weight_range = weight.max().item() - weight.min().item()
    non_outlier_range = weight[~outliers].max().item() - weight[~outliers].min().item()
    range_ratio = weight_range / non_outlier_range

    return {
        'num_outliers': num_outliers,
        'outlier_fraction': num_outliers / total,
        'max_outlier_magnitude': max_outlier,
        'weight_range': weight_range,
        'non_outlier_range': non_outlier_range,
        'range_ratio': range_ratio,
        'quantization_waste': 1 - 1 / range_ratio  # Fraction of bins wasted
    }

# Example: check all layers
for name, param in model.named_parameters():
    if param.dim() == 2:
        stats = analyze_weight_outliers(param.data, threshold_sigma=4.0)
        if stats['range_ratio'] > 2.0:
            print(f"WARNING: {name}")
            print(f"  Outliers: {stats['outlier_fraction']:.4%}")
            print(f"  Range ratio: {stats['range_ratio']:.2f}x")
            print(f"  Quantization waste: {stats['quantization_waste']:.1%}")

Activation Outlier Analysis

Activation outliers are values in the intermediate tensors during forward pass that are much larger than typical values. These are particularly problematic because they affect the quantization scale of subsequent operations.

def analyze_activation_outliers(model, calibration_loader, num_batches=32):
    """Profile activation magnitudes across all layers."""
    activation_stats = {}

    hooks = []
    def make_hook(name):
        def hook(module, input, output):
            if isinstance(output, torch.Tensor):
                x = output.detach()
                if name not in activation_stats:
                    activation_stats[name] = {
                        'max_vals': [],
                        'mean_vals': [],
                        'std_vals': [],
                        'outlier_channels': []
                    }
                stats = activation_stats[name]
                stats['max_vals'].append(x.abs().max().item())
                stats['mean_vals'].append(x.abs().mean().item())
                stats['std_vals'].append(x.std().item())

                # Per-channel analysis
                if x.dim() >= 2:
                    channel_max = x.abs().amax(dim=list(range(x.dim()-1)))
                    channel_mean = channel_max.mean().item()
                    channel_outliers = (channel_max > 6 * channel_mean).sum().item()
                    stats['outlier_channels'].append(channel_outliers)

        return hook

    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            hooks.append(module.register_forward_hook(make_hook(name)))

    # Run calibration data through model
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(calibration_loader):
            if i >= num_batches:
                break
            model(**batch)

    # Remove hooks
    for h in hooks:
        h.remove()

    # Summarize
    for name, stats in activation_stats.items():
        max_val = max(stats['max_vals'])
        mean_val = np.mean(stats['mean_vals'])
        if max_val / mean_val > 100:
            print(f"OUTLIER ALERT: {name}")
            print(f"  Max activation: {max_val:.1f}")
            print(f"  Mean activation: {mean_val:.4f}")
            print(f"  Ratio: {max_val/mean_val:.0f}x")

    return activation_stats

Activation Outlier Magnitude by Layer (Llama 70B, sample input)

(max / mean ratio)
Layer 0 (attn) Severe outliers
280 max / mean ratio
Layer 0 (MLP)
45 max / mean ratio
Layer 20 (attn)
12 max / mean ratio
Layer 40 (attn)
15 max / mean ratio
Layer 60 (attn)
22 max / mean ratio
Layer 79 (attn) Severe outliers
95 max / mean ratio

Outlier Mitigation Techniques

Technique 1: Smaller Group Size

Reducing group size from 128 to 32 means each scale factor covers fewer values, reducing the impact of any single outlier.

def quantize_with_group_size(weight, group_size):
    """Quantize with specified group size, return error."""
    K, N = weight.shape
    num_groups = K // group_size
    total_error = 0.0

    for g in range(num_groups):
        group = weight[g*group_size:(g+1)*group_size, :]
        gmin = group.min(dim=0).values
        gmax = group.max(dim=0).values
        scale = (gmax - gmin) / 15.0  # 4-bit range: 0-15
        zero = torch.round(-gmin / scale).clamp(0, 15)

        quantized = torch.round(group / scale + zero).clamp(0, 15)
        dequantized = (quantized - zero) * scale

        total_error += torch.sum((group - dequantized) ** 2).item()

    rmse = (total_error / weight.numel()) ** 0.5
    return rmse

# Compare group sizes
for gs in [32, 64, 128, 256]:
    error = quantize_with_group_size(sample_weight, gs)
    overhead = (2 + 0.5) / gs  # scale (2B) + zero (0.5B) per group element
    effective_bits = 4 + overhead * 8
    print(f"Group size {gs}: RMSE={error:.6f}, "
          f"effective bits={effective_bits:.2f}")
📊

Group Size Impact on Quality and Size (Llama 70B)

Group SizeEffective BitsModel Size (GB)PerplexityRMSE vs FP16
32 4.62 38.4 3.41 0.00312
64 4.31 35.8 3.49 0.00387
128 4.16 34.5 3.58 0.00465
256 4.08 33.9 3.72 0.00548
Channel-wise 4.00 33.2 4.15 0.00812
Note: Group size 128 is the standard trade-off. Group size 32 recovers significant quality at modest size increase.

Technique 2: Clipping (Outlier Suppression)

Instead of letting outliers dictate the scale, clip the weight range to optimize for the majority of values:

def quantize_with_clipping(weight, bits=4, clip_ratio=0.999):
    """Quantize with percentile-based clipping."""
    # Find clip thresholds based on percentile
    sorted_abs = weight.abs().flatten().sort().values
    clip_idx = int(len(sorted_abs) * clip_ratio)
    clip_val = sorted_abs[clip_idx].item()

    # Clip weights
    weight_clipped = weight.clamp(-clip_val, clip_val)

    # Quantize clipped weights
    wmin = weight_clipped.min()
    wmax = weight_clipped.max()
    num_levels = 2 ** bits - 1
    scale = (wmax - wmin) / num_levels
    zero_point = torch.round(-wmin / scale)
    quantized = torch.round(weight_clipped / scale + zero_point).clamp(0, num_levels)
    dequantized = (quantized - zero_point) * scale

    # Error analysis
    clip_error = torch.sum((weight - weight_clipped) ** 2)  # Error from clipping
    quant_error = torch.sum((weight_clipped - dequantized) ** 2)  # Quantization error
    total_error = torch.sum((weight - dequantized) ** 2)

    return {
        'clip_error': clip_error.item(),
        'quant_error': quant_error.item(),
        'total_error': total_error.item(),
        'num_clipped': (weight.abs() > clip_val).sum().item()
    }

# Sweep clip ratios to find optimal
for ratio in [0.99, 0.995, 0.999, 0.9995, 1.0]:
    result = quantize_with_clipping(sample_weight, clip_ratio=ratio)
    print(f"Clip {ratio}: total_error={result['total_error']:.6f}, "
          f"clipped={result['num_clipped']} values")

Technique 3: SmoothQuant (Activation Migration)

SmoothQuant migrates quantization difficulty from activations to weights by applying per-channel scaling:

Y=(Xdiag(s)1)(diag(s)W)Y = (X \cdot \text{diag}(s)^{-1}) \cdot (\text{diag}(s) \cdot W)

The scaling factor ss is chosen to balance the outlier ranges between XX and WW:

sj=max(Xj)αmax(Wj)1αs_j = \frac{\max(|X_j|)^\alpha}{\max(|W_j|)^{1-\alpha}}

where α[0,1]\alpha \in [0, 1] controls how much difficulty is migrated from activations to weights. Typically α=0.5\alpha = 0.5.

def compute_smooth_scales(activation_max, weight_max, alpha=0.5):
    """Compute SmoothQuant scaling factors."""
    # activation_max: [hidden_dim] - max absolute activation per channel
    # weight_max: [hidden_dim] - max absolute weight per input channel
    scales = (activation_max.pow(alpha) /
              weight_max.pow(1 - alpha)).clamp(min=1e-5)
    return scales

def apply_smoothquant(linear_layer, activation_max, alpha=0.5):
    """Apply SmoothQuant to a linear layer."""
    weight = linear_layer.weight.data  # [out, in]
    weight_max = weight.abs().amax(dim=0)  # [in]
    scales = compute_smooth_scales(activation_max, weight_max, alpha)

    # Scale weights: W_smooth = W * diag(scales)
    linear_layer.weight.data = weight * scales.unsqueeze(0)

    # The inverse scaling diag(scales)^{-1} is applied to activations
    # at runtime (fused into previous layer's output or layernorm)
    return scales
ℹ️ SmoothQuant Works Best for INT8

SmoothQuant was designed for INT8 weight + INT8 activation quantization (W8A8). For INT4 weight-only quantization, the activation migration is less relevant because activations remain in FP16. However, the per-channel analysis from SmoothQuant is valuable for identifying which channels have the worst outlier behavior.

Calibration Dataset Selection

The calibration dataset used for quantization (GPTQ, AWQ) significantly affects quality. A poor calibration set produces poor scale factors.

def evaluate_calibration_quality(model, quant_method, calib_datasets, eval_data):
    """Compare calibration datasets by measuring downstream quality."""
    results = {}

    for name, calib_data in calib_datasets.items():
        # Quantize with this calibration set
        quantized_model = quant_method(model, calib_data, bits=4, group_size=128)

        # Evaluate on held-out data
        ppl = compute_perplexity(quantized_model, eval_data)
        results[name] = ppl
        print(f"Calibration: {name}, Eval PPL: {ppl:.4f}")

    return results

# Common calibration datasets
calib_datasets = {
    'c4_128':       load_c4(num_samples=128, seq_len=2048),
    'c4_512':       load_c4(num_samples=512, seq_len=2048),
    'wikitext':     load_wikitext(num_samples=128, seq_len=2048),
    'pile_sample':  load_pile(num_samples=128, seq_len=2048),
    'domain_data':  load_custom_domain(num_samples=128, seq_len=2048),
}
📊

Calibration Dataset Impact on INT4 Quality (Llama 70B GPTQ)

Calibration SetSamplesSeq LengthWikiText PPLMMLU AccNotes
C4 (128 samples) 128 2048 3.58 63.2% Standard default
C4 (512 samples) 512 2048 3.55 63.5% Marginal improvement
C4 (32 samples) 32 2048 3.71 62.1% Too few samples
WikiText-2 128 2048 3.52 63.1% Slightly better PPL, same acc
The Pile 128 2048 3.56 63.8% Good diversity
Code only 128 2048 3.85 61.4% Poor for general use
Random noise 128 2048 4.92 55.8% Worst case
Note: C4 with 128 samples is the standard. The Pile offers better diversity. Domain-specific calibration data helps for domain-specific deployment.

Calibration Best Practices

# Best practices for calibration data:
# Use at least 128 samples
# Use sequence length >= 2048 (longer captures more weight activation patterns)
# Match the domain of your deployment (code model -> code calibration)
# Include diverse content (not all short prompts or all long documents)
# Avoid repetitive or degenerate text

def create_calibration_dataset(tokenizer, texts, num_samples=128, seq_len=2048):
    """Create properly formatted calibration dataset."""
    encodings = []
    for text in texts:
        tokens = tokenizer.encode(text, add_special_tokens=False)
        if len(tokens) >= seq_len:
            # Take a random window
            start = np.random.randint(0, len(tokens) - seq_len)
            encodings.append(tokens[start:start + seq_len])

        if len(encodings) >= num_samples:
            break

    if len(encodings) < num_samples:
        print(f"Warning: only {len(encodings)} samples (requested {num_samples})")

    return torch.tensor(encodings[:num_samples])

Mixed-Precision Quantization

The most effective quality recovery technique is mixed-precision: keeping sensitive layers at higher precision while quantizing the majority to INT4.

def create_mixed_precision_config(sensitivity_results, budget_bits=4.5):
    """Generate mixed-precision config from sensitivity analysis.

    Args:
        sensitivity_results: sorted list of (layer_name, sensitivity_dict)
        budget_bits: target average bits per parameter
    """
    total_params = sum(r[1]['num_params'] for r in sensitivity_results)
    remaining_budget = budget_bits * total_params

    config = {}
    # Assign INT4 (4 bits) to all layers initially
    for name, info in sensitivity_results:
        config[name] = 4
        remaining_budget -= 4 * info['num_params']

    # Upgrade the most sensitive layers until budget is spent
    # Extra bits from INT4 to FP16 = 12 bits per param
    # Extra bits from INT4 to INT8 = 4 bits per param
    for name, info in sensitivity_results:
        if remaining_budget <= 0:
            break
        extra_bits = 12  # Upgrade to FP16 (16 - 4 = 12 extra)
        cost = extra_bits * info['num_params']
        if cost <= remaining_budget:
            config[name] = 16
            remaining_budget -= cost
        else:
            # Try INT8 instead
            extra_bits = 4
            cost = extra_bits * info['num_params']
            if cost <= remaining_budget:
                config[name] = 8
                remaining_budget -= cost

    actual_avg_bits = sum(
        config[name] * info['num_params']
        for name, info in sensitivity_results
    ) / total_params

    return config, actual_avg_bits

Mixed-Precision Quality Recovery (Llama 70B, average 4.5 bits)

(perplexity (WikiText-2))
FP16 (16 bits) Baseline
3.32 perplexity (WikiText-2)
Uniform INT4 (4 bits)
3.58 perplexity (WikiText-2)
Mixed: first/last FP16 (4.5 bits) 77% recovery
3.38 perplexity (WikiText-2)
Mixed: sensitivity-guided (4.5 bits) 85% recovery
3.36 perplexity (WikiText-2)
Uniform INT8 (8 bits) Reference
3.34 perplexity (WikiText-2)

Diagnostic Tools

Weight Distribution Visualization

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

def plot_weight_distribution(model, output_path="weight_dist.png"):
    """Plot weight distributions for all linear layers."""
    fig, axes = plt.subplots(10, 8, figsize=(40, 25))
    axes = axes.flatten()

    idx = 0
    for name, param in model.named_parameters():
        if param.dim() != 2 or idx >= len(axes):
            continue

        weights = param.data.float().cpu().flatten().numpy()
        ax = axes[idx]
        ax.hist(weights, bins=100, density=True, alpha=0.7)
        ax.set_title(name.split('.')[-2] + '.' + name.split('.')[-1],
                     fontsize=6)
        ax.set_xlim(-0.1, 0.1)

        # Mark 3-sigma outlier threshold
        std = np.std(weights)
        ax.axvline(3*std, color='r', linestyle='--', linewidth=0.5)
        ax.axvline(-3*std, color='r', linestyle='--', linewidth=0.5)

        idx += 1

    plt.tight_layout()
    plt.savefig(output_path, dpi=150)
    print(f"Saved to {output_path}")

Quantization Error Heatmap

def quantization_error_heatmap(weight, group_size=128):
    """Compute per-group quantization error heatmap."""
    K, N = weight.shape
    num_groups = K // group_size
    error_map = torch.zeros(num_groups, N)

    for g in range(num_groups):
        group = weight[g*group_size:(g+1)*group_size, :]
        gmin = group.min(dim=0).values
        gmax = group.max(dim=0).values
        scale = (gmax - gmin) / 15.0
        scale = torch.where(scale == 0, torch.ones_like(scale), scale)
        zero = torch.round(-gmin / scale).clamp(0, 15)

        quantized = torch.round(group / scale + zero).clamp(0, 15)
        dequantized = (quantized - zero) * scale
        group_error = torch.mean((group - dequantized) ** 2, dim=0)
        error_map[g, :] = group_error

    return error_map

Quality Recovery Without Increasing Bits

GPTQ with Activation Reordering (desc_act)

GPTQ’s desc_act=True reorders columns by their activation magnitude (descending). This ensures that the most important weights — those multiplied by the largest activations — are quantized first, while quantization error accumulates in less important columns.

# GPTQ with desc_act=True
# Note: incompatible with Marlin kernel, uses ExLlama v2 kernel
from auto_gptq import AutoGPTQForCausalLM

quantize_config = {
    "bits": 4,
    "group_size": 128,
    "desc_act": True,   # Activation-ordered quantization
    "damp_percent": 0.01,
    "sym": False,
    "true_sequential": True,
}

model = AutoGPTQForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf",
    quantize_config=quantize_config
)
model.quantize(calibration_data)

AWQ Salient Weight Protection

AWQ identifies “salient” weights (those multiplied by large activations) and applies per-channel scaling to protect them from quantization error:

# AWQ's core insight: protect salient channels
def awq_scale_search(weight, activation_distribution, bits=4):
    """Search for optimal per-channel scales that minimize quantization error."""
    # activation_distribution: [hidden_dim] mean absolute activation per channel
    _, in_features = weight.shape

    best_scales = torch.ones(in_features)
    best_error = float('inf')

    # Grid search over scale factors
    for alpha in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
        scales = activation_distribution.pow(alpha)
        scales = scales / scales.mean()  # Normalize

        # Apply scale and quantize
        scaled_weight = weight * scales.unsqueeze(0)
        q_weight = quantize(scaled_weight, bits)
        deq_weight = dequantize(q_weight, bits) / scales.unsqueeze(0)

        error = torch.sum((weight - deq_weight) ** 2 *
                          activation_distribution.unsqueeze(0)).item()
        if error < best_error:
            best_error = error
            best_scales = scales.clone()

    return best_scales
📊

Quality Recovery Techniques Comparison (Llama 70B INT4)

TechniquePerplexityDelta vs FP16Extra CostCompatible With
Baseline INT4 (GPTQ, g=128) 3.58 +7.8% None Marlin, ExLlama
Smaller group (g=64) 3.49 +5.1% +3% model size Marlin, ExLlama
desc_act=True 3.44 +3.6% Slower kernel (ExLlama only) ExLlama only
AWQ (activation-aware) 3.42 +3.0% Calibration time Marlin, AWQ kernel
Mixed precision (4.5 avg bits) 3.36 +1.2% +12% model size Custom config
GPTQ + Hessian tuning 3.40 +2.4% Higher calibration cost Marlin, ExLlama
Note: AWQ with Marlin kernel offers the best quality-performance trade-off. Mixed precision offers the best quality but requires custom integration.

Summary

Debugging quantization quality requires systematic analysis: measure per-layer sensitivity to identify the 5-10% of layers that cause 80% of degradation, detect weight and activation outliers that waste quantization bins, choose calibration data that matches your deployment domain, and apply targeted fixes (smaller group size, mixed precision, SmoothQuant, AWQ scaling) rather than increasing bit width uniformly. The first-and-last layer pattern is consistent across model families — keeping these layers at higher precision is the highest-leverage single fix. For production deployments, AWQ with the Marlin kernel provides the best combination of quality preservation and inference speed.