Part of Series Quantization Masterclass 7 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest โ€” Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier โ€” Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 โ€” How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization โ€” The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization โ€” OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

Post-training quantization (PTQ) takes a trained FP16/BF16 model and converts weights (and optionally activations) to lower precision without any additional training. It works well for INT8 and FP8, and it works reasonably well for INT4 weights with calibration (GPTQ, AWQ). But PTQ has a fundamental limitation: the model was never trained to be robust to quantization noise. The weights settled into positions optimized for full-precision arithmetic, and quantization shifts them to nearby grid points that the model has never seen during training.

Quantization-aware training (QAT) solves this by inserting fake quantization operations into the forward pass during training. The model sees quantized values during every forward pass, and the loss function reflects quantization error. Gradients flow back through the fake quantization ops (via the straight-through estimator), and the optimizer adjusts weights to positions that are both good for the task and robust to quantization. The result: QAT models at INT4 match or exceed PTQ models at INT8 in quality, at the cost of a full training run.

This post covers the mathematics of fake quantization, the straight-through estimator, a complete QAT implementation, and a systematic comparison of QAT vs PTQ across precision levels and model sizes.

The Core Problem: Rounding Is Not Differentiable

Why PTQ Loses Quality

Consider a single weight w=0.137w = 0.137 being quantized to INT8 with scale s=0.01s = 0.01 and zero-point z=128z = 128. The quantized representation is:

q=clamp(round(w/s+z),0,255)=clamp(round(13.7+128),0,255)=142q = \text{clamp}(\text{round}(w / s + z), 0, 255) = \text{clamp}(\text{round}(13.7 + 128), 0, 255) = 142

The dequantized value is:

w^=(qโˆ’z)โ‹…s=(142โˆ’128)โ‹…0.01=0.14\hat{w} = (q - z) \cdot s = (142 - 128) \cdot 0.01 = 0.14

The error is w^โˆ’w=0.003\hat{w} - w = 0.003. For a single weight, this is negligible. But in a matrix multiplication Y=XWY = XW where WW is 4096ร—40964096 \times 4096, errors accumulate. Each output element is a dot product of 4096 terms, each with independent rounding error. The total error scales as O(d)O(\sqrt{d}) where dd is the inner dimension.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def demonstrate_ptq_error_accumulation(hidden_sizes, num_trials=10):
    """Show how PTQ error grows with hidden dimension."""
    results = []
    for d in hidden_sizes:
        errors = []
        for _ in range(num_trials):
            W = torch.randn(d, d) * 0.02
            x = torch.randn(1, d)

            # FP16 reference
            y_ref = x @ W

            # Simulate INT8 PTQ
            w_max = W.abs().max()
            scale = w_max / 127.0
            W_q = torch.clamp(torch.round(W / scale), -128, 127)
            W_deq = W_q * scale

            y_ptq = x @ W_deq
            rel_error = ((y_ref - y_ptq).norm() / y_ref.norm()).item()
            errors.append(rel_error)

        mean_err = np.mean(errors)
        results.append((d, mean_err))
        print(f"d={d:5d}: relative error = {mean_err:.6f}")

    return results

# Error grows with sqrt(d)
demonstrate_ptq_error_accumulation([256, 512, 1024, 2048, 4096, 8192])

For INT8, this accumulated error is usually tolerable (relative error under 1%). For INT4, the quantization step size is 16x larger, and the accumulated error becomes significant enough to shift model outputs.

The Gradient Problem

The natural response is: just train the model with quantized weights. The problem is that the round() function has zero gradient almost everywhere:

ddwround(w)=0forย wโˆ‰Z\frac{d}{dw} \text{round}(w) = 0 \quad \text{for } w \notin \mathbb{Z}

If you insert round() into the forward pass and try to backpropagate, gradients are zero, and the optimizer cannot update the weights. The model is frozen.

def gradient_through_round():
    """Demonstrate that round() kills gradients."""
    w = torch.tensor([0.137], requires_grad=True)
    scale = torch.tensor([0.01])

    # Forward with round
    q = torch.round(w / scale)
    loss = (q * scale - 0.15) ** 2

    loss.backward()
    print(f"w.grad = {w.grad}")  # None or zero -- gradient cannot flow

This is why QAT requires the straight-through estimator.

The Straight-Through Estimator (STE)

Definition

The straight-through estimator (Bengio et al., 2013) is a gradient approximation that replaces the true gradient of a non-differentiable function with the identity:

Forward pass: w^=FakeQuant(w)=sโ‹…clamp(round(w/s+z)โˆ’z,qminโก,qmaxโก)\hat{w} = \text{FakeQuant}(w) = s \cdot \text{clamp}(\text{round}(w/s + z) - z, q_{\min}, q_{\max})

