Part of Series Quantization Masterclass 2 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest — Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier — Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 — How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization — The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization — OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

Naive round-to-nearest quantization of a 7B model to INT4 yields 8-12 perplexity points of degradation—completely unusable for production. GPTQ drops that to 0.3-0.8 perplexity points. AWQ matches GPTQ quality while requiring 10x less calibration time. The difference is not hardware—same GPU, same kernels—it’s the algorithm deciding which direction to round each weight. At INT4 precision, every weight occupies one of 16 discrete values, and the gap between “this weight rounds to 7” versus “this weight rounds to 8” can propagate through 80 transformer layers into catastrophic output quality loss.

This post implements the three dominant weight quantization algorithms from scratch: Round-To-Nearest (RTN) as the baseline that fails at INT4, GPTQ with its Hessian-based error compensation, and AWQ with activation-aware salient channel protection. By the end, you will understand exactly how each algorithm decides which rounding direction to choose for every weight, and why GPTQ and AWQ produce dramatically better results.

Round-To-Nearest (RTN): The Baseline

RTN is the simplest possible quantization algorithm. For each weight, compute the scale factor, divide by it, round to the nearest integer, and clamp to the representable range.

Symmetric RTN

In symmetric quantization, zero maps to zero, and the scale is determined by the maximum absolute value:

s=max(W)2b11s = \frac{\max(|W|)}{2^{b-1} - 1}

Wq=clamp(round(Ws),2b1,2b11)W_q = \text{clamp}\left(\text{round}\left(\frac{W}{s}\right), -2^{b-1}, 2^{b-1} - 1\right)

W^=sWq\hat{W} = s \cdot W_q

import torch
import torch.nn as nn
import numpy as np

def quantize_rtn_symmetric(weight, bits=4, group_size=128):
    """
    Round-To-Nearest symmetric quantization.

    Args:
        weight: (out_features, in_features) FP16 weight matrix
        bits: target precision (4 or 8)
        group_size: number of elements per quantization group
            -1 means per-channel (one scale per output row)

    Returns:
        q_weight: quantized integers
        scales: per-group scale factors
    """
    out_f, in_f = weight.shape
    qmin = -(1 << (bits - 1))
    qmax = (1 << (bits - 1)) - 1

    if group_size == -1:
        # Per-channel: one scale per row
        amax = weight.abs().amax(dim=1, keepdim=True)
        scales = amax / qmax
        scales = scales.clamp(min=1e-10)
        q_weight = (weight / scales).round().clamp(qmin, qmax).to(torch.int8)
        return q_weight, scales

    # Per-group quantization
    assert in_f % group_size == 0, f"in_features ({in_f}) must be divisible by group_size ({group_size})"
    num_groups = in_f // group_size

    weight_grouped = weight.reshape(out_f, num_groups, group_size)
    amax = weight_grouped.abs().amax(dim=2, keepdim=True)
    scales = amax / qmax
    scales = scales.clamp(min=1e-10)

    q_weight = (weight_grouped / scales).round().clamp(qmin, qmax).to(torch.int8)
    return q_weight.reshape(out_f, in_f), scales.squeeze(-1)

def dequantize_rtn(q_weight, scales, group_size=128):
    """Dequantize RTN-quantized weights back to FP16."""
    out_f, in_f = q_weight.shape
    if group_size == -1:
        return q_weight.float() * scales

    num_groups = in_f // group_size
    q_grouped = q_weight.reshape(out_f, num_groups, group_size)
    scales_expanded = scales.unsqueeze(-1)
    return (q_grouped.float() * scales_expanded).reshape(out_f, in_f)

Why RTN Fails at INT4

RTN works acceptably at INT8 (256 levels) but degrades badly at INT4 (16 levels). The problem is that each weight is quantized independently, ignoring the impact of rounding errors on the layer’s output. When you round 1000 weights in a single row, the individual rounding errors accumulate and can shift the output by a significant amount.

# Demonstrate RTN quality degradation
torch.manual_seed(42)
weight = torch.randn(256, 1024) * 0.02  # Typical LLM weight scale

