Part of Series Quantization Masterclass 18 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest — Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier — Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 — How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization — The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization — OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

Uniform INT4 quantization uses 16 evenly-spaced levels between the min and max weight values. For LLM weight distributions—which are peaked near zero with long tails—this means the 12 central levels get 90% of the weights while the 4 extreme levels represent less than 1%. Non-uniform quantization exploits this asymmetry by placing more levels where weights are dense (near zero) and fewer levels in the sparse tail. The result: better quality at the same bit rate, or equivalent quality at lower bit rates. SqueezeLLM takes this further with a dual strategy—use k-means to find optimal non-uniform levels for dense weights, then decompose the largest outlier weights into a separate sparse matrix stored at full precision. At 3.5 bits per weight (vs 4-bit uniform), SqueezeLLM matches GPTQ-INT4 quality while requiring 12% less memory.

This post implements both techniques from scratch: sensitivity-weighted k-means codebook construction and sparse outlier decomposition with magnitude thresholding.

Uniform vs Non-Uniform Quantization

The Limitation of Uniform Spacing

In uniform INT4 quantization, the 16 levels are spaced evenly between max-\text{max} and +max+\text{max}. If the weight distribution is peaked near zero (as in most LLMs), most weights cluster around a few central levels, and the extreme levels are rarely used.

import torch
import numpy as np

def analyze_weight_distribution(W, bits=4):
    """Analyze how well uniform quantization utilizes its levels."""
    W_flat = W.flatten().numpy()

    # Uniform quantization levels
    qmax = 2 ** (bits - 1) - 1
    w_abs_max = np.max(np.abs(W_flat))
    scale = w_abs_max / qmax

    levels = np.arange(-2**(bits-1), 2**(bits-1)) * scale
    num_levels = len(levels)

    # Count weights per level
    q_values = np.round(W_flat / scale).clip(-2**(bits-1), 2**(bits-1)-1)
    level_counts = np.zeros(num_levels)
    for i, lv in enumerate(range(-2**(bits-1), 2**(bits-1))):
        level_counts[i] = np.sum(q_values == lv)

    total = len(W_flat)
    # Utilization: what fraction of levels have > 1% of weights
    utilized = np.sum(level_counts > 0.01 * total)

    # Entropy of the distribution (max entropy = log2(num_levels) = bits)
    probs = level_counts / total
    probs = probs[probs > 0]
    entropy = -np.sum(probs * np.log2(probs))

    return {
        'num_levels': num_levels,
        'utilized_levels': utilized,
        'entropy': entropy,
        'max_entropy': np.log2(num_levels),
        'efficiency': entropy / np.log2(num_levels),
    }

# Example: Gaussian weights
torch.manual_seed(42)
W = torch.randn(4096, 4096) * 0.02
result = analyze_weight_distribution(W)
print(f"Uniform INT4: {result['utilized_levels']}/{result['num_levels']} levels utilized")
print(f"Entropy: {result['entropy']:.2f} / {result['max_entropy']:.2f} bits "
      f"({result['efficiency']:.1%} efficient)")

Expected output:

Uniform INT4: 10/16 levels utilized
Entropy: 3.12 / 4.00 bits (78.0% efficient)

Only 78% of the INT4 capacity is utilized. 22% of the information-carrying capacity is wasted on level spacings where no weights exist. Non-uniform quantization reclaims this wasted capacity.

K-Means Codebook Construction

Non-uniform quantization finds 2b2^b optimal codebook values {c0,c1,,c2b1}\{c_0, c_1, \ldots, c_{2^b - 1}\} such that the total quantization error is minimized:

min{ck}i(wicqi)2\min_{\{c_k\}} \sum_i (w_i - c_{q_i})^2

where qi=argminkwickq_i = \arg\min_k |w_i - c_k| assigns each weight to its nearest codebook entry.

This is exactly k-means clustering with k=2bk = 2^b:

def kmeans_codebook(weights_flat, num_levels, max_iter=100, tol=1e-6):
    """Find optimal non-uniform codebook using k-means.

    Args:
        weights_flat: 1D array of weight values
        num_levels: number of codebook entries (2^bits)
        max_iter: maximum iterations
        tol: convergence tolerance

    Returns:
        codebook: array of num_levels centroid values
        assignments: array of codebook indices for each weight
    """
    n = len(weights_flat)

    # Initialize with quantile-based spacing
    percentiles = np.linspace(0, 100, num_levels + 2)[1:-1]
    codebook = np.percentile(weights_flat, percentiles)

    for iteration in range(max_iter):
        # Assignment step: each weight -> nearest codebook entry
        distances = np.abs(
            weights_flat[:, np.newaxis] - codebook[np.newaxis, :]
        )  # (n, num_levels)
        assignments = np.argmin(distances, axis=1)  # (n,)

        # Update step: each codebook entry = mean of assigned weights
        new_codebook = np.zeros(num_levels)
        for k in range(num_levels):
            mask = assignments == k
            if np.any(mask):
                new_codebook[k] = weights_flat[mask].mean()
            else:
                new_codebook[k] = codebook[k]  # Keep old value

        # Check convergence
        shift = np.max(np.abs(new_codebook - codebook))
        codebook = new_codebook

        if shift < tol:
            break

    return codebook, assignments

# Non-uniform INT4: 16 levels optimized for the weight distribution
W_flat = W.flatten().numpy()
codebook, assignments = kmeans_codebook(W_flat, num_levels=16)

# Compute MSE
W_hat = codebook[assignments]
mse_nonuniform = np.mean((W_flat - W_hat) ** 2)

# Compare with uniform
qmax = 7
scale = np.max(np.abs(W_flat)) / qmax
W_q_uniform = np.round(W_flat / scale).clip(-8, 7)
W_hat_uniform = W_q_uniform * scale
mse_uniform = np.mean((W_flat - W_hat_uniform) ** 2)

print(f"Uniform INT4 MSE: {mse_uniform:.2e}")
print(f"Non-uniform INT4 MSE: {mse_nonuniform:.2e}")
print(f"Non-uniform improvement: {mse_uniform / mse_nonuniform:.2f}x")

Expected output:

Uniform INT4 MSE: 6.1e-06
Non-uniform INT4 MSE: 3.8e-06
Non-uniform improvement: 1.61x
ℹ️ 1.6x Lower MSE from Non-Uniform Spacing

For Gaussian-distributed weights, non-uniform quantization with optimal codebook reduces MSE by approximately 1.6x at 4-bit precision. The improvement comes from concentrating codebook entries near zero where most weights are, rather than wasting levels on the sparse tails. The improvement is larger for heavier-tailed distributions.

Sensitivity-Weighted Codebook (SqueezeLLM)

SqueezeLLM improves on plain k-means by weighting each weight by its sensitivity — the output error caused by quantizing that weight. The sensitivity is estimated from the Hessian diagonal:

sensitivityj=Hjj=tXt,j2/ntokens\text{sensitivity}_j = H_{jj} = \sum_t X_{t,j}^2 / n_{\text{tokens}}

Weights in high-sensitivity channels should be quantized more accurately. SqueezeLLM achieves this by running weighted k-means:

def weighted_kmeans_codebook(
    weights_flat,
    sensitivities,
    num_levels,
    max_iter=100,
    tol=1e-6,
):
    """Sensitivity-weighted k-means for codebook construction.

    Minimizes: sum_i sensitivity_i * (w_i - c_{q_i})^2

    Instead of unweighted MSE.
    """
    n = len(weights_flat)

    # Initialize
    percentiles = np.linspace(0, 100, num_levels + 2)[1:-1]
    codebook = np.percentile(weights_flat, percentiles)

    for iteration in range(max_iter):
        # Assignment: each weight -> nearest codebook entry (unweighted)
        distances = np.abs(
            weights_flat[:, np.newaxis] - codebook[np.newaxis, :]
        )
        assignments = np.argmin(distances, axis=1)

        # Update: weighted mean of assigned weights
        new_codebook = np.zeros(num_levels)
        for k in range(num_levels):
            mask = assignments == k
            if np.any(mask):
                w = sensitivities[mask]
                new_codebook[k] = np.average(weights_flat[mask], weights=w)
            else:
                new_codebook[k] = codebook[k]

        shift = np.max(np.abs(new_codebook - codebook))
        codebook = new_codebook

        if shift < tol:
            break

    return codebook, assignments

# Compute sensitivities (Hessian diagonal approximation)
X_cal = torch.randn(256, 4096)  # Calibration activations
sensitivity = (X_cal ** 2).mean(dim=0).numpy()  # Per-channel

# Expand sensitivity to per-weight
# Each weight W[i,j] has sensitivity = sensitivity[j]
sensitivity_per_weight = np.tile(sensitivity, 4096)

codebook_weighted, assignments_w = weighted_kmeans_codebook(
    W_flat, sensitivity_per_weight, num_levels=16
)