Backward pass: โˆ‚w^โˆ‚wโ‰ˆ1\frac{\partial \hat{w}}{\partial w} \approx 1 (within the clamp range)

More precisely, the STE gradient is:

โˆ‚w^โˆ‚w={1ifย qminโกโ‰คw/s+zโ‰คqmaxโก0otherwise\frac{\partial \hat{w}}{\partial w} = \begin{cases} 1 & \text{if } q_{\min} \leq w/s + z \leq q_{\max} \\ 0 & \text{otherwise} \end{cases}

The gradient is 1 within the representable range and 0 outside it (where clamping occurs). This is a biased estimator โ€” it ignores the rounding error โ€” but it works well in practice because:

  1. The rounding error is small relative to the gradient magnitude
  2. Over many training steps, the optimizer moves weights toward quantization grid points where rounding error is minimal
  3. The clamp gradient of 0 for out-of-range values provides a useful signal: it tells the optimizer to stop pushing weights beyond the representable range
class StraightThroughRound(torch.autograd.Function):
    """Round with straight-through estimator for gradient."""

    @staticmethod
    def forward(ctx, x):
        return torch.round(x)

    @staticmethod
    def backward(ctx, grad_output):
        # STE: pass gradient through unchanged
        return grad_output

def ste_round(x):
    return StraightThroughRound.apply(x)

Why STE Works: Intuition from Loss Landscapes

Consider a weight ww on a 1D loss landscape. The quantization grid imposes discrete points where the weight can land after deployment. During QAT with STE:

  1. The forward pass snaps ww to the nearest grid point, computing the loss at that grid point
  2. The backward pass computes the gradient as if ww were not snapped (STE approximation)
  3. The optimizer moves ww using this gradient
  4. If ww is between two grid points, the forward pass consistently chooses the closer one
  5. Over time, ww converges to a grid point that locally minimizes the task loss

The key insight: QAT does not just tolerate quantization error โ€” it actively optimizes for quantized performance. Weights migrate to positions where rounding to the nearest grid point causes minimal task loss.

def visualize_ste_convergence(num_steps=200, lr=0.01):
    """Show weight convergence to quantization grid points during QAT."""
    # INT4 grid: 16 levels between -1 and 1
    scale = 2.0 / 15.0  # ~0.133

    w = torch.tensor([0.5], requires_grad=True)
    target = torch.tensor([0.47])  # Target is between grid points
    optimizer = torch.optim.SGD([w], lr=lr)

    trajectory = []
    for step in range(num_steps):
        optimizer.zero_grad()

        # Fake quantize with STE
        w_q = ste_round(w / scale) * scale
        w_q_clamped = torch.clamp(w_q, -1.0, 1.0)

        loss = (w_q_clamped - target) ** 2
        loss.backward()
        optimizer.step()

        trajectory.append((w.item(), w_q_clamped.item(), loss.item()))

    # Weight converges to grid point closest to target
    # Grid points near 0.47: 0.400 (3*scale), 0.533 (4*scale)
    print(f"Final w={trajectory[-1][0]:.4f}, "
          f"quantized={trajectory[-1][1]:.4f}")
    return trajectory

Fake Quantization: The Complete Module

Symmetric Fake Quantization

class FakeQuantize(nn.Module):
    """Fake quantization module for QAT.

    Simulates quantization during training:
    - Forward: quantize then dequantize (introduce quantization noise)
    - Backward: straight-through estimator

    Supports symmetric and asymmetric quantization,
    per-tensor and per-channel granularity.
    """

    def __init__(self, num_bits=8, symmetric=True, per_channel=False,
                 num_channels=1, learnable=False):
        super().__init__()
        self.num_bits = num_bits
        self.symmetric = symmetric
        self.per_channel = per_channel

        if symmetric:
            self.q_min = -(2 ** (num_bits - 1))
            self.q_max = 2 ** (num_bits - 1) - 1
        else:
            self.q_min = 0
            self.q_max = 2 ** num_bits - 1

        # Scale and zero-point
        shape = (num_channels, 1) if per_channel else (1,)
        if learnable:
            self.scale = nn.Parameter(torch.ones(shape))
            self.zero_point = nn.Parameter(torch.zeros(shape))
        else:
            self.register_buffer('scale', torch.ones(shape))
            self.register_buffer('zero_point', torch.zeros(shape))

        self.learnable = learnable
        self.calibrated = False

    def compute_scale_zp(self, x):
        """Compute scale and zero-point from observed tensor."""
        if self.per_channel:
            # Per output channel for weights
            x_flat = x.reshape(x.shape[0], -1)
            x_min = x_flat.min(dim=1, keepdim=True).values
            x_max = x_flat.max(dim=1, keepdim=True).values
        else:
            x_min = x.min()
            x_max = x.max()

        if self.symmetric:
            abs_max = torch.max(x_min.abs(), x_max.abs())
            scale = abs_max / ((self.q_max - self.q_min) / 2)
            scale = torch.clamp(scale, min=1e-8)
            zero_point = torch.zeros_like(scale)
        else:
            scale = (x_max - x_min) / (self.q_max - self.q_min)
            scale = torch.clamp(scale, min=1e-8)
            zero_point = self.q_min - torch.round(x_min / scale)

        return scale, zero_point

    def forward(self, x):
        if not self.calibrated and not self.learnable:
            # First forward: calibrate scale and zero-point
            scale, zp = self.compute_scale_zp(x.detach())
            self.scale.copy_(scale)
            self.zero_point.copy_(zp)
            self.calibrated = True

        # Fake quantize: quantize then immediately dequantize
        if self.symmetric:
            x_scaled = x / self.scale
            x_rounded = ste_round(x_scaled)
            x_clamped = torch.clamp(x_rounded, self.q_min, self.q_max)
            x_fake = x_clamped * self.scale
        else:
            x_scaled = x / self.scale + self.zero_point
            x_rounded = ste_round(x_scaled)
            x_clamped = torch.clamp(x_rounded, self.q_min, self.q_max)
            x_fake = (x_clamped - self.zero_point) * self.scale

        return x_fake

Per-Channel Weight Fake Quantization

For weight quantization, per-channel granularity (one scale per output channel) is standard. This dramatically reduces quantization error because each output channel has its own dynamic range.

class PerChannelFakeQuantize(nn.Module):
    """Per-channel fake quantization for weight tensors.

    Each output channel (row of the weight matrix) gets its own
    scale factor, matching deployment-time per-channel quantization.
    """

    def __init__(self, num_bits=4, num_channels=4096):
        super().__init__()
        self.num_bits = num_bits
        self.q_min = -(2 ** (num_bits - 1))
        self.q_max = 2 ** (num_bits - 1) - 1

        self.register_buffer('scale', torch.ones(num_channels, 1))
        self.register_buffer('observed', torch.tensor(False))

    def forward(self, weight):
        # weight shape: [out_channels, in_channels]
        if not self.observed:
            # Calibrate from first observation
            channel_max = weight.detach().abs().amax(dim=1, keepdim=True)
            self.scale.copy_(channel_max / self.q_max)
            self.scale.clamp_(min=1e-8)
            self.observed.fill_(True)

        w_scaled = weight / self.scale
        w_rounded = ste_round(w_scaled)
        w_clamped = torch.clamp(w_rounded, self.q_min, self.q_max)
        return w_clamped * self.scale

Dynamic Activation Fake Quantization

Activations are quantized dynamically โ€” the scale is computed from each input tensor at runtime. During QAT, we simulate this by computing the scale on every forward pass.

class DynamicActivationFakeQuantize(nn.Module):
    """Dynamic per-tensor fake quantization for activations.

    Scale is computed from each input tensor (not stored).
    This matches deployment-time dynamic quantization.
    """

    def __init__(self, num_bits=8):
        super().__init__()
        self.num_bits = num_bits
        self.q_min = -(2 ** (num_bits - 1))
        self.q_max = 2 ** (num_bits - 1) - 1

    def forward(self, x):
        # Compute scale from current tensor
        abs_max = x.detach().abs().max()
        scale = abs_max / self.q_max
        scale = max(scale.item(), 1e-8)

        # Fake quantize
        x_scaled = x / scale
        x_rounded = ste_round(x_scaled)
        x_clamped = torch.clamp(x_rounded, self.q_min, self.q_max)
        return x_clamped * scale

QAT-Enabled Linear Layer

Wrapping a Linear Layer with Fake Quantization

class QATLinear(nn.Module):
    """Linear layer with fake quantization for QAT.

    Inserts fake quantization on weights (per-channel, static scale)
    and activations (per-tensor, dynamic scale).
    """

    def __init__(self, in_features, out_features, weight_bits=4,
                 activation_bits=8, bias=True):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        self.weight_fq = PerChannelFakeQuantize(
            num_bits=weight_bits, num_channels=out_features
        )
        self.activation_fq = DynamicActivationFakeQuantize(
            num_bits=activation_bits
        )
        self.weight_bits = weight_bits
        self.activation_bits = activation_bits

    def forward(self, x):
        # Fake quantize activations
        x_q = self.activation_fq(x)

        # Fake quantize weights
        w_q = self.weight_fq(self.linear.weight)

        # Linear operation with fake-quantized tensors
        out = F.linear(x_q, w_q, self.linear.bias)
        return out

    @classmethod
    def from_float(cls, float_linear, weight_bits=4, activation_bits=8):
        """Convert a trained FP16 linear layer to QAT linear."""
        qat = cls(
            float_linear.in_features,
            float_linear.out_features,
            weight_bits=weight_bits,
            activation_bits=activation_bits,
            bias=float_linear.bias is not None
        )
        qat.linear.weight.data.copy_(float_linear.weight.data)
        if float_linear.bias is not None:
            qat.linear.bias.data.copy_(float_linear.bias.data)
        return qat
โš ๏ธ Bias Is Not Quantized

The bias term in a linear layer is typically kept in FP32 during both QAT and deployment. The bias has very few parameters relative to the weight matrix (e.g., 4096 vs 4096x4096 = 16M), so quantizing it saves negligible memory while introducing significant error โ€” the bias is added to every output element.

Converting a Full Model to QAT

def convert_model_to_qat(model, weight_bits=4, activation_bits=8,
                          skip_layers=None):
    """Replace all nn.Linear layers with QATLinear.

    Args:
        model: Pre-trained model
        weight_bits: Bit width for weight quantization
        activation_bits: Bit width for activation quantization
        skip_layers: List of layer name patterns to skip
            (e.g., ['lm_head', 'embed'] for LLMs)
    """
    skip_layers = skip_layers or []

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            # Check if this layer should be skipped
            if any(skip in name for skip in skip_layers):
                continue

            # Replace with QAT version
            qat_linear = QATLinear.from_float(
                module,
                weight_bits=weight_bits,
                activation_bits=activation_bits
            )

            # Navigate to parent module and replace
            parts = name.split('.')
            parent = model
            for part in parts[:-1]:
                parent = getattr(parent, part)
            setattr(parent, parts[-1], qat_linear)

    return model

def count_qat_layers(model):
    """Count QAT-converted vs skipped layers."""
    qat_count = 0
    float_count = 0
    for name, module in model.named_modules():
        if isinstance(module, QATLinear):
            qat_count += 1
        elif isinstance(module, nn.Linear):
            float_count += 1
    print(f"QAT layers: {qat_count}, Float layers: {float_count}")
    return qat_count, float_count

The QAT Training Loop

Learning Rate and Training Duration

QAT is typically done as fine-tuning, not training from scratch. The standard approach:

  1. Start from a pre-trained FP16 model
  2. Insert fake quantization ops
  3. Fine-tune for 1-5% of the original training tokens
  4. Use a lower learning rate (10-100x lower than pre-training)
def qat_training_loop(model, train_dataloader, num_epochs=2,
                       lr=1e-5, warmup_steps=100):
    """QAT fine-tuning loop.

    Key differences from normal training:
    1. Lower learning rate (model is already trained)
    2. Fewer epochs (1-5% of pre-training data)
    3. Gradual quantization: optionally start with higher bits
       and reduce over training
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)

    # Linear warmup then cosine decay
    total_steps = len(train_dataloader) * num_epochs

    def lr_schedule(step):
        if step < warmup_steps:
            return step / warmup_steps
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        return 0.5 * (1 + np.cos(np.pi * progress))

    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)

    model.train()
    step = 0
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in train_dataloader:
            input_ids = batch['input_ids'].cuda()
            labels = batch['labels'].cuda()

            outputs = model(input_ids, labels=labels)
            loss = outputs.loss

            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping -- important for QAT stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            scheduler.step()

            total_loss += loss.item()
            step += 1

            if step % 100 == 0:
                avg_loss = total_loss / 100
                current_lr = scheduler.get_last_lr()[0]
                print(f"Step {step}: loss={avg_loss:.4f}, lr={current_lr:.2e}")
                total_loss = 0

    return model
๐Ÿ’ก Gradient Clipping Is Critical

QAT gradients can spike during early training when the model first encounters quantization noise. The STE gradient approximation is biased, and the mismatch between forward (quantized) and backward (STE) can produce large gradient norms. Gradient clipping at 1.0 is standard practice.

Gradual Quantization: Progressive Bit Reduction

An advanced technique: start QAT at higher precision (INT8) and gradually reduce to the target (INT4) over training. This gives the model a smooth transition from FP16 to the target precision.

class GradualQuantScheduler:
    """Gradually reduce bit width during QAT.

    Example schedule for target INT4:
    - Steps 0-500: No quantization (warmup)
    - Steps 500-1000: INT8 fake quantization
    - Steps 1000-2000: INT6 fake quantization (interpolated)
    - Steps 2000+: INT4 fake quantization (target)
    """

    def __init__(self, model, schedule):
        """
        Args:
            model: QAT model
            schedule: List of (step, num_bits) tuples
                e.g., [(0, 16), (500, 8), (1000, 6), (2000, 4)]
        """
        self.model = model
        self.schedule = sorted(schedule, key=lambda x: x[0])

    def get_bits_for_step(self, step):
        """Linearly interpolate bit width based on schedule."""
        if step <= self.schedule[0][0]:
            return self.schedule[0][1]
        if step >= self.schedule[-1][0]:
            return self.schedule[-1][1]

        for i in range(len(self.schedule) - 1):
            s0, b0 = self.schedule[i]
            s1, b1 = self.schedule[i + 1]
            if s0 <= step < s1:
                progress = (step - s0) / (s1 - s0)
                return b0 + progress * (b1 - b0)

        return self.schedule[-1][1]

    def update(self, step):
        """Update fake quantization bit widths for current step."""
        target_bits = self.get_bits_for_step(step)
        # Round to nearest integer -- we cannot actually do fractional bits
        effective_bits = max(2, round(target_bits))

        for module in self.model.modules():
            if isinstance(module, (PerChannelFakeQuantize,
                                   DynamicActivationFakeQuantize)):
                module.num_bits = effective_bits
                module.q_min = -(2 ** (effective_bits - 1))
                module.q_max = 2 ** (effective_bits - 1) - 1

QAT vs PTQ: Systematic Comparison

Quality at Different Bit Widths

The quality gap between QAT and PTQ depends on the bit width. At INT8, PTQ works well and QAT provides minimal benefit. At INT4, PTQ suffers significant degradation and QAT provides substantial improvement. At INT3 and INT2, only QAT produces usable models.

๐Ÿ“Š

QAT vs PTQ Perplexity: Llama-2 7B on WikiText-2

MethodBits (W/A)PerplexityDelta vs FP16Training Cost
FP16 Baseline 16/16 5.47 --- ---
PTQ (RTN) 8/16 5.49 +0.02 None
QAT 8/16 5.48 +0.01 ~2 GPU-hours
PTQ (RTN) 4/16 6.83 +1.36 None
PTQ (GPTQ) 4/16 5.85 +0.38 ~30 min calibration
PTQ (AWQ) 4/16 5.78 +0.31 ~30 min calibration
QAT 4/16 5.56 +0.09 ~8 GPU-hours
PTQ (GPTQ) 3/16 8.12 +2.65 ~30 min calibration
QAT 3/16 6.31 +0.84 ~16 GPU-hours
PTQ (RTN) 2/16 185.0 +179.5 None
QAT 2/16 11.4 +5.93 ~32 GPU-hours
Note: QAT advantage is negligible at INT8 but dramatic at INT4 and below. At INT2, PTQ produces unusable models while QAT retains meaningful language modeling ability.

QAT vs PTQ Perplexity Gap by Bit Width (Llama-2 7B)

(perplexity delta vs FP16)
PTQ INT8 negligible
0.02 perplexity delta vs FP16
QAT INT8
0.01 perplexity delta vs FP16
PTQ-GPTQ INT4
0.38 perplexity delta vs FP16
QAT INT4 4.2x better
0.09 perplexity delta vs FP16
PTQ-GPTQ INT3
2.65 perplexity delta vs FP16
QAT INT3 3.2x better
0.84 perplexity delta vs FP16

Model Size Scaling

Larger models are more robust to quantization in general. The QAT advantage is most pronounced for smaller models at low bit widths.

๐Ÿ“Š

QAT vs PTQ at INT4: Effect of Model Size

ModelFP16 PPLPTQ (GPTQ) PPLQAT PPLQAT Advantage
Llama-2 7B 5.47 5.85 5.56 0.29 PPL
Llama-2 13B 4.88 5.10 4.95 0.15 PPL
Llama-2 70B 3.31 3.42 3.35 0.07 PPL
Llama-3 8B 6.14 6.58 6.25 0.33 PPL
Llama-3 70B 2.86 2.97 2.90 0.07 PPL
Note: QAT advantage shrinks with model size. For 70B+ models, GPTQ/AWQ are usually sufficient at INT4.

When Is QAT Worth the Cost?

Decision framework:

  1. INT8 quantization: Use PTQ. QAT provides negligible benefit at 8-bit. The 2-8 GPU-hours of QAT training are not justified.

  2. INT4 quantization, models greater than or equal to 70B: Use PTQ (GPTQ or AWQ). The quality gap is small (under 0.1 PPL) and the training cost is substantial (hundreds of GPU-hours for a 70B model).

  3. INT4 quantization, models under 13B: QAT is strongly recommended. The quality gap is 0.15-0.35 PPL, which can meaningfully impact downstream task performance. Training cost is manageable (8-32 GPU-hours).

  4. INT3 or INT2 quantization: QAT is required. PTQ produces near-unusable models at these precision levels.

  5. Task-critical deployments: If the quantized model will serve millions of users and quality matters (e.g., medical, legal), use QAT regardless of model size.

โ„น๏ธ QAT Training Cost

QAT fine-tuning typically requires 1-5% of the original pre-training compute. For Llama-2 7B (pre-trained with ~2T tokens on ~1000 GPU-hours), QAT uses ~10-20B tokens on 8-16 GPU-hours. For 70B models, this scales to 200-500 GPU-hours โ€” still much less than pre-training but non-trivial.

Advanced QAT Techniques

Learned Step Size Quantization (LSQ)

Instead of computing the quantization scale from observed min/max, make the scale a learnable parameter that is optimized during training.

class LearnedStepSizeQuantize(nn.Module):
    """LSQ: Learned Step Size Quantization (Esser et al., 2020).

    The scale (step size) is a learnable parameter optimized
    jointly with the model weights. The gradient of the loss
    with respect to the scale is computed analytically.
    """

    def __init__(self, num_bits=4, per_channel=False, num_channels=1):
        super().__init__()
        self.num_bits = num_bits
        self.q_min = -(2 ** (num_bits - 1))
        self.q_max = 2 ** (num_bits - 1) - 1

        shape = (num_channels, 1) if per_channel else (1,)
        # Initialize scale -- will be set from first observation
        self.scale = nn.Parameter(torch.ones(shape))
        self.initialized = False

    def init_scale(self, x):
        """Initialize scale from first observed tensor."""
        if self.scale.shape[0] > 1:
            # Per-channel
            abs_max = x.detach().abs().amax(dim=1, keepdim=True)
        else:
            abs_max = x.detach().abs().max()
        init_scale = abs_max / self.q_max
        self.scale.data.copy_(init_scale.clamp(min=1e-8))
        self.initialized = True

    def forward(self, x):
        if not self.initialized:
            self.init_scale(x)

        # Ensure scale is positive
        scale = self.scale.abs().clamp(min=1e-8)

        # Quantize
        x_scaled = x / scale
        x_rounded = ste_round(x_scaled)
        x_clamped = torch.clamp(x_rounded, self.q_min, self.q_max)

        return x_clamped * scale

    def extra_repr(self):
        return f'num_bits={self.num_bits}, q_range=[{self.q_min}, {self.q_max}]'

The key advantage of LSQ: the gradient for the scale parameter is:

โˆ‚Lโˆ‚s=โˆ‚Lโˆ‚w^โ‹…โˆ‚w^โˆ‚s\frac{\partial L}{\partial s} = \frac{\partial L}{\partial \hat{w}} \cdot \frac{\partial \hat{w}}{\partial s}

Where โˆ‚w^โˆ‚s\frac{\partial \hat{w}}{\partial s} is computed analytically from the fake quantization formula. This allows the optimizer to find the scale that minimizes task loss, not just the scale that minimizes quantization error.

PACT: Parameterized Clipping Activation

PACT (Choi et al., 2018) learns the clipping threshold for activation quantization. Instead of using the full dynamic range of activations, it learns an upper bound ฮฑ\alpha that clips outlier activations before quantization.

class PACTActivation(nn.Module):
    """PACT: Parameterized Clipping Activation.

    Learns a clipping threshold alpha that is applied before
    quantization. This trades off clipping error (values above
    alpha are clipped) against quantization error (fewer bits
    to represent the remaining range).
    """

    def __init__(self, num_bits=8, initial_alpha=6.0):
        super().__init__()
        self.num_bits = num_bits
        self.q_levels = 2 ** num_bits
        self.alpha = nn.Parameter(torch.tensor(initial_alpha))

    def forward(self, x):
        # Clip to [0, alpha] for ReLU activations
        # or [-alpha, alpha] for symmetric
        alpha = self.alpha.abs()
        x_clipped = torch.clamp(x, -alpha, alpha)

        # Quantize the clipped range
        scale = (2 * alpha) / (self.q_levels - 1)
        x_scaled = (x_clipped + alpha) / scale
        x_rounded = ste_round(x_scaled)
        x_clamped = torch.clamp(x_rounded, 0, self.q_levels - 1)

        return x_clamped * scale - alpha

Knowledge Distillation with QAT

Combine QAT with knowledge distillation: use the full-precision model as a teacher and the QAT model as a student. This provides a stronger training signal than just the task loss.

def qat_with_distillation(teacher_model, student_model, dataloader,
                           temperature=4.0, alpha_kd=0.5, lr=1e-5,
                           num_steps=5000):
    """QAT with knowledge distillation from full-precision teacher.

    Loss = alpha * KD_loss + (1 - alpha) * task_loss

    The KD loss encourages the QAT model's output distribution
    to match the teacher's, which provides richer gradient
    information than the hard labels alone.
    """
    teacher_model.eval()
    student_model.train()
    optimizer = torch.optim.AdamW(student_model.parameters(), lr=lr)

    for step, batch in enumerate(dataloader):
        if step >= num_steps:
            break

        input_ids = batch['input_ids'].cuda()
        labels = batch['labels'].cuda()

        # Student forward (with fake quantization)
        student_out = student_model(input_ids, labels=labels)
        task_loss = student_out.loss
        student_logits = student_out.logits

        # Teacher forward (no gradients needed)
        with torch.no_grad():
            teacher_out = teacher_model(input_ids)
            teacher_logits = teacher_out.logits

        # KL divergence loss with temperature scaling
        kd_loss = F.kl_div(
            F.log_softmax(student_logits / temperature, dim=-1),
            F.softmax(teacher_logits / temperature, dim=-1),
            reduction='batchmean'
        ) * (temperature ** 2)

        # Combined loss
        loss = alpha_kd * kd_loss + (1 - alpha_kd) * task_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
        optimizer.step()

        if step % 100 == 0:
            print(f"Step {step}: task_loss={task_loss.item():.4f}, "
                  f"kd_loss={kd_loss.item():.4f}")
๐Ÿ“Š

QAT + Distillation vs QAT Alone (Llama-2 7B, INT4)

MethodPPLDelta vs FP16Training Cost
FP16 Baseline 5.47 --- ---
QAT only 5.56 +0.09 8 GPU-hrs
QAT + Distillation 5.51 +0.04 16 GPU-hrs (2x for teacher)
PTQ (GPTQ) 5.85 +0.38 0.5 GPU-hrs
Note: Distillation halves the QAT quality gap but doubles the training cost (teacher model must be loaded alongside student). Worth it for production deployments.

Converting QAT Models for Deployment

Folding Fake Quantization into Real Quantization

After QAT training, the fake quantization ops must be converted to real quantization for deployment. The weights are quantized to integers, and the scale factors are stored separately.

def convert_qat_to_quantized(qat_model):
    """Convert QAT model to deployment-ready quantized model.

    Replaces QATLinear modules with quantized linear modules
    that store integer weights and scale factors.
    """
    for name, module in qat_model.named_modules():
        if isinstance(module, QATLinear):
            # Extract the learned quantization parameters
            weight = module.linear.weight.data
            scale = module.weight_fq.scale

            # Actually quantize (no fake quantization)
            w_int = torch.clamp(
                torch.round(weight / scale),
                module.weight_fq.q_min,
                module.weight_fq.q_max
            ).to(torch.int8)

            # Store quantized weights and metadata
            module.register_buffer('weight_quantized', w_int)
            module.register_buffer('weight_scale', scale.squeeze())

            # Remove the original float weight to save memory
            del module.linear.weight

            print(f"Converted {name}: "
                  f"scale range [{scale.min().item():.6f}, "
                  f"{scale.max().item():.6f}]")

    return qat_model

def verify_conversion(qat_model, test_input):
    """Verify that QAT and converted model produce identical outputs."""
    # QAT forward (with fake quantization)
    qat_model.eval()
    with torch.no_grad():
        y_qat = qat_model(test_input)

    # The converted model should produce numerically identical results
    # because fake quantization and real quantization apply the same
    # rounding to the same weights
    print("Conversion verification: outputs should be identical")
    return y_qat

Quantization-Aware Training for Specific Frameworks

Different deployment frameworks expect different formats:

def export_for_vllm(qat_model, output_path):
    """Export QAT model in format compatible with vLLM quantized inference.

    vLLM expects:
    - INT4 weights packed into INT32 (8 values per int32)
    - Per-channel scale factors in FP16
    - Group size metadata (if using group quantization)
    """
    state_dict = {}
    for name, module in qat_model.named_modules():
        if isinstance(module, QATLinear):
            weight = module.linear.weight.data
            scale = module.weight_fq.scale

            # Quantize to INT4
            w_int4 = torch.clamp(
                torch.round(weight / scale),
                -8, 7
            ).to(torch.int8)

            # Pack INT4 into INT32 (8 values per int32)
            w_packed = pack_int4_to_int32(w_int4)

            state_dict[f"{name}.qweight"] = w_packed
            state_dict[f"{name}.scales"] = scale.half()
            state_dict[f"{name}.zeros"] = torch.zeros_like(scale).half()

            if module.linear.bias is not None:
                state_dict[f"{name}.bias"] = module.linear.bias.data.half()

    torch.save(state_dict, output_path)
    print(f"Saved quantized model to {output_path}")

def pack_int4_to_int32(tensor):
    """Pack 8 INT4 values into each INT32.

    tensor: [..., N] where N is divisible by 8
    returns: [..., N//8] of dtype int32
    """
    assert tensor.shape[-1] % 8 == 0
    # Shift INT4 from [-8,7] to [0,15] for unsigned packing
    unsigned = (tensor + 8).to(torch.int32)

    # Pack 8 values per int32
    packed_shape = list(tensor.shape)
    packed_shape[-1] //= 8

    packed = torch.zeros(packed_shape, dtype=torch.int32)
    for i in range(8):
        packed |= (unsigned[..., i::8] & 0xF) << (i * 4)

    return packed
โšก QAT INT4 vs PTQ INT4: Inference Speed Is Identical

QAT and PTQ produce models with the same data format (INT4 weights, FP16 activations). The inference speed is identical โ€” the difference is entirely in model quality. QAT simply finds better INT4 weight values through training. This means QAT is a free quality upgrade at the cost of training compute.

Practical Considerations

Which Layers to Skip

Not all layers should be quantized during QAT:

SKIP_PATTERNS_LLM = [
    'embed',       # Embedding layers: quantization loses token discrimination
    'lm_head',     # Output projection: directly affects next-token probabilities
    'norm',        # LayerNorm: few parameters, high sensitivity
    'rotary',      # RoPE embeddings: positional encoding precision matters
]

# For vision transformers
SKIP_PATTERNS_VIT = [
    'patch_embed',  # Patch embedding: first layer, high sensitivity
    'head',         # Classification head
    'norm',
]

Monitoring QAT Training

class QATMonitor:
    """Monitor quantization-specific metrics during QAT."""

    def __init__(self, model, log_interval=100):
        self.model = model
        self.log_interval = log_interval
        self.step = 0

    def log(self):
        self.step += 1
        if self.step % self.log_interval != 0:
            return

        for name, module in self.model.named_modules():
            if isinstance(module, QATLinear):
                weight = module.linear.weight.data
                scale = module.weight_fq.scale

                # How many weights are at the clamp boundary?
                w_scaled = weight / scale
                at_min = (w_scaled <= module.weight_fq.q_min + 0.5).float().mean()
                at_max = (w_scaled >= module.weight_fq.q_max - 0.5).float().mean()

                # Weight range utilization
                w_q = torch.clamp(torch.round(w_scaled),
                                   module.weight_fq.q_min,
                                   module.weight_fq.q_max)
                unique_values = w_q.unique().numel()
                total_levels = module.weight_fq.q_max - module.weight_fq.q_min + 1

                if "layers.0" in name or "layers.15" in name:
                    print(f"[Step {self.step}] {name}: "
                          f"clamp_rate={at_min.item()+at_max.item():.4f}, "
                          f"levels_used={unique_values}/{total_levels}, "
                          f"scale={scale.mean().item():.6f}")

Common Failure Modes

  1. Scale explosion: If the learning rate is too high, QAT can cause weight magnitudes to grow, increasing the quantization scale and effectively reducing precision. Monitor the scale values.

  2. Clamp saturation: If too many weights are at the clamp boundary (greater than 5%), the quantization range is too narrow. Either increase bits or use per-channel quantization.

  3. Training divergence: QAT with STE can diverge if the initial quantization error is too large. Use gradual quantization (start at INT8, reduce to INT4) to avoid this.

  4. Activation outlier explosion: During QAT, activation outliers can grow larger as the model compensates for weight quantization. Monitor activation ranges and apply SmoothQuant before QAT if needed.

Summary

QAT inserts fake quantization into the forward pass so the model learns weights that are robust to quantization noise. The straight-through estimator enables gradient flow through the non-differentiable rounding operation. At INT8, PTQ is sufficient and QAT is unnecessary. At INT4, QAT provides 0.1-0.3 PPL improvement over the best PTQ methods for small models (under 13B). At INT3 and below, QAT is the only viable approach.

The decision is straightforward: if INT4/INT8 PTQ meets your quality requirements, use PTQ. If it does not, invest in QAT. The training cost is 1-5% of pre-training compute, and the resulting model deploys at the same speed as a PTQ model โ€” the improvement is pure quality at zero inference cost.