# Reference output
x = torch.randn(32, 1024)  # Batch of 32 inputs
y_ref = x @ weight.T

# INT8 RTN
q8, s8 = quantize_rtn_symmetric(weight, bits=8, group_size=128)
w8 = dequantize_rtn(q8, s8, group_size=128)
y_int8 = x @ w8.T
err_int8 = ((y_ref - y_int8) ** 2).mean().item()

# INT4 RTN
q4, s4 = quantize_rtn_symmetric(weight, bits=4, group_size=128)
w4 = dequantize_rtn(q4, s4, group_size=128)
y_int4 = x @ w4.T
err_int4 = ((y_ref - y_int4) ** 2).mean().item()

print(f"INT8 RTN output MSE: {err_int8:.8f}")
print(f"INT4 RTN output MSE: {err_int4:.8f}")
print(f"INT4/INT8 error ratio: {err_int4/err_int8:.1f}x")
# Typical output: INT4 error is 50-100x worse than INT8

Output MSE by Quantization Method (INT4, Llama-style Linear Layer)

(relative MSE)
RTN baseline
100 relative MSE
GPTQ 12.5x better
8 relative MSE
AWQ 10x better
10 relative MSE

Per-Channel vs Per-Group Scaling

Before diving into GPTQ and AWQ, we need to understand quantization granularity, because it determines the quality ceiling for any algorithm.

Per-tensor scaling: One scale factor for the entire weight matrix. Cheapest in metadata overhead, worst in quality. Rarely used for weight quantization.

Per-channel scaling: One scale factor per output channel (row of the weight matrix). Each row has its own dynamic range. This is the standard for INT8 weight quantization.

Per-group scaling: One scale factor per group of gg consecutive elements within a row. With g=128g = 128, a row of 4096 weights has 32 scale factors. This is the standard for INT4 weight quantization.

The smaller the group size, the better the quality (each group has a tighter dynamic range), but the higher the metadata overhead (more scale factors to store).

📊

Scale Factor Overhead by Group Size (INT4, 4096x4096 Matrix)

Group SizeScales per RowTotal ScalesScale Storage (FP16)Overhead vs Weight
Per-tensor 1 1 2 bytes 0.00%
Per-channel 1 4,096 8 KB 0.10%
g256 16 65,536 128 KB 1.56%
g128 32 131,072 256 KB 3.13%
g64 64 262,144 512 KB 6.25%
g32 128 524,288 1 MB 12.5%
Note: Weight matrix at INT4 = 4096*4096*0.5 = 8 MB. g128 adds only 3.13% overhead while significantly improving quality over per-channel.

Group size 128 is the most common choice in practice. It strikes a good balance between quality and overhead, and it aligns well with GPU memory access patterns (128 INT4 elements = 64 bytes = one cache line on many architectures).

GPTQ: Optimal Brain Quantization for LLMs

GPTQ (Frantar et al., 2022) is based on Optimal Brain Surgeon / Optimal Brain Compression. The key insight: when you quantize one weight, you can adjust the remaining unquantized weights to compensate for the quantization error. This is not heuristic — it is the mathematically optimal compensation given a quadratic approximation to the loss.

The GPTQ Algorithm

For a single linear layer with weight matrix WRm×nW \in \mathbb{R}^{m \times n} and a calibration dataset producing inputs XRn×pX \in \mathbb{R}^{n \times p}, GPTQ minimizes:

WXW^XF2\|WX - \hat{W}X\|_F^2

where W^\hat{W} is the quantized weight matrix. The algorithm processes columns of WW one at a time (left to right), and for each column:

  1. Quantize the column using RTN
  2. Compute the quantization error
  3. Distribute the error to all remaining (not yet quantized) columns using the inverse Hessian

The Hessian of the output error with respect to the weights is:

H=XXT+λIH = XX^T + \lambda I

where λ\lambda is a small damping term for numerical stability. The error compensation formula for updating column jj after quantizing column qq is:

δj=wqw^qHqq1Hqj1\delta_j = -\frac{w_q - \hat{w}_q}{H_{qq}^{-1}} \cdot H_{qj}^{-1}

where H1H^{-1} is the inverse Hessian and wqw_q, w^q\hat{w}_q are the original and quantized values of column qq.

ℹ️ Why the Hessian?

The Hessian captures the curvature of the output error surface. If the Hessian entry HqjH_{qj} is large, columns qq and jj are strongly correlated in their effect on the output. Quantizing column qq and compensating through column jj is then highly effective. If HqjH_{qj} is small, the columns are independent and compensation has little effect.

Complete GPTQ Implementation

import torch
import torch.nn as nn

class GPTQ:
    """GPTQ quantizer for a single linear layer."""

    def __init__(self, layer, bits=4, group_size=128, damp_percent=0.01):
        self.layer = layer
        self.bits = bits
        self.group_size = group_size
        self.damp = damp_percent

        self.rows = layer.weight.shape[0]  # out_features
        self.cols = layer.weight.shape[1]  # in_features

        # Hessian accumulator
        self.H = torch.zeros(self.cols, self.cols, device=layer.weight.device,
                             dtype=torch.float32)
        self.nsamples = 0

    def add_batch(self, inp):
        """Accumulate Hessian from a batch of inputs.

        inp: (batch_size, in_features) or (batch_size, seq_len, in_features)
        """
        if inp.dim() == 3:
            inp = inp.reshape(-1, inp.shape[-1])

        batch_size = inp.shape[0]
        inp = inp.float()

        # H += X^T X  (accumulated over batches)
        self.H += inp.T @ inp
        self.nsamples += batch_size

    def quantize(self):
        """Run GPTQ quantization. Returns (q_weight, scales, zeros)."""
        W = self.layer.weight.data.clone().float()
        H = self.H / self.nsamples  # Average Hessian

        # Add damping for numerical stability
        damp = self.damp * torch.diag(H).mean()
        diag_indices = torch.arange(self.cols, device=H.device)
        H[diag_indices, diag_indices] += damp

        # Cholesky decomposition of H for efficient inverse computation
        # We use the Cholesky factor to solve systems efficiently
        try:
            L = torch.linalg.cholesky(H)
        except RuntimeError:
            # If Cholesky fails, add more damping
            H[diag_indices, diag_indices] += 10 * damp
            L = torch.linalg.cholesky(H)

        H_inv = torch.cholesky_inverse(L)

        qmin = -(1 << (self.bits - 1))
        qmax = (1 << (self.bits - 1)) - 1

        q_weight = torch.zeros_like(W, dtype=torch.int32)
        scales = torch.zeros(self.rows, self.cols // self.group_size,
                             device=W.device, dtype=torch.float32)

        # Process in blocks of group_size columns
        for block_start in range(0, self.cols, self.group_size):
            block_end = min(block_start + self.group_size, self.cols)
            block_size = block_end - block_start
            group_idx = block_start // self.group_size

            # Extract the block of columns we are quantizing
            W_block = W[:, block_start:block_end].clone()

            # Compute per-row scale for this group
            amax = W_block.abs().amax(dim=1, keepdim=True)
            scale = amax / qmax
            scale = scale.clamp(min=1e-10)
            scales[:, group_idx] = scale.squeeze()

            # Extract Hessian block
            H_block_inv = H_inv[block_start:block_end, block_start:block_end]

            # Process columns within the block
            Err = torch.zeros_like(W_block)
            for col in range(block_size):
                w = W_block[:, col]

                # Quantize this column
                w_q = (w / scale.squeeze()).round().clamp(qmin, qmax)
                q_weight[:, block_start + col] = w_q.to(torch.int32)

                # Dequantize to get quantized value
                w_deq = w_q * scale.squeeze()

                # Quantization error for this column
                err = (w - w_deq) / H_block_inv[col, col]

                # Compensate remaining columns in this block
                if col < block_size - 1:
                    W_block[:, col + 1:] -= err.unsqueeze(1) * H_block_inv[col, col + 1:].unsqueeze(0)

        return q_weight.to(torch.int8), scales

