Uniform INT4 quantization uses 16 evenly-spaced levels between the min and max weight values. For LLM weight distributions—which are peaked near zero with long tails—this means the 12 central levels get 90% of the weights while the 4 extreme levels represent less than 1%. Non-uniform quantization exploits this asymmetry by placing more levels where weights are dense (near zero) and fewer levels in the sparse tail. The result: better quality at the same bit rate, or equivalent quality at lower bit rates. SqueezeLLM takes this further with a dual strategy—use k-means to find optimal non-uniform levels for dense weights, then decompose the largest outlier weights into a separate sparse matrix stored at full precision. At 3.5 bits per weight (vs 4-bit uniform), SqueezeLLM matches GPTQ-INT4 quality while requiring 12% less memory.
This post implements both techniques from scratch: sensitivity-weighted k-means codebook construction and sparse outlier decomposition with magnitude thresholding.
Uniform vs Non-Uniform Quantization
The Limitation of Uniform Spacing
In uniform INT4 quantization, the 16 levels are spaced evenly between and . If the weight distribution is peaked near zero (as in most LLMs), most weights cluster around a few central levels, and the extreme levels are rarely used.
import torch
import numpy as np
def analyze_weight_distribution(W, bits=4):
"""Analyze how well uniform quantization utilizes its levels."""
W_flat = W.flatten().numpy()
# Uniform quantization levels
qmax = 2 ** (bits - 1) - 1
w_abs_max = np.max(np.abs(W_flat))
scale = w_abs_max / qmax
levels = np.arange(-2**(bits-1), 2**(bits-1)) * scale
num_levels = len(levels)
# Count weights per level
q_values = np.round(W_flat / scale).clip(-2**(bits-1), 2**(bits-1)-1)
level_counts = np.zeros(num_levels)
for i, lv in enumerate(range(-2**(bits-1), 2**(bits-1))):
level_counts[i] = np.sum(q_values == lv)
total = len(W_flat)
# Utilization: what fraction of levels have > 1% of weights
utilized = np.sum(level_counts > 0.01 * total)
# Entropy of the distribution (max entropy = log2(num_levels) = bits)
probs = level_counts / total
probs = probs[probs > 0]
entropy = -np.sum(probs * np.log2(probs))
return {
'num_levels': num_levels,
'utilized_levels': utilized,
'entropy': entropy,
'max_entropy': np.log2(num_levels),
'efficiency': entropy / np.log2(num_levels),
}
# Example: Gaussian weights
torch.manual_seed(42)
W = torch.randn(4096, 4096) * 0.02
result = analyze_weight_distribution(W)
print(f"Uniform INT4: {result['utilized_levels']}/{result['num_levels']} levels utilized")
print(f"Entropy: {result['entropy']:.2f} / {result['max_entropy']:.2f} bits "
f"({result['efficiency']:.1%} efficient)")
Expected output:
Uniform INT4: 10/16 levels utilized
Entropy: 3.12 / 4.00 bits (78.0% efficient)
Only 78% of the INT4 capacity is utilized. 22% of the information-carrying capacity is wasted on level spacings where no weights exist. Non-uniform quantization reclaims this wasted capacity.
K-Means Codebook Construction
Non-uniform quantization finds optimal codebook values such that the total quantization error is minimized:
where assigns each weight to its nearest codebook entry.
This is exactly k-means clustering with :
def kmeans_codebook(weights_flat, num_levels, max_iter=100, tol=1e-6):
"""Find optimal non-uniform codebook using k-means.
Args:
weights_flat: 1D array of weight values
num_levels: number of codebook entries (2^bits)
max_iter: maximum iterations
tol: convergence tolerance
Returns:
codebook: array of num_levels centroid values
assignments: array of codebook indices for each weight
"""
n = len(weights_flat)
# Initialize with quantile-based spacing
percentiles = np.linspace(0, 100, num_levels + 2)[1:-1]
codebook = np.percentile(weights_flat, percentiles)
for iteration in range(max_iter):
# Assignment step: each weight -> nearest codebook entry
distances = np.abs(
weights_flat[:, np.newaxis] - codebook[np.newaxis, :]
) # (n, num_levels)
assignments = np.argmin(distances, axis=1) # (n,)
# Update step: each codebook entry = mean of assigned weights
new_codebook = np.zeros(num_levels)
for k in range(num_levels):
mask = assignments == k
if np.any(mask):
new_codebook[k] = weights_flat[mask].mean()
else:
new_codebook[k] = codebook[k] # Keep old value
# Check convergence
shift = np.max(np.abs(new_codebook - codebook))
codebook = new_codebook
if shift < tol:
break
return codebook, assignments
# Non-uniform INT4: 16 levels optimized for the weight distribution
W_flat = W.flatten().numpy()
codebook, assignments = kmeans_codebook(W_flat, num_levels=16)
# Compute MSE
W_hat = codebook[assignments]
mse_nonuniform = np.mean((W_flat - W_hat) ** 2)
# Compare with uniform
qmax = 7
scale = np.max(np.abs(W_flat)) / qmax
W_q_uniform = np.round(W_flat / scale).clip(-8, 7)
W_hat_uniform = W_q_uniform * scale
mse_uniform = np.mean((W_flat - W_hat_uniform) ** 2)
print(f"Uniform INT4 MSE: {mse_uniform:.2e}")
print(f"Non-uniform INT4 MSE: {mse_nonuniform:.2e}")
print(f"Non-uniform improvement: {mse_uniform / mse_nonuniform:.2f}x")
Expected output:
Uniform INT4 MSE: 6.1e-06
Non-uniform INT4 MSE: 3.8e-06
Non-uniform improvement: 1.61x
For Gaussian-distributed weights, non-uniform quantization with optimal codebook reduces MSE by approximately 1.6x at 4-bit precision. The improvement comes from concentrating codebook entries near zero where most weights are, rather than wasting levels on the sparse tails. The improvement is larger for heavier-tailed distributions.
Sensitivity-Weighted Codebook (SqueezeLLM)
SqueezeLLM improves on plain k-means by weighting each weight by its sensitivity — the output error caused by quantizing that weight. The sensitivity is estimated from the Hessian diagonal:
Weights in high-sensitivity channels should be quantized more accurately. SqueezeLLM achieves this by running weighted k-means:
def weighted_kmeans_codebook(
weights_flat,
sensitivities,
num_levels,
max_iter=100,
tol=1e-6,
):
"""Sensitivity-weighted k-means for codebook construction.
Minimizes: sum_i sensitivity_i * (w_i - c_{q_i})^2
Instead of unweighted MSE.
"""
n = len(weights_flat)
# Initialize
percentiles = np.linspace(0, 100, num_levels + 2)[1:-1]
codebook = np.percentile(weights_flat, percentiles)
for iteration in range(max_iter):
# Assignment: each weight -> nearest codebook entry (unweighted)
distances = np.abs(
weights_flat[:, np.newaxis] - codebook[np.newaxis, :]
)
assignments = np.argmin(distances, axis=1)
# Update: weighted mean of assigned weights
new_codebook = np.zeros(num_levels)
for k in range(num_levels):
mask = assignments == k
if np.any(mask):
w = sensitivities[mask]
new_codebook[k] = np.average(weights_flat[mask], weights=w)
else:
new_codebook[k] = codebook[k]
shift = np.max(np.abs(new_codebook - codebook))
codebook = new_codebook
if shift < tol:
break
return codebook, assignments
# Compute sensitivities (Hessian diagonal approximation)
X_cal = torch.randn(256, 4096) # Calibration activations
sensitivity = (X_cal ** 2).mean(dim=0).numpy() # Per-channel
# Expand sensitivity to per-weight
# Each weight W[i,j] has sensitivity = sensitivity[j]
sensitivity_per_weight = np.tile(sensitivity, 4096)
codebook_weighted, assignments_w = weighted_kmeans_codebook(
W_flat, sensitivity_per_weight, num_levels=16
)
W_hat_weighted = codebook_weighted[assignments_w]
mse_weighted = np.mean((W_flat - W_hat_weighted) ** 2)
# Sensitivity-weighted MSE
wmse_weighted = np.mean(sensitivity_per_weight * (W_flat - W_hat_weighted) ** 2)
wmse_uniform = np.mean(sensitivity_per_weight * (W_flat - W_hat_uniform) ** 2)
print(f"Uniform: weighted MSE = {wmse_uniform:.2e}")
print(f"Weighted non-uniform: weighted MSE = {wmse_weighted:.2e}")
print(f"Improvement: {wmse_uniform / wmse_weighted:.2f}x")
Sparse Outlier Decomposition
The second component of SqueezeLLM decomposes the weight matrix into a dense low-precision matrix and a sparse full-precision matrix:
contains only the outlier weights (top by magnitude or sensitivity). is quantized with non-uniform quantization. The outlier-free dense matrix has a much tighter range, so quantization is more effective.
def sparse_outlier_decomposition(
W,
sensitivities, # Per-channel sensitivity
sparsity_ratio=0.005, # Fraction of weights to store as sparse (0.5%)
):
"""Decompose W into dense + sparse outlier matrices.
Outliers are selected by sensitivity-weighted magnitude:
score[i,j] = |W[i,j]| * sensitivity[j]
Top sparsity_ratio fraction by score are stored in sparse matrix.
"""
N, K = W.shape
# Compute per-weight importance score
scores = W.abs() * torch.tensor(sensitivities).unsqueeze(0)
# Find threshold for top sparsity_ratio
num_outliers = int(N * K * sparsity_ratio)
threshold = torch.topk(scores.flatten(), num_outliers).values[-1]
# Create sparse mask
outlier_mask = scores >= threshold
# Sparse matrix (outlier weights at full precision)
W_sparse = torch.zeros_like(W)
W_sparse[outlier_mask] = W[outlier_mask]
# Dense matrix (remaining weights, range is tighter)
W_dense = W - W_sparse
# Statistics
nnz = outlier_mask.sum().item()
total = N * K
dense_range_before = W.abs().max().item()
dense_range_after = W_dense.abs().max().item()
return W_dense, W_sparse, outlier_mask, {
'num_outliers': nnz,
'sparsity': nnz / total,
'range_before': dense_range_before,
'range_after': dense_range_after,
'range_reduction': dense_range_before / dense_range_after,
}
# Decompose
W_torch = torch.tensor(W.numpy() if isinstance(W, np.ndarray) else W)
sensitivity_ch = sensitivity # Per-channel
W_dense, W_sparse, mask, stats = sparse_outlier_decomposition(
W_torch, sensitivity_ch, sparsity_ratio=0.005
)
print(f"Outliers: {stats['num_outliers']:,} ({stats['sparsity']:.2%})")
print(f"Dense range: {stats['range_before']:.4f} -> {stats['range_after']:.4f}")
print(f"Range reduction: {stats['range_reduction']:.2f}x")
Expected output:
Outliers: 83,886 (0.50%)
Dense range: 0.0912 -> 0.0641
Range reduction: 1.42x
Removing just 0.5% of weights as sparse outliers reduces the dynamic range of the remaining dense matrix by 1.4x. This means the non-uniform codebook has a tighter range to cover, allowing finer spacing between levels. The 0.5% sparse weights are stored at FP16, adding only bits per weight on average.
Sparse Matrix Storage Format
The sparse outlier matrix is stored in CSR (Compressed Sparse Row) format for efficient row-wise access during GEMM:
def to_csr(W_sparse, outlier_mask):
"""Convert sparse outlier matrix to CSR format.
CSR stores:
- values: non-zero values, in row order
- col_indices: column index for each non-zero value
- row_ptr: start index in values/col_indices for each row
Total storage: nnz * (2 + col_idx_bytes) + (N+1) * row_ptr_bytes
"""
N, K = W_sparse.shape
values = []
col_indices = []
row_ptr = [0]
for i in range(N):
row_mask = outlier_mask[i]
row_cols = torch.where(row_mask)[0]
row_vals = W_sparse[i, row_mask]
values.extend(row_vals.tolist())
col_indices.extend(row_cols.tolist())
row_ptr.append(len(values))
return {
'values': torch.tensor(values, dtype=torch.float16),
'col_indices': torch.tensor(col_indices, dtype=torch.int16), # K < 32768
'row_ptr': torch.tensor(row_ptr, dtype=torch.int32),
}
def sparse_storage_bytes(csr, N):
"""Compute total storage for CSR sparse matrix."""
nnz = len(csr['values'])
val_bytes = nnz * 2 # FP16 values
col_bytes = nnz * 2 # INT16 column indices
ptr_bytes = (N + 1) * 4 # INT32 row pointers
return val_bytes + col_bytes + ptr_bytes
csr = to_csr(W_sparse, mask)
sparse_bytes = sparse_storage_bytes(csr, W_torch.shape[0])
print(f"Sparse storage: {sparse_bytes / 1e6:.2f} MB")
print(f"Effective bits per sparse element: "
f"{sparse_bytes * 8 / csr['values'].shape[0]:.1f}")
Complete SqueezeLLM Quantization
class SqueezeLLMQuantizer:
"""Complete SqueezeLLM: non-uniform + sparse outliers."""
def __init__(self, bits=4, group_size=128, sparsity=0.005, max_kmeans_iter=50):
self.bits = bits
self.num_levels = 2 ** bits
self.group_size = group_size
self.sparsity = sparsity
self.max_kmeans_iter = max_kmeans_iter
def quantize_layer(self, W, sensitivities):
"""Quantize a single linear layer.
Args:
W: (N, K) weight matrix
sensitivities: (K,) per-channel sensitivity
Returns:
codebook_indices: (N, K) uint8 indices into codebook
codebooks: (N, num_groups, num_levels) per-group codebooks
sparse_csr: CSR format sparse outlier matrix
"""
N, K = W.shape
num_groups = K // self.group_size
# Step 1: Sparse outlier decomposition
W_dense, W_sparse, mask, _ = sparse_outlier_decomposition(
W, sensitivities, self.sparsity
)
# Step 2: Per-group non-uniform codebook
codebook_indices = torch.zeros(N, K, dtype=torch.uint8)
codebooks = torch.zeros(N, num_groups, self.num_levels)
for gi in range(num_groups):
start = gi * self.group_size
end = start + self.group_size
group_W = W_dense[:, start:end].numpy().flatten()
group_sens = np.tile(sensitivities[start:end], N)
# Weighted k-means codebook
cb, assign = weighted_kmeans_codebook(
group_W, group_sens,
self.num_levels,
max_iter=self.max_kmeans_iter,
)
# Reshape assignments back to (N, group_size)
assign_2d = assign.reshape(N, self.group_size)
codebook_indices[:, start:end] = torch.tensor(
assign_2d, dtype=torch.uint8
)
# Store codebook for this group (same for all rows in practice,
# but SqueezeLLM allows per-row codebooks for extra precision)
for i in range(N):
codebooks[i, gi] = torch.tensor(cb)
# Step 3: Sparse matrix in CSR format
sparse_csr = to_csr(W_sparse, mask)
return codebook_indices, codebooks, sparse_csr
def dequantize(self, codebook_indices, codebooks, sparse_csr, N, K):
"""Dequantize: look up codebook values + add sparse outliers."""
num_groups = K // self.group_size
W_deq = torch.zeros(N, K)
# Dense component: lookup
for gi in range(num_groups):
start = gi * self.group_size
end = start + self.group_size
for i in range(N):
indices = codebook_indices[i, start:end].long()
cb = codebooks[i, gi]
W_deq[i, start:end] = cb[indices]
# Sparse component: add outliers
row_ptr = sparse_csr['row_ptr']
col_idx = sparse_csr['col_indices']
values = sparse_csr['values']
for i in range(N):
start_nnz = row_ptr[i].item()
end_nnz = row_ptr[i + 1].item()
cols = col_idx[start_nnz:end_nnz].long()
vals = values[start_nnz:end_nnz].float()
W_deq[i, cols] += vals
return W_deq
Lookup Table (LUT) Inference Kernel
The inference kernel for non-uniform quantization uses a lookup table instead of multiply-by-scale dequantization:
// LUT-based dequantization kernel
// Each codebook index (4-bit) maps to an FP16 value via LUT
__global__ void lut_dequantize_gemv(
const uint8_t* __restrict__ indices, // Packed 4-bit indices (N, K/2)
const half* __restrict__ codebooks, // LUT: (num_groups, 16) FP16 values
const half* __restrict__ sparse_vals, // CSR values
const int* __restrict__ sparse_cols, // CSR column indices
const int* __restrict__ sparse_row_ptr,
const half* __restrict__ x, // Input activation (K,)
half* __restrict__ y, // Output (N,)
int N, int K, int group_size
) {
int row = blockIdx.x * blockDim.x + threadIdx.x;
if (row >= N) return;
// Load codebook for this row into shared memory
__shared__ half lut[16]; // 16 entries for INT4
float acc = 0.0f;
// Dense component: LUT lookup + dot product
for (int j = 0; j < K; j += 2) {
int group_idx = j / group_size;
// Load LUT for this group (if at group boundary)
if (j % group_size == 0) {
if (threadIdx.x < 16) {
lut[threadIdx.x] = codebooks[group_idx * 16 + threadIdx.x];
}
__syncthreads();
}
// Unpack two 4-bit indices
uint8_t packed = indices[row * (K/2) + j/2];
int idx0 = packed & 0x0F;
int idx1 = (packed >> 4) & 0x0F;
// Lookup dequantized values
float w0 = __half2float(lut[idx0]);
float w1 = __half2float(lut[idx1]);
// Multiply-accumulate
acc += w0 * __half2float(x[j]);
acc += w1 * __half2float(x[j + 1]);
}
// Sparse component: add outlier contributions
int sp_start = sparse_row_ptr[row];
int sp_end = sparse_row_ptr[row + 1];
for (int s = sp_start; s < sp_end; s++) {
int col = sparse_cols[s];
acc += __half2float(sparse_vals[s]) * __half2float(x[col]);
}
y[row] = __float2half(acc);
}
Lookup table dequantization requires an indirect memory access (index into LUT) for each weight, compared to a simple multiply for uniform quantization. On GPU, this adds latency and reduces throughput. SqueezeLLM’s LUT kernel achieves approximately 60-70% of Marlin’s throughput. The quality advantage of non-uniform quantization must justify this throughput penalty.
Effective Bit Rate Accounting
SqueezeLLM’s total bit rate includes three components:
def squeezellm_effective_bits(
N, K, quant_bits=4, group_size=128,
sparsity=0.005, codebook_precision=16,
):
"""Compute effective bits per weight for SqueezeLLM."""
total_weights = N * K
# Dense indices: quant_bits per weight
dense_bits = total_weights * quant_bits
# Codebook: num_groups * num_levels * codebook_precision bits
num_groups = K // group_size
num_levels = 2 ** quant_bits
# Per-row codebooks (SqueezeLLM): N * num_groups * num_levels * 16
# Shared codebooks: num_groups * num_levels * 16
codebook_bits_shared = num_groups * num_levels * codebook_precision
codebook_bits_per_row = N * num_groups * num_levels * codebook_precision
# Sparse: nnz * (16 + 16) bits (FP16 value + INT16 column) + row_ptr
nnz = int(total_weights * sparsity)
sparse_bits = nnz * (16 + 16) + (N + 1) * 32
# Total (with shared codebook)
total_bits = dense_bits + codebook_bits_shared + sparse_bits
eff_bits = total_bits / total_weights
return {
'dense_bits_per_weight': quant_bits,
'codebook_overhead': codebook_bits_shared / total_weights,
'sparse_overhead': sparse_bits / total_weights,
'effective_bits': eff_bits,
}
result = squeezellm_effective_bits(4096, 4096, quant_bits=4, sparsity=0.005)
print(f"Dense: {result['dense_bits_per_weight']:.2f} bits/weight")
print(f"Codebook overhead: {result['codebook_overhead']:.4f} bits/weight")
print(f"Sparse overhead: {result['sparse_overhead']:.4f} bits/weight")
print(f"Effective total: {result['effective_bits']:.2f} bits/weight")
Dense: 4.00 bits/weight
Codebook overhead: 0.0031 bits/weight
Sparse overhead: 0.16 bits/weight
Effective total: 4.16 bits/weight
SqueezeLLM Effective Bits vs Uniform Quantization
| Method | Dense Bits | Overhead | Effective Bits | Perplexity (7B) |
|---|---|---|---|---|
| Uniform INT4 g128 | 4.00 | 0.12 | 4.12 | 5.68 (RTN) |
| Uniform INT4 g128 + AWQ | 4.00 | 0.12 | 4.12 | 5.51 |
| SqueezeLLM 4-bit | 4.00 | 0.16 | 4.16 | 5.48 |
| SqueezeLLM 3-bit | 3.00 | 0.16 | 3.16 | 6.22 |
| Uniform INT3 g128 + GPTQ | 3.00 | 0.18 | 3.18 | 6.98 |
Perplexity vs Effective Bits (Llama-2 7B)
(WikiText-2 Perplexity)When Non-Uniform Quantization is the Right Choice
Non-uniform quantization adds kernel complexity and reduces inference throughput. It is the right choice when:
-
Sub-4-bit quantization: At 3-bit or 2-bit, the gap between uniform and non-uniform is large. The limited number of levels makes optimal placement critical.
-
CPU inference where LUT is cheap: On CPU, the LUT lookup is a simple array access, which is fast. The throughput penalty is smaller than on GPU.
-
Quality is paramount: If 0.03-0.05 ppl matters (e.g., medical or legal applications), non-uniform quantization provides a meaningful improvement.
-
Mixed precision budgets: Non-uniform quantization can be combined with variable bit allocation: 2-bit for insensitive layers, 4-bit for sensitive layers, with codebooks optimized per layer.
def should_use_nonuniform(
target_bits,
hardware,
quality_requirement,
):
"""Decision: uniform vs non-uniform quantization."""
if target_bits <= 3:
return True, "At sub-4-bit, non-uniform is significantly better"
if hardware == 'cpu':
if quality_requirement == 'maximum':
return True, "LUT is cheap on CPU, quality benefit is free"
return False, "Uniform is simpler and nearly as good"
if hardware == 'gpu':
if quality_requirement == 'maximum' and target_bits <= 3:
return True, "Accept throughput penalty for quality"
return False, "Marlin/ExLlama uniform kernels are faster"
return False, "Default to uniform for simplicity"
Other Non-Uniform Approaches
QuIP# (Quantization with Incoherence Processing)
QuIP# uses random orthogonal rotations (similar to QuaRot) to make weight distributions more uniform (incoherent), then applies vector quantization using E8 lattice codebooks:
# QuIP# key idea: E8 lattice quantization
# The E8 lattice is a mathematically optimal 8-dimensional packing
# that provides better distortion-rate than k-means in high dimensions
# After incoherence processing (Hadamard rotation):
# Group weights into 8-dimensional vectors
# Quantize each vector to the nearest E8 lattice point
# Store the lattice index (compact encoding)
# E8 lattice at 2-bit effective rate achieves better quality than
# k-means at 2-bit because E8 is the densest packing in 8D
# QuIP# achieves 2-bit quantization with < 1 ppl degradation on
# Llama-2 7B -- significantly better than any scalar method
AQLM (Additive Quantization of Language Models)
AQLM uses additive (multi-codebook) quantization: each weight group is represented as the sum of entries from multiple small codebooks:
# AQLM: w_group = codebook_1[idx_1] + codebook_2[idx_2]
# With M codebooks of size C each, this represents C^M possible values
# using only M * log2(C) bits
# Example: M=2 codebooks of C=256 entries each
# Represents 256^2 = 65,536 possible values
# Using only 2 * 8 = 16 bits
# But the values are learned, not uniformly spaced