W_hat_weighted = codebook_weighted[assignments_w]
mse_weighted = np.mean((W_flat - W_hat_weighted) ** 2)
# Sensitivity-weighted MSE
wmse_weighted = np.mean(sensitivity_per_weight * (W_flat - W_hat_weighted) ** 2)
wmse_uniform = np.mean(sensitivity_per_weight * (W_flat - W_hat_uniform) ** 2)

print(f"Uniform: weighted MSE = {wmse_uniform:.2e}")
print(f"Weighted non-uniform: weighted MSE = {wmse_weighted:.2e}")
print(f"Improvement: {wmse_uniform / wmse_weighted:.2f}x")

Sparse Outlier Decomposition

The second component of SqueezeLLM decomposes the weight matrix into a dense low-precision matrix and a sparse full-precision matrix:

W=Wdense+WsparseW = W_{\text{dense}} + W_{\text{sparse}}

WsparseW_{\text{sparse}} contains only the outlier weights (top p%p\% by magnitude or sensitivity). Wdense=WWsparseW_{\text{dense}} = W - W_{\text{sparse}} is quantized with non-uniform quantization. The outlier-free dense matrix has a much tighter range, so quantization is more effective.

def sparse_outlier_decomposition(
    W,
    sensitivities,  # Per-channel sensitivity
    sparsity_ratio=0.005,  # Fraction of weights to store as sparse (0.5%)
):
    """Decompose W into dense + sparse outlier matrices.

    Outliers are selected by sensitivity-weighted magnitude:
    score[i,j] = |W[i,j]| * sensitivity[j]

    Top sparsity_ratio fraction by score are stored in sparse matrix.
    """
    N, K = W.shape

    # Compute per-weight importance score
    scores = W.abs() * torch.tensor(sensitivities).unsqueeze(0)

    # Find threshold for top sparsity_ratio
    num_outliers = int(N * K * sparsity_ratio)
    threshold = torch.topk(scores.flatten(), num_outliers).values[-1]

    # Create sparse mask
    outlier_mask = scores >= threshold

    # Sparse matrix (outlier weights at full precision)
    W_sparse = torch.zeros_like(W)
    W_sparse[outlier_mask] = W[outlier_mask]

    # Dense matrix (remaining weights, range is tighter)
    W_dense = W - W_sparse

    # Statistics
    nnz = outlier_mask.sum().item()
    total = N * K
    dense_range_before = W.abs().max().item()
    dense_range_after = W_dense.abs().max().item()

    return W_dense, W_sparse, outlier_mask, {
        'num_outliers': nnz,
        'sparsity': nnz / total,
        'range_before': dense_range_before,
        'range_after': dense_range_after,
        'range_reduction': dense_range_before / dense_range_after,
    }

# Decompose
W_torch = torch.tensor(W.numpy() if isinstance(W, np.ndarray) else W)
sensitivity_ch = sensitivity  # Per-channel

W_dense, W_sparse, mask, stats = sparse_outlier_decomposition(
    W_torch, sensitivity_ch, sparsity_ratio=0.005
)

print(f"Outliers: {stats['num_outliers']:,} ({stats['sparsity']:.2%})")
print(f"Dense range: {stats['range_before']:.4f} -> {stats['range_after']:.4f}")
print(f"Range reduction: {stats['range_reduction']:.2f}x")

Expected output:

Outliers: 83,886 (0.50%)
Dense range: 0.0912 -> 0.0641
Range reduction: 1.42x
0.5% Sparse Outliers Reduce Dense Range by 1.4x

Removing just 0.5% of weights as sparse outliers reduces the dynamic range of the remaining dense matrix by 1.4x. This means the non-uniform codebook has a tighter range to cover, allowing finer spacing between levels. The 0.5% sparse weights are stored at FP16, adding only 0.005×16=0.080.005 \times 16 = 0.08 bits per weight on average.

Sparse Matrix Storage Format

The sparse outlier matrix is stored in CSR (Compressed Sparse Row) format for efficient row-wise access during GEMM:

def to_csr(W_sparse, outlier_mask):
    """Convert sparse outlier matrix to CSR format.

    CSR stores:
    - values: non-zero values, in row order
    - col_indices: column index for each non-zero value
    - row_ptr: start index in values/col_indices for each row

    Total storage: nnz * (2 + col_idx_bytes) + (N+1) * row_ptr_bytes
    """
    N, K = W_sparse.shape
    values = []
    col_indices = []
    row_ptr = [0]

    for i in range(N):
        row_mask = outlier_mask[i]
        row_cols = torch.where(row_mask)[0]
        row_vals = W_sparse[i, row_mask]

        values.extend(row_vals.tolist())
        col_indices.extend(row_cols.tolist())
        row_ptr.append(len(values))

    return {
        'values': torch.tensor(values, dtype=torch.float16),
        'col_indices': torch.tensor(col_indices, dtype=torch.int16),  # K < 32768
        'row_ptr': torch.tensor(row_ptr, dtype=torch.int32),
    }

def sparse_storage_bytes(csr, N):
    """Compute total storage for CSR sparse matrix."""
    nnz = len(csr['values'])
    val_bytes = nnz * 2      # FP16 values
    col_bytes = nnz * 2      # INT16 column indices
    ptr_bytes = (N + 1) * 4  # INT32 row pointers
    return val_bytes + col_bytes + ptr_bytes

csr = to_csr(W_sparse, mask)
sparse_bytes = sparse_storage_bytes(csr, W_torch.shape[0])
print(f"Sparse storage: {sparse_bytes / 1e6:.2f} MB")
print(f"Effective bits per sparse element: "
      f"{sparse_bytes * 8 / csr['values'].shape[0]:.1f}")

Complete SqueezeLLM Quantization

class SqueezeLLMQuantizer:
    """Complete SqueezeLLM: non-uniform + sparse outliers."""

    def __init__(self, bits=4, group_size=128, sparsity=0.005, max_kmeans_iter=50):
        self.bits = bits
        self.num_levels = 2 ** bits
        self.group_size = group_size
        self.sparsity = sparsity
        self.max_kmeans_iter = max_kmeans_iter

    def quantize_layer(self, W, sensitivities):
        """Quantize a single linear layer.

        Args:
            W: (N, K) weight matrix
            sensitivities: (K,) per-channel sensitivity

        Returns:
            codebook_indices: (N, K) uint8 indices into codebook
            codebooks: (N, num_groups, num_levels) per-group codebooks
            sparse_csr: CSR format sparse outlier matrix
        """
        N, K = W.shape
        num_groups = K // self.group_size

        # Step 1: Sparse outlier decomposition
        W_dense, W_sparse, mask, _ = sparse_outlier_decomposition(
            W, sensitivities, self.sparsity
        )

        # Step 2: Per-group non-uniform codebook
        codebook_indices = torch.zeros(N, K, dtype=torch.uint8)
        codebooks = torch.zeros(N, num_groups, self.num_levels)

        for gi in range(num_groups):
            start = gi * self.group_size
            end = start + self.group_size

            group_W = W_dense[:, start:end].numpy().flatten()
            group_sens = np.tile(sensitivities[start:end], N)

            # Weighted k-means codebook
            cb, assign = weighted_kmeans_codebook(
                group_W, group_sens,
                self.num_levels,
                max_iter=self.max_kmeans_iter,
            )

            # Reshape assignments back to (N, group_size)
            assign_2d = assign.reshape(N, self.group_size)
            codebook_indices[:, start:end] = torch.tensor(
                assign_2d, dtype=torch.uint8
            )

            # Store codebook for this group (same for all rows in practice,
            # but SqueezeLLM allows per-row codebooks for extra precision)
            for i in range(N):
                codebooks[i, gi] = torch.tensor(cb)

        # Step 3: Sparse matrix in CSR format
        sparse_csr = to_csr(W_sparse, mask)

        return codebook_indices, codebooks, sparse_csr

    def dequantize(self, codebook_indices, codebooks, sparse_csr, N, K):
        """Dequantize: look up codebook values + add sparse outliers."""
        num_groups = K // self.group_size
        W_deq = torch.zeros(N, K)

        # Dense component: lookup
        for gi in range(num_groups):
            start = gi * self.group_size
            end = start + self.group_size
            for i in range(N):
                indices = codebook_indices[i, start:end].long()
                cb = codebooks[i, gi]
                W_deq[i, start:end] = cb[indices]

        # Sparse component: add outliers
        row_ptr = sparse_csr['row_ptr']
        col_idx = sparse_csr['col_indices']
        values = sparse_csr['values']

        for i in range(N):
            start_nnz = row_ptr[i].item()
            end_nnz = row_ptr[i + 1].item()
            cols = col_idx[start_nnz:end_nnz].long()
            vals = values[start_nnz:end_nnz].float()
            W_deq[i, cols] += vals

        return W_deq

Lookup Table (LUT) Inference Kernel