def gptq_quantize_layer(layer, calibration_inputs, bits=4, group_size=128):
    """Quantize a linear layer using GPTQ.

    Args:
        layer: nn.Linear to quantize
        calibration_inputs: list of input tensors from calibration data
        bits: target precision
        group_size: quantization group size

    Returns:
        q_weight, scales
    """
    gptq = GPTQ(layer, bits=bits, group_size=group_size)

    # Phase 1: Accumulate Hessian from calibration data
    for inp in calibration_inputs:
        gptq.add_batch(inp)

    # Phase 2: Quantize with error compensation
    q_weight, scales = gptq.quantize()
    return q_weight, scales

GPTQ Usage Example

# Create a sample layer and calibration data
torch.manual_seed(42)
layer = nn.Linear(4096, 4096, bias=False).float()

# Generate calibration data (typically 128-512 samples from training data)
calib_data = [torch.randn(1, 128, 4096) for _ in range(128)]

# Quantize
q_weight, scales = gptq_quantize_layer(layer, calib_data, bits=4, group_size=128)

# Measure quality
x_test = torch.randn(32, 4096)
y_ref = layer(x_test)

# Dequantize and compute output
num_groups = 4096 // 128
q_float = q_weight.float().reshape(4096, num_groups, 128)
scales_expanded = scales.unsqueeze(-1)
w_deq = (q_float * scales_expanded).reshape(4096, 4096)
y_gptq = x_test @ w_deq.T

mse = ((y_ref - y_gptq) ** 2).mean().item()
print(f"GPTQ INT4 output MSE: {mse:.8f}")
GPTQ Runtime

GPTQ quantization is a one-time offline cost. For a 7B model, GPTQ with 128 calibration samples takes approximately 10-30 minutes on a single GPU. For a 70B model, expect 2-4 hours. The Hessian computation dominates the cost. Once quantized, the model serves at full INT4 speed with no additional overhead.

AWQ: Activation-Aware Weight Quantization

AWQ (Lin et al., 2023) takes a different approach from GPTQ. Instead of compensating for quantization errors using the Hessian, AWQ identifies salient weight channels — the channels that matter most for output quality — and protects them by scaling before quantization.

The Core Insight

Not all weight channels contribute equally to the output. Some channels carry activations with large magnitudes, and errors in these channels have an outsized impact on quality. AWQ finds these channels by examining activation statistics from calibration data, then applies a per-channel scaling that effectively gives salient channels more quantization resolution.

The process:

  1. Run calibration data through the model, recording per-channel activation magnitudes
  2. Identify salient channels: those with large average activation magnitude
  3. Apply a per-channel scale sjs_j that multiplies weights in channel jj by sjs_j and divides activations in channel jj by sjs_j (preserving the mathematical output)
  4. After scaling, quantize using standard RTN

The scaling does not change the layer’s mathematical function — it just redistributes where quantization error falls. Salient channels get larger scale factors, concentrating more of the quantization range on the values that matter.

The optimal scale for each channel minimizes the quantization error on the output. AWQ uses a grid search over scale factors:

sj=argminsjQ(Wdiag(s))X/sWXF2s_j^* = \arg\min_{s_j} \|Q(W \cdot \text{diag}(s)) X / s - W X\|_F^2

where Q()Q(\cdot) denotes the quantization operator and ss is the vector of per-channel scales.

