Post-training quantization (PTQ) takes a trained FP16/BF16 model and converts weights (and optionally activations) to lower precision without any additional training. It works well for INT8 and FP8, and it works reasonably well for INT4 weights with calibration (GPTQ, AWQ). But PTQ has a fundamental limitation: the model was never trained to be robust to quantization noise. The weights settled into positions optimized for full-precision arithmetic, and quantization shifts them to nearby grid points that the model has never seen during training.
Quantization-aware training (QAT) solves this by inserting fake quantization operations into the forward pass during training. The model sees quantized values during every forward pass, and the loss function reflects quantization error. Gradients flow back through the fake quantization ops (via the straight-through estimator), and the optimizer adjusts weights to positions that are both good for the task and robust to quantization. The result: QAT models at INT4 match or exceed PTQ models at INT8 in quality, at the cost of a full training run.
This post covers the mathematics of fake quantization, the straight-through estimator, a complete QAT implementation, and a systematic comparison of QAT vs PTQ across precision levels and model sizes.
The Core Problem: Rounding Is Not Differentiable
Why PTQ Loses Quality
Consider a single weight being quantized to INT8 with scale and zero-point . The quantized representation is:
The dequantized value is:
The error is . For a single weight, this is negligible. But in a matrix multiplication where is , errors accumulate. Each output element is a dot product of 4096 terms, each with independent rounding error. The total error scales as where is the inner dimension.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def demonstrate_ptq_error_accumulation(hidden_sizes, num_trials=10):
"""Show how PTQ error grows with hidden dimension."""
results = []
for d in hidden_sizes:
errors = []
for _ in range(num_trials):
W = torch.randn(d, d) * 0.02
x = torch.randn(1, d)
# FP16 reference
y_ref = x @ W
# Simulate INT8 PTQ
w_max = W.abs().max()
scale = w_max / 127.0
W_q = torch.clamp(torch.round(W / scale), -128, 127)
W_deq = W_q * scale
y_ptq = x @ W_deq
rel_error = ((y_ref - y_ptq).norm() / y_ref.norm()).item()
errors.append(rel_error)
mean_err = np.mean(errors)
results.append((d, mean_err))
print(f"d={d:5d}: relative error = {mean_err:.6f}")
return results
# Error grows with sqrt(d)
demonstrate_ptq_error_accumulation([256, 512, 1024, 2048, 4096, 8192])
For INT8, this accumulated error is usually tolerable (relative error under 1%). For INT4, the quantization step size is 16x larger, and the accumulated error becomes significant enough to shift model outputs.
The Gradient Problem
The natural response is: just train the model with quantized weights. The problem is that the round() function has zero gradient almost everywhere:
If you insert round() into the forward pass and try to backpropagate, gradients are zero, and the optimizer cannot update the weights. The model is frozen.
def gradient_through_round():
"""Demonstrate that round() kills gradients."""
w = torch.tensor([0.137], requires_grad=True)
scale = torch.tensor([0.01])
# Forward with round
q = torch.round(w / scale)
loss = (q * scale - 0.15) ** 2
loss.backward()
print(f"w.grad = {w.grad}") # None or zero -- gradient cannot flow
This is why QAT requires the straight-through estimator.
The Straight-Through Estimator (STE)
Definition
The straight-through estimator (Bengio et al., 2013) is a gradient approximation that replaces the true gradient of a non-differentiable function with the identity:
Forward pass:
Backward pass: (within the clamp range)
More precisely, the STE gradient is:
The gradient is 1 within the representable range and 0 outside it (where clamping occurs). This is a biased estimator โ it ignores the rounding error โ but it works well in practice because:
- The rounding error is small relative to the gradient magnitude
- Over many training steps, the optimizer moves weights toward quantization grid points where rounding error is minimal
- The clamp gradient of 0 for out-of-range values provides a useful signal: it tells the optimizer to stop pushing weights beyond the representable range
class StraightThroughRound(torch.autograd.Function):
"""Round with straight-through estimator for gradient."""
@staticmethod
def forward(ctx, x):
return torch.round(x)
@staticmethod
def backward(ctx, grad_output):
# STE: pass gradient through unchanged
return grad_output
def ste_round(x):
return StraightThroughRound.apply(x)
Why STE Works: Intuition from Loss Landscapes
Consider a weight on a 1D loss landscape. The quantization grid imposes discrete points where the weight can land after deployment. During QAT with STE:
- The forward pass snaps to the nearest grid point, computing the loss at that grid point
- The backward pass computes the gradient as if were not snapped (STE approximation)
- The optimizer moves using this gradient
- If is between two grid points, the forward pass consistently chooses the closer one
- Over time, converges to a grid point that locally minimizes the task loss
The key insight: QAT does not just tolerate quantization error โ it actively optimizes for quantized performance. Weights migrate to positions where rounding to the nearest grid point causes minimal task loss.
def visualize_ste_convergence(num_steps=200, lr=0.01):
"""Show weight convergence to quantization grid points during QAT."""
# INT4 grid: 16 levels between -1 and 1
scale = 2.0 / 15.0 # ~0.133
w = torch.tensor([0.5], requires_grad=True)
target = torch.tensor([0.47]) # Target is between grid points
optimizer = torch.optim.SGD([w], lr=lr)
trajectory = []
for step in range(num_steps):
optimizer.zero_grad()
# Fake quantize with STE
w_q = ste_round(w / scale) * scale
w_q_clamped = torch.clamp(w_q, -1.0, 1.0)
loss = (w_q_clamped - target) ** 2
loss.backward()
optimizer.step()
trajectory.append((w.item(), w_q_clamped.item(), loss.item()))
# Weight converges to grid point closest to target
# Grid points near 0.47: 0.400 (3*scale), 0.533 (4*scale)
print(f"Final w={trajectory[-1][0]:.4f}, "
f"quantized={trajectory[-1][1]:.4f}")
return trajectory
Fake Quantization: The Complete Module
Symmetric Fake Quantization
class FakeQuantize(nn.Module):
"""Fake quantization module for QAT.
Simulates quantization during training:
- Forward: quantize then dequantize (introduce quantization noise)
- Backward: straight-through estimator
Supports symmetric and asymmetric quantization,
per-tensor and per-channel granularity.
"""
def __init__(self, num_bits=8, symmetric=True, per_channel=False,
num_channels=1, learnable=False):
super().__init__()
self.num_bits = num_bits
self.symmetric = symmetric
self.per_channel = per_channel
if symmetric:
self.q_min = -(2 ** (num_bits - 1))
self.q_max = 2 ** (num_bits - 1) - 1
else:
self.q_min = 0
self.q_max = 2 ** num_bits - 1
# Scale and zero-point
shape = (num_channels, 1) if per_channel else (1,)
if learnable:
self.scale = nn.Parameter(torch.ones(shape))
self.zero_point = nn.Parameter(torch.zeros(shape))
else:
self.register_buffer('scale', torch.ones(shape))
self.register_buffer('zero_point', torch.zeros(shape))
self.learnable = learnable
self.calibrated = False
def compute_scale_zp(self, x):
"""Compute scale and zero-point from observed tensor."""
if self.per_channel:
# Per output channel for weights
x_flat = x.reshape(x.shape[0], -1)
x_min = x_flat.min(dim=1, keepdim=True).values
x_max = x_flat.max(dim=1, keepdim=True).values
else:
x_min = x.min()
x_max = x.max()
if self.symmetric:
abs_max = torch.max(x_min.abs(), x_max.abs())
scale = abs_max / ((self.q_max - self.q_min) / 2)
scale = torch.clamp(scale, min=1e-8)
zero_point = torch.zeros_like(scale)
else:
scale = (x_max - x_min) / (self.q_max - self.q_min)
scale = torch.clamp(scale, min=1e-8)
zero_point = self.q_min - torch.round(x_min / scale)
return scale, zero_point
def forward(self, x):
if not self.calibrated and not self.learnable:
# First forward: calibrate scale and zero-point
scale, zp = self.compute_scale_zp(x.detach())
self.scale.copy_(scale)
self.zero_point.copy_(zp)
self.calibrated = True
# Fake quantize: quantize then immediately dequantize
if self.symmetric:
x_scaled = x / self.scale
x_rounded = ste_round(x_scaled)
x_clamped = torch.clamp(x_rounded, self.q_min, self.q_max)
x_fake = x_clamped * self.scale
else:
x_scaled = x / self.scale + self.zero_point
x_rounded = ste_round(x_scaled)
x_clamped = torch.clamp(x_rounded, self.q_min, self.q_max)
x_fake = (x_clamped - self.zero_point) * self.scale
return x_fake
Per-Channel Weight Fake Quantization
For weight quantization, per-channel granularity (one scale per output channel) is standard. This dramatically reduces quantization error because each output channel has its own dynamic range.
class PerChannelFakeQuantize(nn.Module):
"""Per-channel fake quantization for weight tensors.
Each output channel (row of the weight matrix) gets its own
scale factor, matching deployment-time per-channel quantization.
"""
def __init__(self, num_bits=4, num_channels=4096):
super().__init__()
self.num_bits = num_bits
self.q_min = -(2 ** (num_bits - 1))
self.q_max = 2 ** (num_bits - 1) - 1
self.register_buffer('scale', torch.ones(num_channels, 1))
self.register_buffer('observed', torch.tensor(False))
def forward(self, weight):
# weight shape: [out_channels, in_channels]
if not self.observed:
# Calibrate from first observation
channel_max = weight.detach().abs().amax(dim=1, keepdim=True)
self.scale.copy_(channel_max / self.q_max)
self.scale.clamp_(min=1e-8)
self.observed.fill_(True)
w_scaled = weight / self.scale
w_rounded = ste_round(w_scaled)
w_clamped = torch.clamp(w_rounded, self.q_min, self.q_max)
return w_clamped * self.scale
Dynamic Activation Fake Quantization
Activations are quantized dynamically โ the scale is computed from each input tensor at runtime. During QAT, we simulate this by computing the scale on every forward pass.
class DynamicActivationFakeQuantize(nn.Module):
"""Dynamic per-tensor fake quantization for activations.
Scale is computed from each input tensor (not stored).
This matches deployment-time dynamic quantization.
"""
def __init__(self, num_bits=8):
super().__init__()
self.num_bits = num_bits
self.q_min = -(2 ** (num_bits - 1))
self.q_max = 2 ** (num_bits - 1) - 1
def forward(self, x):
# Compute scale from current tensor
abs_max = x.detach().abs().max()
scale = abs_max / self.q_max
scale = max(scale.item(), 1e-8)
# Fake quantize
x_scaled = x / scale
x_rounded = ste_round(x_scaled)
x_clamped = torch.clamp(x_rounded, self.q_min, self.q_max)
return x_clamped * scale
QAT-Enabled Linear Layer
Wrapping a Linear Layer with Fake Quantization
class QATLinear(nn.Module):
"""Linear layer with fake quantization for QAT.
Inserts fake quantization on weights (per-channel, static scale)
and activations (per-tensor, dynamic scale).
"""
def __init__(self, in_features, out_features, weight_bits=4,
activation_bits=8, bias=True):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias=bias)
self.weight_fq = PerChannelFakeQuantize(
num_bits=weight_bits, num_channels=out_features
)
self.activation_fq = DynamicActivationFakeQuantize(
num_bits=activation_bits
)
self.weight_bits = weight_bits
self.activation_bits = activation_bits
def forward(self, x):
# Fake quantize activations
x_q = self.activation_fq(x)
# Fake quantize weights
w_q = self.weight_fq(self.linear.weight)
# Linear operation with fake-quantized tensors
out = F.linear(x_q, w_q, self.linear.bias)
return out
@classmethod
def from_float(cls, float_linear, weight_bits=4, activation_bits=8):
"""Convert a trained FP16 linear layer to QAT linear."""
qat = cls(
float_linear.in_features,
float_linear.out_features,
weight_bits=weight_bits,
activation_bits=activation_bits,
bias=float_linear.bias is not None
)
qat.linear.weight.data.copy_(float_linear.weight.data)
if float_linear.bias is not None:
qat.linear.bias.data.copy_(float_linear.bias.data)
return qat
The bias term in a linear layer is typically kept in FP32 during both QAT and deployment. The bias has very few parameters relative to the weight matrix (e.g., 4096 vs 4096x4096 = 16M), so quantizing it saves negligible memory while introducing significant error โ the bias is added to every output element.
Converting a Full Model to QAT
def convert_model_to_qat(model, weight_bits=4, activation_bits=8,
skip_layers=None):
"""Replace all nn.Linear layers with QATLinear.
Args:
model: Pre-trained model
weight_bits: Bit width for weight quantization
activation_bits: Bit width for activation quantization
skip_layers: List of layer name patterns to skip
(e.g., ['lm_head', 'embed'] for LLMs)
"""
skip_layers = skip_layers or []
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# Check if this layer should be skipped
if any(skip in name for skip in skip_layers):
continue
# Replace with QAT version
qat_linear = QATLinear.from_float(
module,
weight_bits=weight_bits,
activation_bits=activation_bits
)
# Navigate to parent module and replace
parts = name.split('.')
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], qat_linear)
return model
def count_qat_layers(model):
"""Count QAT-converted vs skipped layers."""
qat_count = 0
float_count = 0
for name, module in model.named_modules():
if isinstance(module, QATLinear):
qat_count += 1
elif isinstance(module, nn.Linear):
float_count += 1
print(f"QAT layers: {qat_count}, Float layers: {float_count}")
return qat_count, float_count
The QAT Training Loop
Learning Rate and Training Duration
QAT is typically done as fine-tuning, not training from scratch. The standard approach:
- Start from a pre-trained FP16 model
- Insert fake quantization ops
- Fine-tune for 1-5% of the original training tokens
- Use a lower learning rate (10-100x lower than pre-training)
def qat_training_loop(model, train_dataloader, num_epochs=2,
lr=1e-5, warmup_steps=100):
"""QAT fine-tuning loop.
Key differences from normal training:
1. Lower learning rate (model is already trained)
2. Fewer epochs (1-5% of pre-training data)
3. Gradual quantization: optionally start with higher bits
and reduce over training
"""
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
# Linear warmup then cosine decay
total_steps = len(train_dataloader) * num_epochs
def lr_schedule(step):
if step < warmup_steps:
return step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return 0.5 * (1 + np.cos(np.pi * progress))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
model.train()
step = 0
for epoch in range(num_epochs):
total_loss = 0
for batch in train_dataloader:
input_ids = batch['input_ids'].cuda()
labels = batch['labels'].cuda()
outputs = model(input_ids, labels=labels)
loss = outputs.loss
optimizer.zero_grad()
loss.backward()
# Gradient clipping -- important for QAT stability
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step()
total_loss += loss.item()
step += 1
if step % 100 == 0:
avg_loss = total_loss / 100
current_lr = scheduler.get_last_lr()[0]
print(f"Step {step}: loss={avg_loss:.4f}, lr={current_lr:.2e}")
total_loss = 0
return model
QAT gradients can spike during early training when the model first encounters quantization noise. The STE gradient approximation is biased, and the mismatch between forward (quantized) and backward (STE) can produce large gradient norms. Gradient clipping at 1.0 is standard practice.
Gradual Quantization: Progressive Bit Reduction
An advanced technique: start QAT at higher precision (INT8) and gradually reduce to the target (INT4) over training. This gives the model a smooth transition from FP16 to the target precision.
class GradualQuantScheduler:
"""Gradually reduce bit width during QAT.
Example schedule for target INT4:
- Steps 0-500: No quantization (warmup)
- Steps 500-1000: INT8 fake quantization
- Steps 1000-2000: INT6 fake quantization (interpolated)
- Steps 2000+: INT4 fake quantization (target)
"""
def __init__(self, model, schedule):
"""
Args:
model: QAT model
schedule: List of (step, num_bits) tuples
e.g., [(0, 16), (500, 8), (1000, 6), (2000, 4)]
"""
self.model = model
self.schedule = sorted(schedule, key=lambda x: x[0])
def get_bits_for_step(self, step):
"""Linearly interpolate bit width based on schedule."""
if step <= self.schedule[0][0]:
return self.schedule[0][1]
if step >= self.schedule[-1][0]:
return self.schedule[-1][1]
for i in range(len(self.schedule) - 1):
s0, b0 = self.schedule[i]
s1, b1 = self.schedule[i + 1]
if s0 <= step < s1:
progress = (step - s0) / (s1 - s0)
return b0 + progress * (b1 - b0)
return self.schedule[-1][1]
def update(self, step):
"""Update fake quantization bit widths for current step."""
target_bits = self.get_bits_for_step(step)
# Round to nearest integer -- we cannot actually do fractional bits
effective_bits = max(2, round(target_bits))
for module in self.model.modules():
if isinstance(module, (PerChannelFakeQuantize,
DynamicActivationFakeQuantize)):
module.num_bits = effective_bits
module.q_min = -(2 ** (effective_bits - 1))
module.q_max = 2 ** (effective_bits - 1) - 1
QAT vs PTQ: Systematic Comparison
Quality at Different Bit Widths
The quality gap between QAT and PTQ depends on the bit width. At INT8, PTQ works well and QAT provides minimal benefit. At INT4, PTQ suffers significant degradation and QAT provides substantial improvement. At INT3 and INT2, only QAT produces usable models.
QAT vs PTQ Perplexity: Llama-2 7B on WikiText-2
| Method | Bits (W/A) | Perplexity | Delta vs FP16 | Training Cost |
|---|---|---|---|---|
| FP16 Baseline | 16/16 | 5.47 | --- | --- |
| PTQ (RTN) | 8/16 | 5.49 | +0.02 | None |
| QAT | 8/16 | 5.48 | +0.01 | ~2 GPU-hours |
| PTQ (RTN) | 4/16 | 6.83 | +1.36 | None |
| PTQ (GPTQ) | 4/16 | 5.85 | +0.38 | ~30 min calibration |
| PTQ (AWQ) | 4/16 | 5.78 | +0.31 | ~30 min calibration |
| QAT | 4/16 | 5.56 | +0.09 | ~8 GPU-hours |
| PTQ (GPTQ) | 3/16 | 8.12 | +2.65 | ~30 min calibration |
| QAT | 3/16 | 6.31 | +0.84 | ~16 GPU-hours |
| PTQ (RTN) | 2/16 | 185.0 | +179.5 | None |
| QAT | 2/16 | 11.4 | +5.93 | ~32 GPU-hours |
QAT vs PTQ Perplexity Gap by Bit Width (Llama-2 7B)
(perplexity delta vs FP16)Model Size Scaling
Larger models are more robust to quantization in general. The QAT advantage is most pronounced for smaller models at low bit widths.
QAT vs PTQ at INT4: Effect of Model Size
| Model | FP16 PPL | PTQ (GPTQ) PPL | QAT PPL | QAT Advantage |
|---|---|---|---|---|
| Llama-2 7B | 5.47 | 5.85 | 5.56 | 0.29 PPL |
| Llama-2 13B | 4.88 | 5.10 | 4.95 | 0.15 PPL |
| Llama-2 70B | 3.31 | 3.42 | 3.35 | 0.07 PPL |
| Llama-3 8B | 6.14 | 6.58 | 6.25 | 0.33 PPL |
| Llama-3 70B | 2.86 | 2.97 | 2.90 | 0.07 PPL |
When Is QAT Worth the Cost?
Decision framework:
-
INT8 quantization: Use PTQ. QAT provides negligible benefit at 8-bit. The 2-8 GPU-hours of QAT training are not justified.
-
INT4 quantization, models greater than or equal to 70B: Use PTQ (GPTQ or AWQ). The quality gap is small (under 0.1 PPL) and the training cost is substantial (hundreds of GPU-hours for a 70B model).
-
INT4 quantization, models under 13B: QAT is strongly recommended. The quality gap is 0.15-0.35 PPL, which can meaningfully impact downstream task performance. Training cost is manageable (8-32 GPU-hours).
-
INT3 or INT2 quantization: QAT is required. PTQ produces near-unusable models at these precision levels.
-
Task-critical deployments: If the quantized model will serve millions of users and quality matters (e.g., medical, legal), use QAT regardless of model size.
QAT fine-tuning typically requires 1-5% of the original pre-training compute. For Llama-2 7B (pre-trained with ~2T tokens on ~1000 GPU-hours), QAT uses ~10-20B tokens on 8-16 GPU-hours. For 70B models, this scales to 200-500 GPU-hours โ still much less than pre-training but non-trivial.
Advanced QAT Techniques
Learned Step Size Quantization (LSQ)
Instead of computing the quantization scale from observed min/max, make the scale a learnable parameter that is optimized during training.
class LearnedStepSizeQuantize(nn.Module):
"""LSQ: Learned Step Size Quantization (Esser et al., 2020).
The scale (step size) is a learnable parameter optimized
jointly with the model weights. The gradient of the loss
with respect to the scale is computed analytically.
"""
def __init__(self, num_bits=4, per_channel=False, num_channels=1):
super().__init__()
self.num_bits = num_bits
self.q_min = -(2 ** (num_bits - 1))
self.q_max = 2 ** (num_bits - 1) - 1
shape = (num_channels, 1) if per_channel else (1,)
# Initialize scale -- will be set from first observation
self.scale = nn.Parameter(torch.ones(shape))
self.initialized = False
def init_scale(self, x):
"""Initialize scale from first observed tensor."""
if self.scale.shape[0] > 1:
# Per-channel
abs_max = x.detach().abs().amax(dim=1, keepdim=True)
else:
abs_max = x.detach().abs().max()
init_scale = abs_max / self.q_max
self.scale.data.copy_(init_scale.clamp(min=1e-8))
self.initialized = True
def forward(self, x):
if not self.initialized:
self.init_scale(x)
# Ensure scale is positive
scale = self.scale.abs().clamp(min=1e-8)
# Quantize
x_scaled = x / scale
x_rounded = ste_round(x_scaled)
x_clamped = torch.clamp(x_rounded, self.q_min, self.q_max)
return x_clamped * scale
def extra_repr(self):
return f'num_bits={self.num_bits}, q_range=[{self.q_min}, {self.q_max}]'
The key advantage of LSQ: the gradient for the scale parameter is:
Where is computed analytically from the fake quantization formula. This allows the optimizer to find the scale that minimizes task loss, not just the scale that minimizes quantization error.
PACT: Parameterized Clipping Activation
PACT (Choi et al., 2018) learns the clipping threshold for activation quantization. Instead of using the full dynamic range of activations, it learns an upper bound that clips outlier activations before quantization.
class PACTActivation(nn.Module):
"""PACT: Parameterized Clipping Activation.
Learns a clipping threshold alpha that is applied before
quantization. This trades off clipping error (values above
alpha are clipped) against quantization error (fewer bits
to represent the remaining range).
"""
def __init__(self, num_bits=8, initial_alpha=6.0):
super().__init__()
self.num_bits = num_bits
self.q_levels = 2 ** num_bits
self.alpha = nn.Parameter(torch.tensor(initial_alpha))
def forward(self, x):
# Clip to [0, alpha] for ReLU activations
# or [-alpha, alpha] for symmetric
alpha = self.alpha.abs()
x_clipped = torch.clamp(x, -alpha, alpha)
# Quantize the clipped range
scale = (2 * alpha) / (self.q_levels - 1)
x_scaled = (x_clipped + alpha) / scale
x_rounded = ste_round(x_scaled)
x_clamped = torch.clamp(x_rounded, 0, self.q_levels - 1)
return x_clamped * scale - alpha
Knowledge Distillation with QAT
Combine QAT with knowledge distillation: use the full-precision model as a teacher and the QAT model as a student. This provides a stronger training signal than just the task loss.
def qat_with_distillation(teacher_model, student_model, dataloader,
temperature=4.0, alpha_kd=0.5, lr=1e-5,
num_steps=5000):
"""QAT with knowledge distillation from full-precision teacher.
Loss = alpha * KD_loss + (1 - alpha) * task_loss
The KD loss encourages the QAT model's output distribution
to match the teacher's, which provides richer gradient
information than the hard labels alone.
"""
teacher_model.eval()
student_model.train()
optimizer = torch.optim.AdamW(student_model.parameters(), lr=lr)
for step, batch in enumerate(dataloader):
if step >= num_steps:
break
input_ids = batch['input_ids'].cuda()
labels = batch['labels'].cuda()
# Student forward (with fake quantization)
student_out = student_model(input_ids, labels=labels)
task_loss = student_out.loss
student_logits = student_out.logits
# Teacher forward (no gradients needed)
with torch.no_grad():
teacher_out = teacher_model(input_ids)
teacher_logits = teacher_out.logits
# KL divergence loss with temperature scaling
kd_loss = F.kl_div(
F.log_softmax(student_logits / temperature, dim=-1),
F.softmax(teacher_logits / temperature, dim=-1),
reduction='batchmean'
) * (temperature ** 2)
# Combined loss
loss = alpha_kd * kd_loss + (1 - alpha_kd) * task_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(student_model.parameters(), 1.0)
optimizer.step()
if step % 100 == 0:
print(f"Step {step}: task_loss={task_loss.item():.4f}, "
f"kd_loss={kd_loss.item():.4f}")
QAT + Distillation vs QAT Alone (Llama-2 7B, INT4)
| Method | PPL | Delta vs FP16 | Training Cost |
|---|---|---|---|
| FP16 Baseline | 5.47 | --- | --- |
| QAT only | 5.56 | +0.09 | 8 GPU-hrs |
| QAT + Distillation | 5.51 | +0.04 | 16 GPU-hrs (2x for teacher) |
| PTQ (GPTQ) | 5.85 | +0.38 | 0.5 GPU-hrs |
Converting QAT Models for Deployment
Folding Fake Quantization into Real Quantization
After QAT training, the fake quantization ops must be converted to real quantization for deployment. The weights are quantized to integers, and the scale factors are stored separately.
def convert_qat_to_quantized(qat_model):
"""Convert QAT model to deployment-ready quantized model.
Replaces QATLinear modules with quantized linear modules
that store integer weights and scale factors.
"""
for name, module in qat_model.named_modules():
if isinstance(module, QATLinear):
# Extract the learned quantization parameters
weight = module.linear.weight.data
scale = module.weight_fq.scale
# Actually quantize (no fake quantization)
w_int = torch.clamp(
torch.round(weight / scale),
module.weight_fq.q_min,
module.weight_fq.q_max
).to(torch.int8)
# Store quantized weights and metadata
module.register_buffer('weight_quantized', w_int)
module.register_buffer('weight_scale', scale.squeeze())
# Remove the original float weight to save memory
del module.linear.weight
print(f"Converted {name}: "
f"scale range [{scale.min().item():.6f}, "
f"{scale.max().item():.6f}]")
return qat_model
def verify_conversion(qat_model, test_input):
"""Verify that QAT and converted model produce identical outputs."""
# QAT forward (with fake quantization)
qat_model.eval()
with torch.no_grad():
y_qat = qat_model(test_input)
# The converted model should produce numerically identical results
# because fake quantization and real quantization apply the same
# rounding to the same weights
print("Conversion verification: outputs should be identical")
return y_qat
Quantization-Aware Training for Specific Frameworks
Different deployment frameworks expect different formats:
def export_for_vllm(qat_model, output_path):
"""Export QAT model in format compatible with vLLM quantized inference.
vLLM expects:
- INT4 weights packed into INT32 (8 values per int32)
- Per-channel scale factors in FP16
- Group size metadata (if using group quantization)
"""
state_dict = {}
for name, module in qat_model.named_modules():
if isinstance(module, QATLinear):
weight = module.linear.weight.data
scale = module.weight_fq.scale
# Quantize to INT4
w_int4 = torch.clamp(
torch.round(weight / scale),
-8, 7
).to(torch.int8)
# Pack INT4 into INT32 (8 values per int32)
w_packed = pack_int4_to_int32(w_int4)
state_dict[f"{name}.qweight"] = w_packed
state_dict[f"{name}.scales"] = scale.half()
state_dict[f"{name}.zeros"] = torch.zeros_like(scale).half()
if module.linear.bias is not None:
state_dict[f"{name}.bias"] = module.linear.bias.data.half()
torch.save(state_dict, output_path)
print(f"Saved quantized model to {output_path}")
def pack_int4_to_int32(tensor):
"""Pack 8 INT4 values into each INT32.
tensor: [..., N] where N is divisible by 8
returns: [..., N//8] of dtype int32
"""
assert tensor.shape[-1] % 8 == 0
# Shift INT4 from [-8,7] to [0,15] for unsigned packing
unsigned = (tensor + 8).to(torch.int32)
# Pack 8 values per int32
packed_shape = list(tensor.shape)
packed_shape[-1] //= 8
packed = torch.zeros(packed_shape, dtype=torch.int32)
for i in range(8):
packed |= (unsigned[..., i::8] & 0xF) << (i * 4)
return packed
QAT and PTQ produce models with the same data format (INT4 weights, FP16 activations). The inference speed is identical โ the difference is entirely in model quality. QAT simply finds better INT4 weight values through training. This means QAT is a free quality upgrade at the cost of training compute.
Practical Considerations
Which Layers to Skip
Not all layers should be quantized during QAT:
SKIP_PATTERNS_LLM = [
'embed', # Embedding layers: quantization loses token discrimination
'lm_head', # Output projection: directly affects next-token probabilities
'norm', # LayerNorm: few parameters, high sensitivity
'rotary', # RoPE embeddings: positional encoding precision matters
]
# For vision transformers
SKIP_PATTERNS_VIT = [
'patch_embed', # Patch embedding: first layer, high sensitivity
'head', # Classification head
'norm',
]
Monitoring QAT Training
class QATMonitor:
"""Monitor quantization-specific metrics during QAT."""
def __init__(self, model, log_interval=100):
self.model = model
self.log_interval = log_interval
self.step = 0
def log(self):
self.step += 1
if self.step % self.log_interval != 0:
return
for name, module in self.model.named_modules():
if isinstance(module, QATLinear):
weight = module.linear.weight.data
scale = module.weight_fq.scale
# How many weights are at the clamp boundary?
w_scaled = weight / scale
at_min = (w_scaled <= module.weight_fq.q_min + 0.5).float().mean()
at_max = (w_scaled >= module.weight_fq.q_max - 0.5).float().mean()
# Weight range utilization
w_q = torch.clamp(torch.round(w_scaled),
module.weight_fq.q_min,
module.weight_fq.q_max)
unique_values = w_q.unique().numel()
total_levels = module.weight_fq.q_max - module.weight_fq.q_min + 1
if "layers.0" in name or "layers.15" in name:
print(f"[Step {self.step}] {name}: "
f"clamp_rate={at_min.item()+at_max.item():.4f}, "
f"levels_used={unique_values}/{total_levels}, "
f"scale={scale.mean().item():.6f}")
Common Failure Modes
-
Scale explosion: If the learning rate is too high, QAT can cause weight magnitudes to grow, increasing the quantization scale and effectively reducing precision. Monitor the scale values.
-
Clamp saturation: If too many weights are at the clamp boundary (greater than 5%), the quantization range is too narrow. Either increase bits or use per-channel quantization.
-
Training divergence: QAT with STE can diverge if the initial quantization error is too large. Use gradual quantization (start at INT8, reduce to INT4) to avoid this.
-
Activation outlier explosion: During QAT, activation outliers can grow larger as the model compensates for weight quantization. Monitor activation ranges and apply SmoothQuant before QAT if needed.
Summary
QAT inserts fake quantization into the forward pass so the model learns weights that are robust to quantization noise. The straight-through estimator enables gradient flow through the non-differentiable rounding operation. At INT8, PTQ is sufficient and QAT is unnecessary. At INT4, QAT provides 0.1-0.3 PPL improvement over the best PTQ methods for small models (under 13B). At INT3 and below, QAT is the only viable approach.
The decision is straightforward: if INT4/INT8 PTQ meets your quality requirements, use PTQ. If it does not, invest in QAT. The training cost is 1-5% of pre-training compute, and the resulting model deploys at the same speed as a PTQ model โ the improvement is pure quality at zero inference cost.