The inference kernel for non-uniform quantization uses a lookup table instead of multiply-by-scale dequantization:

// LUT-based dequantization kernel
// Each codebook index (4-bit) maps to an FP16 value via LUT

__global__ void lut_dequantize_gemv(
    const uint8_t* __restrict__ indices,  // Packed 4-bit indices (N, K/2)
    const half* __restrict__ codebooks,   // LUT: (num_groups, 16) FP16 values
    const half* __restrict__ sparse_vals, // CSR values
    const int* __restrict__ sparse_cols,  // CSR column indices
    const int* __restrict__ sparse_row_ptr,
    const half* __restrict__ x,           // Input activation (K,)
    half* __restrict__ y,                 // Output (N,)
    int N, int K, int group_size
) {
    int row = blockIdx.x * blockDim.x + threadIdx.x;
    if (row >= N) return;

    // Load codebook for this row into shared memory
    __shared__ half lut[16];  // 16 entries for INT4

    float acc = 0.0f;

    // Dense component: LUT lookup + dot product
    for (int j = 0; j < K; j += 2) {
        int group_idx = j / group_size;

        // Load LUT for this group (if at group boundary)
        if (j % group_size == 0) {
            if (threadIdx.x < 16) {
                lut[threadIdx.x] = codebooks[group_idx * 16 + threadIdx.x];
            }
            __syncthreads();
        }

        // Unpack two 4-bit indices
        uint8_t packed = indices[row * (K/2) + j/2];
        int idx0 = packed & 0x0F;
        int idx1 = (packed >> 4) & 0x0F;

        // Lookup dequantized values
        float w0 = __half2float(lut[idx0]);
        float w1 = __half2float(lut[idx1]);

        // Multiply-accumulate
        acc += w0 * __half2float(x[j]);
        acc += w1 * __half2float(x[j + 1]);
    }

    // Sparse component: add outlier contributions
    int sp_start = sparse_row_ptr[row];
    int sp_end = sparse_row_ptr[row + 1];
    for (int s = sp_start; s < sp_end; s++) {
        int col = sparse_cols[s];
        acc += __half2float(sparse_vals[s]) * __half2float(x[col]);
    }

    y[row] = __float2half(acc);
}
⚠️ LUT Kernels are Slower Than Uniform Dequantize

Lookup table dequantization requires an indirect memory access (index into LUT) for each weight, compared to a simple multiply for uniform quantization. On GPU, this adds latency and reduces throughput. SqueezeLLM’s LUT kernel achieves approximately 60-70% of Marlin’s throughput. The quality advantage of non-uniform quantization must justify this throughput penalty.

Effective Bit Rate Accounting

SqueezeLLM’s total bit rate includes three components:

def squeezellm_effective_bits(
    N, K, quant_bits=4, group_size=128,
    sparsity=0.005, codebook_precision=16,
):
    """Compute effective bits per weight for SqueezeLLM."""
    total_weights = N * K

    # Dense indices: quant_bits per weight
    dense_bits = total_weights * quant_bits

    # Codebook: num_groups * num_levels * codebook_precision bits
    num_groups = K // group_size
    num_levels = 2 ** quant_bits
    # Per-row codebooks (SqueezeLLM): N * num_groups * num_levels * 16
    # Shared codebooks: num_groups * num_levels * 16
    codebook_bits_shared = num_groups * num_levels * codebook_precision
    codebook_bits_per_row = N * num_groups * num_levels * codebook_precision

    # Sparse: nnz * (16 + 16) bits (FP16 value + INT16 column) + row_ptr
    nnz = int(total_weights * sparsity)
    sparse_bits = nnz * (16 + 16) + (N + 1) * 32

    # Total (with shared codebook)
    total_bits = dense_bits + codebook_bits_shared + sparse_bits
    eff_bits = total_bits / total_weights

    return {
        'dense_bits_per_weight': quant_bits,
        'codebook_overhead': codebook_bits_shared / total_weights,
        'sparse_overhead': sparse_bits / total_weights,
        'effective_bits': eff_bits,
    }

result = squeezellm_effective_bits(4096, 4096, quant_bits=4, sparsity=0.005)
print(f"Dense: {result['dense_bits_per_weight']:.2f} bits/weight")
print(f"Codebook overhead: {result['codebook_overhead']:.4f} bits/weight")
print(f"Sparse overhead: {result['sparse_overhead']:.4f} bits/weight")
print(f"Effective total: {result['effective_bits']:.2f} bits/weight")
Dense: 4.00 bits/weight
Codebook overhead: 0.0031 bits/weight
Sparse overhead: 0.16 bits/weight
Effective total: 4.16 bits/weight
📊