class AWQ:
    """Activation-Aware Weight Quantization for a single linear layer."""

    def __init__(self, layer, bits=4, group_size=128):
        self.layer = layer
        self.bits = bits
        self.group_size = group_size
        self.rows = layer.weight.shape[0]
        self.cols = layer.weight.shape[1]

    def compute_activation_scales(self, calibration_inputs):
        """Compute per-channel activation magnitudes from calibration data.

        Returns: (in_features,) tensor of average absolute activation per channel.
        """
        act_sum = torch.zeros(self.cols, device=self.layer.weight.device)
        n_samples = 0

        for inp in calibration_inputs:
            if inp.dim() == 3:
                inp = inp.reshape(-1, inp.shape[-1])
            act_sum += inp.abs().sum(dim=0).float()
            n_samples += inp.shape[0]

        return act_sum / n_samples

    def search_scales(self, act_scales, n_grid=20, alpha_range=(0.0, 1.0)):
        """Search for optimal per-channel scaling factors.

        The scale for channel j is:
            s_j = act_scales_j^alpha

        We search over alpha in [0, 1] to find the best trade-off.
        alpha=0 means no scaling (all channels equal).
        alpha=1 means full activation-proportional scaling.
        """
        W = self.layer.weight.data.clone().float()
        qmin = -(1 << (self.bits - 1))
        qmax = (1 << (self.bits - 1)) - 1

        best_alpha = 0.0
        best_error = float('inf')

        for i in range(n_grid + 1):
            alpha = alpha_range[0] + (alpha_range[1] - alpha_range[0]) * i / n_grid

            # Compute per-channel scales: s_j = act_j^alpha
            scales = act_scales.clamp(min=1e-5).pow(alpha)
            scales = scales / scales.mean()  # Normalize to preserve magnitude

            # Apply scaling to weights: W_scaled = W * diag(scales)
            W_scaled = W * scales.unsqueeze(0)

            # Quantize the scaled weights using RTN
            W_grouped = W_scaled.reshape(self.rows, -1, self.group_size)
            amax = W_grouped.abs().amax(dim=2, keepdim=True)
            q_scale = amax / qmax
            q_scale = q_scale.clamp(min=1e-10)

            W_q = (W_grouped / q_scale).round().clamp(qmin, qmax)
            W_deq = (W_q * q_scale).reshape(self.rows, self.cols)

            # Undo the channel scaling: W_final = W_deq / diag(scales)
            W_final = W_deq / scales.unsqueeze(0)

            # Compute output error
            error = ((W - W_final) ** 2).sum().item()

            if error < best_error:
                best_error = error
                best_alpha = alpha

        return best_alpha

    def quantize(self, calibration_inputs):
        """Run AWQ quantization.

        Returns: (q_weight, scales, channel_scales, alpha)
        """
        # Step 1: Compute activation scales
        act_scales = self.compute_activation_scales(calibration_inputs)

        # Step 2: Search for optimal alpha
        alpha = self.search_scales(act_scales)

        # Step 3: Compute channel scaling factors
        channel_scales = act_scales.clamp(min=1e-5).pow(alpha)
        channel_scales = channel_scales / channel_scales.mean()

        # Step 4: Apply scaling and quantize
        W = self.layer.weight.data.clone().float()
        W_scaled = W * channel_scales.unsqueeze(0)

        qmin = -(1 << (self.bits - 1))
        qmax = (1 << (self.bits - 1)) - 1

        num_groups = self.cols // self.group_size
        W_grouped = W_scaled.reshape(self.rows, num_groups, self.group_size)
        amax = W_grouped.abs().amax(dim=2, keepdim=True)
        q_scales = amax / qmax
        q_scales = q_scales.clamp(min=1e-10)

        q_weight = (W_grouped / q_scales).round().clamp(qmin, qmax).to(torch.int8)
        q_weight = q_weight.reshape(self.rows, self.cols)
        q_scales = q_scales.squeeze(-1)

        return q_weight, q_scales, channel_scales, alpha

def awq_quantize_layer(layer, calibration_inputs, bits=4, group_size=128):
    """Quantize a linear layer using AWQ."""
    awq = AWQ(layer, bits=bits, group_size=group_size)
    return awq.quantize(calibration_inputs)

AWQ Usage and Comparison

# Compare RTN, GPTQ, and AWQ on the same layer
torch.manual_seed(42)
layer = nn.Linear(4096, 4096, bias=False).float()
calib_data = [torch.randn(1, 128, 4096) for _ in range(128)]
x_test = torch.randn(32, 4096)
y_ref = layer(x_test)

# RTN
q_rtn, s_rtn = quantize_rtn_symmetric(layer.weight.data, bits=4, group_size=128)
w_rtn = dequantize_rtn(q_rtn, s_rtn, group_size=128)
mse_rtn = ((y_ref - x_test @ w_rtn.T) ** 2).mean().item()

# AWQ
q_awq, s_awq, ch_scales, alpha = awq_quantize_layer(
    layer, calib_data, bits=4, group_size=128
)
# Dequantize AWQ
num_groups = 4096 // 128
q_float = q_awq.float().reshape(4096, num_groups, 128)
w_awq = (q_float * s_awq.unsqueeze(-1)).reshape(4096, 4096) / ch_scales.unsqueeze(0)
mse_awq = ((y_ref - x_test @ w_awq.T) ** 2).mean().item()

print(f"RTN  INT4 MSE: {mse_rtn:.8f}")
print(f"AWQ  INT4 MSE: {mse_awq:.8f}")
print(f"AWQ optimal alpha: {alpha:.2f}")
print(f"Improvement: {mse_rtn / mse_awq:.1f}x")
💡 AWQ vs GPTQ: When to Choose Which

AWQ is faster to run (no Hessian computation), simpler to implement, and produces models that are slightly more robust to different input distributions. GPTQ produces slightly lower quantization error on the calibration distribution because it uses second-order information. In practice, the quality difference is small at INT4 g128. AWQ is preferred when quantization speed matters. GPTQ is preferred when you have good calibration data and want the absolute best quality.

GPTQ vs AWQ vs RTN: Quality Benchmarks

📊

Perplexity on WikiText-2 (Lower is Better)

ModelFP16RTN INT4 g128GPTQ INT4 g128AWQ INT4 g128
Llama 2 7B 5.47 5.96 5.61 5.60
Llama 2 13B 4.88 5.22 4.98 4.97
Llama 2 70B 3.32 3.58 3.41 3.40
Mistral 7B 5.25 5.68 5.38 5.36
Note: GPTQ and AWQ achieve similar quality, both significantly better than RTN. Calibration: 128 samples from C4. Group size 128 for all INT4 methods.

Perplexity Degradation from FP16 (Llama 2 70B, WikiText-2)

(perplexity increase)
RTN INT4
0.26 perplexity increase
GPTQ INT4
0.09 perplexity increase
AWQ INT4
0.08 perplexity increase

The key observation: GPTQ and AWQ reduce the perplexity degradation from RTN by 3-4x. At 70B scale, the degradation from AWQ INT4 is only 0.08 perplexity points over FP16 — negligible for most applications.

The Dequantization Kernel: How Quantized Weights Are Used at Inference

During inference, quantized weights must be dequantized before the matrix multiplication. There are two approaches:

Weight-only quantization (W4A16): Weights are stored in INT4, dequantized to FP16 on-the-fly during the GEMM. The activations remain in FP16. This is the standard approach for GPTQ and AWQ models.

Weight-and-activation quantization (W4A4 or W8A8): Both weights and activations are quantized, and the GEMM is performed in integer arithmetic. This requires activation quantization (covered in Part 3).

For W4A16, the dequantization kernel runs as part of the GEMM:

def w4a16_linear_forward(x, q_weight, scales, group_size=128):
    """
    Simulated W4A16 forward pass.

    In practice, this is fused into a single CUDA kernel that:
    1. Loads INT4 weights from global memory (half the bandwidth)
    2. Dequantizes to FP16 in registers
    3. Performs the FP16 GEMM

    x: (batch, in_features) FP16
    q_weight: (out_features, in_features) INT4 packed
    scales: (out_features, in_features // group_size) FP16
    """
    out_f, in_f = q_weight.shape
    num_groups = in_f // group_size

    # Dequantize
    q_float = q_weight.float().reshape(out_f, num_groups, group_size)
    w_deq = (q_float * scales.unsqueeze(-1).float()).reshape(out_f, in_f)

    # GEMM in FP16
    return x.float() @ w_deq.T
Memory Bandwidth is the Bottleneck