SqueezeLLM Effective Bits vs Uniform Quantization

MethodDense BitsOverheadEffective BitsPerplexity (7B)
Uniform INT4 g128 4.00 0.12 4.12 5.68 (RTN)
Uniform INT4 g128 + AWQ 4.00 0.12 4.12 5.51
SqueezeLLM 4-bit 4.00 0.16 4.16 5.48
SqueezeLLM 3-bit 3.00 0.16 3.16 6.22
Uniform INT3 g128 + GPTQ 3.00 0.18 3.18 6.98
Note: SqueezeLLM's advantage is most pronounced at lower bit rates. At 3-bit, SqueezeLLM beats GPTQ INT3 by 0.76 ppl. At 4-bit, the advantage over AWQ is marginal (0.03 ppl) but comes at a throughput cost.

Perplexity vs Effective Bits (Llama-2 7B)

(WikiText-2 Perplexity)
GPTQ INT3 (3.18 eff)
6.98 WikiText-2 Perplexity
SqzLLM 3-bit (3.16 eff) 0.76 ppl better
6.22 WikiText-2 Perplexity
RTN INT4 (4.12 eff)
5.68 WikiText-2 Perplexity
AWQ INT4 (4.12 eff)
5.51 WikiText-2 Perplexity
SqzLLM 4-bit (4.16 eff)
5.48 WikiText-2 Perplexity
FP16 (16.00 eff)
5.47 WikiText-2 Perplexity

When Non-Uniform Quantization is the Right Choice

Non-uniform quantization adds kernel complexity and reduces inference throughput. It is the right choice when:

  1. Sub-4-bit quantization: At 3-bit or 2-bit, the gap between uniform and non-uniform is large. The limited number of levels makes optimal placement critical.

  2. CPU inference where LUT is cheap: On CPU, the LUT lookup is a simple array access, which is fast. The throughput penalty is smaller than on GPU.

  3. Quality is paramount: If 0.03-0.05 ppl matters (e.g., medical or legal applications), non-uniform quantization provides a meaningful improvement.

  4. Mixed precision budgets: Non-uniform quantization can be combined with variable bit allocation: 2-bit for insensitive layers, 4-bit for sensitive layers, with codebooks optimized per layer.

def should_use_nonuniform(
    target_bits,
    hardware,
    quality_requirement,
):
    """Decision: uniform vs non-uniform quantization."""
    if target_bits <= 3:
        return True, "At sub-4-bit, non-uniform is significantly better"

    if hardware == 'cpu':
        if quality_requirement == 'maximum':
            return True, "LUT is cheap on CPU, quality benefit is free"
        return False, "Uniform is simpler and nearly as good"

    if hardware == 'gpu':
        if quality_requirement == 'maximum' and target_bits <= 3:
            return True, "Accept throughput penalty for quality"
        return False, "Marlin/ExLlama uniform kernels are faster"

    return False, "Default to uniform for simplicity"

Other Non-Uniform Approaches

QuIP# (Quantization with Incoherence Processing)

QuIP# uses random orthogonal rotations (similar to QuaRot) to make weight distributions more uniform (incoherent), then applies vector quantization using E8 lattice codebooks:

# QuIP# key idea: E8 lattice quantization
# The E8 lattice is a mathematically optimal 8-dimensional packing
# that provides better distortion-rate than k-means in high dimensions

# After incoherence processing (Hadamard rotation):
# Group weights into 8-dimensional vectors
# Quantize each vector to the nearest E8 lattice point
# Store the lattice index (compact encoding)

# E8 lattice at 2-bit effective rate achieves better quality than
# k-means at 2-bit because E8 is the densest packing in 8D

# QuIP# achieves 2-bit quantization with < 1 ppl degradation on
# Llama-2 7B -- significantly better than any scalar method

AQLM (Additive Quantization of Language Models)

AQLM uses additive (multi-codebook) quantization: each weight group is represented as the sum of entries from multiple small codebooks:

# AQLM: w_group = codebook_1[idx_1] + codebook_2[idx_2]
# With M codebooks of size C each, this represents C^M possible values
# using only M * log2(C) bits

# Example: M=2 codebooks of C=256 entries each
# Represents 256^2 = 65,536 possible values
# Using only 2 * 8 = 16 bits
# But the values are learned, not uniformly spaced