For autoregressive decode (batch size 1), the GEMM is memory-bandwidth-bound. Reading INT4 weights requires exactly half the bandwidth of INT8 and one quarter of FP16. The dequantization arithmetic is negligible compared to the memory transfer time. This is why INT4 weight quantization delivers nearly 4x decode speedup on bandwidth-limited GPUs.

Calibration Data Selection

Both GPTQ and AWQ require calibration data to compute the Hessian (GPTQ) or activation statistics (AWQ). The choice of calibration data matters:

How much data: 128-512 samples is typical. More data gives diminishing returns. GPTQ is more sensitive to calibration data quality than AWQ.

Which data: Use data similar to your target distribution. For general-purpose models, a sample from C4 or WikiText works well. For domain-specific models, use domain data.

Sequence length: Match your inference sequence length. Using short calibration sequences for a model that will process long sequences can lead to suboptimal scale factors.

def prepare_calibration_data(tokenizer, dataset_text, n_samples=128, seq_len=2048):
    """Prepare calibration data for GPTQ/AWQ."""
    import random
    random.seed(42)

    samples = []
    for text in dataset_text:
        tokens = tokenizer.encode(text)
        if len(tokens) >= seq_len:
            start = random.randint(0, len(tokens) - seq_len)
            samples.append(torch.tensor(tokens[start:start + seq_len]).unsqueeze(0))
        if len(samples) >= n_samples:
            break

    return samples

Practical Integration: Quantizing a Full Model

Here is the complete workflow for quantizing all linear layers in a transformer model:

def quantize_model(model, calibration_loader, method='awq', bits=4, group_size=128):
    """Quantize all linear layers in a model.

    Args:
        model: the transformer model
        calibration_loader: dataloader yielding calibration inputs
        method: 'rtn', 'gptq', or 'awq'
        bits: target precision
        group_size: quantization group size
    """
    quantized_layers = {}

    # Hook to capture inputs to each linear layer
    hooks = {}
    layer_inputs = {}

    def make_hook(name):
        def hook_fn(module, inp, out):
            if name not in layer_inputs:
                layer_inputs[name] = []
            layer_inputs[name].append(inp[0].detach().cpu())
        return hook_fn

    # Register hooks on all linear layers
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            hooks[name] = module.register_forward_hook(make_hook(name))

    # Run calibration data through model
    model.eval()
    with torch.no_grad():
        for batch in calibration_loader:
            model(batch)

    # Remove hooks
    for h in hooks.values():
        h.remove()

    # Quantize each layer
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and name in layer_inputs:
            inputs = layer_inputs[name]

            if method == 'rtn':
                q_w, scales = quantize_rtn_symmetric(
                    module.weight.data, bits=bits, group_size=group_size
                )
                quantized_layers[name] = (q_w, scales)

            elif method == 'gptq':
                q_w, scales = gptq_quantize_layer(
                    module, inputs, bits=bits, group_size=group_size
                )
                quantized_layers[name] = (q_w, scales)

            elif method == 'awq':
                q_w, scales, ch_scales, alpha = awq_quantize_layer(
                    module, inputs, bits=bits, group_size=group_size
                )
                quantized_layers[name] = (q_w, scales, ch_scales)

            print(f"Quantized {name}: {module.weight.shape}")

    return quantized_layers

Summary

Three algorithms, one goal: compress weights from 16-bit to 4-bit with minimal quality loss.

RTN rounds each weight independently. Simple, fast, but the accumulated rounding errors degrade quality significantly at INT4.

GPTQ uses the Hessian (second-order curvature information from calibration data) to compensate each rounding decision by adjusting the remaining weights. This is the mathematically optimal approach under a quadratic approximation.

AWQ identifies channels with large activations and scales them to receive more quantization resolution. Conceptually simpler than GPTQ, nearly equivalent quality, and faster to run.

All three methods produce INT4 weights with per-group scale factors. The quantized model is served using W4A16 kernels that dequantize weights on-the-fly during the GEMM. The memory bandwidth savings translate directly to faster decode throughput.

In the next post, we tackle the harder problem: quantizing activations, where the challenge shifts from static weight distributions to dynamic, outlier-heavy activation distributions.