Part of Series Quantization Masterclass 21 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest — Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier — Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 — How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization — The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization — OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

Quantization reduces the number of bits per weight. Sparsity reduces the number of weights altogether. When combined, the two techniques compound: a 2:4 sparse INT8 model uses roughly 4x less memory than a dense FP16 model and can execute up to 2-3x faster on NVIDIA Ampere and Hopper GPUs, because the sparse tensor cores process only the non-zero elements at double the throughput of their dense counterparts.

This is not a free lunch. The accuracy cost of simultaneous pruning and quantization is higher than either technique alone, and the pruning must follow NVIDIA’s rigid 2:4 structural constraint: exactly 2 out of every 4 contiguous weights must be zero. The algorithms that produce these patterns (SparseGPT, Wanda, NVIDIA ASP) differ significantly in cost, quality, and ease of integration.

This post covers the full pipeline: the hardware execution model, the pruning algorithms, the joint sparsity+quantization calibration, and production benchmarks.

The 2:4 Structured Sparsity Hardware Model

What 2:4 Means at the Hardware Level

NVIDIA’s sparse tensor cores (Ampere SM 8.0 and later) support a specific sparsity pattern: in every group of 4 contiguous elements along the reduction dimension of a matrix multiply, exactly 2 must be zero. The hardware stores only the 2 non-zero values plus a 2-bit index that encodes which positions are non-zero.

// Dense representation (4 elements):
// [a, 0, b, 0]  -> stored as: values=[a, b], index=0b0101 (positions 0, 2)
// [0, a, 0, b]  -> stored as: values=[a, b], index=0b1010 (positions 1, 3)
// [a, b, 0, 0]  -> stored as: values=[a, b], index=0b0011 (positions 0, 1)
//
// There are C(4,2) = 6 valid patterns per group of 4.
// The 2-bit index encodes one of these 6 patterns.
//
// Memory layout for a sparse matrix (K x N):
// - Compressed values: (K/2) x N elements (50% of original)
// - Metadata: (K/4) x N x 2 bits = (K/8) x N bytes
// - Total: ~50% + ~6.25% overhead = ~56% of dense size

The sparse tensor core performs a 16x8x32 (M x N x K) matrix multiply-accumulate in one cycle for INT8, processing the 32 sparse K elements as 16 non-zero values. For FP16, the shape is 16x8x16 with 16 sparse elements yielding 8 non-zero values. In both cases, the throughput doubles compared to the dense tensor core operation because the hardware processes half the data in the same cycle count.

// Sparse tensor core throughput (per SM, per clock):
//
// Ampere (SM 8.0 - A100):
//   Dense INT8:  256 ops/clock/SM
//   Sparse INT8: 512 ops/clock/SM  (2x)
//   Dense FP16:  128 ops/clock/SM
//   Sparse FP16: 256 ops/clock/SM  (2x)
//
// Hopper (SM 9.0 - H100):
//   Dense INT8:  512 ops/clock/SM
//   Sparse INT8: 1024 ops/clock/SM (2x)
//   Dense FP16:  256 ops/clock/SM
//   Sparse FP16: 512 ops/clock/SM  (2x)
//
// The 2x is theoretical peak. Real speedup depends on:
//   1. Whether the kernel is compute-bound (not memory-bound)
//   2. Whether the sparse metadata decoding overlaps with computation
//   3. The matrix dimensions (small matrices do not saturate tensor cores)
⚠️ Warning

The 2:4 sparsity constraint is inflexible. You cannot use 3:4 or 1:4 patterns. You cannot mix sparse and dense rows within the same GEMM call through the sparse tensor core path. The entire weight matrix must be uniformly 2:4 sparse along the K dimension. If even a single group of 4 violates the pattern, the hardware path cannot be used.

Storage Format

The compressed sparse format stores the non-zero values contiguously and uses a separate metadata array to record the positions:

// cuSPARSELt storage format for 2:4 sparse matrix
// Original dense matrix A: shape [M, K] in row-major
// Compressed matrix A_sparse: shape [M, K/2]
// Metadata: shape [M, K/4] with 2 bits per element

// Example: compressing a 4x8 matrix
// Dense:
// [1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0]
// [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]
// [1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0]
// [0.0, 0.0, 5.0, 6.0, 7.0, 8.0, 0.0, 0.0]
//
// Compressed values (M=4, K/2=4):
// [1.0, 2.0, 3.0, 4.0]
// [5.0, 6.0, 7.0, 8.0]
// [1.0, 2.0, 3.0, 4.0]
// [5.0, 6.0, 7.0, 8.0]
//
// Metadata encodes which 2 of 4 positions are non-zero:
// Row 0: positions (0,2), (0,2) -> binary: 0101, 0101
// Row 1: positions (1,3), (1,3) -> binary: 1010, 1010
// Row 2: positions (0,1), (2,3) -> binary: 0011, 1100
// Row 3: positions (2,3), (0,1) -> binary: 1100, 0011

struct SparseMatrix {
    half* compressed_values;  // [M, K/2]
    uint8_t* metadata;        // [M, K/8] (2 bits per element, packed)
    int M, K;                 // Original dimensions
};

Pruning Algorithms for 2:4 Sparsity

NVIDIA ASP (Automatic Sparsity)

NVIDIA’s baseline approach: train the model to convergence, prune to 2:4 using magnitude-based selection, then fine-tune for a few epochs to recover accuracy.

# NVIDIA ASP (Automatic SParsity) - Magnitude-based pruning
# From nvidia/apex library

import torch
from apex.contrib.sparsity import ASP

model = load_pretrained_model("llama-7b")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

# Step 1: Compute the 2:4 mask based on weight magnitudes
# For each group of 4 weights, keep the 2 with largest absolute value
ASP.prune_trained_model(model, optimizer)

# Step 2: Fine-tune with the mask applied
# The mask is fixed - pruned weights stay zero during training
for epoch in range(fine_tune_epochs):
    for batch in dataloader:
        loss = model(batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # ASP automatically re-applies the mask after each optimizer step
        # to ensure pruned weights remain zero

# ASP internals for mask computation:
def compute_asp_mask(weight, N=2, M=4):
    """For each group of M weights, keep top-N by magnitude."""
    # weight shape: [out_features, in_features]
    # Process along in_features dimension (K dimension for GEMM)
    K = weight.shape[1]
    mask = torch.zeros_like(weight)

    for i in range(0, K, M):
        group = weight[:, i:i+M].abs()
        # Find top-N indices in each group
        _, topk_indices = group.topk(N, dim=1)
        # Set mask to 1 at those positions
        for j in range(N):
            mask[:, i + topk_indices[:, j]] = 1.0 if j == 0 else mask[:, i + topk_indices[:, j]]

    # Simpler vectorized implementation:
    reshaped = weight.view(-1, M)
    _, indices = reshaped.abs().topk(N, dim=1)
    mask_flat = torch.zeros_like(reshaped)
    mask_flat.scatter_(1, indices, 1.0)
    mask = mask_flat.view_as(weight)

    return mask

The problem with ASP for LLMs: it requires fine-tuning. For a 70B model, even a few epochs of fine-tuning costs thousands of GPU-hours and requires the full training dataset.

SparseGPT: One-Shot Pruning Without Retraining

SparseGPT applies the principles from the Optimal Brain Surgeon framework: when you prune a weight, update the remaining weights to compensate for the information loss, using the inverse Hessian of the layer’s loss function.

# SparseGPT algorithm - prunes to 2:4 sparsity without fine-tuning
# Processes one layer at a time, using calibration data to compute Hessians

def sparsegpt_prune_layer(W, H_inv, sparsity_pattern="2:4"):
    """
    W: weight matrix [out_features, in_features]
    H_inv: inverse Hessian of the layer loss w.r.t. weights [in_features, in_features]

    The key insight: when pruning weight W[i,j], the optimal update to
    remaining weights W[i, S] (where S is the set of surviving weights) is:

    delta_W[i, S] = -W[i,j] * H_inv[S, j] / H_inv[j, j]

    This minimizes the squared error ||WX - W_pruned X||^2 where X is
    the input activations.
    """
    out_features, in_features = W.shape
    W_sparse = W.clone()

    # Process columns in groups of 4 (for 2:4 pattern)
    group_size = 4
    keep = 2

    for col_start in range(0, in_features, group_size):
        col_end = min(col_start + group_size, in_features)
        group_cols = list(range(col_start, col_end))

        # For each row, determine which 2 of 4 to prune
        for row in range(out_features):
            group_weights = W_sparse[row, group_cols]

            # Compute pruning error for each possible mask
            # Error for pruning column j: (W[row,j])^2 / (2 * H_inv[j,j])
            errors = torch.zeros(len(group_cols))
            for idx, col in enumerate(group_cols):
                errors[idx] = (W_sparse[row, col] ** 2) / (2 * H_inv[col, col])

            # Keep the 2 with highest pruning cost (prune the 2 with lowest cost)
            _, prune_indices = errors.topk(keep, largest=False)

            # Apply weight updates for pruned columns
            for prune_idx in prune_indices:
                prune_col = group_cols[prune_idx]
                w_val = W_sparse[row, prune_col]

                # Update remaining weights in this row
                for surv_idx in range(len(group_cols)):
                    if surv_idx not in prune_indices:
                        surv_col = group_cols[surv_idx]
                        W_sparse[row, surv_col] -= (
                            w_val * H_inv[surv_col, prune_col] / H_inv[prune_col, prune_col]
                        )

                # Zero out the pruned weight
                W_sparse[row, prune_col] = 0.0

    return W_sparse
ℹ️ Note

SparseGPT requires computing the Hessian (or a factored approximation) per layer, which costs O(din2)O(d_{in}^2) memory and O(din2ncalibration)O(d_{in}^2 \cdot n_{calibration}) compute. For a layer with din=8192d_{in} = 8192 (Llama-70B), the Hessian is 512 MB in FP32. This is manageable on a single GPU. The full pruning of Llama-70B takes roughly 4 hours on a single A100.

Wanda: Pruning Without Any Weight Update

Wanda (Weights AND Activations) observes that the importance of a weight depends not just on its magnitude but also on the magnitude of the activations it multiplies. The pruning metric is:

score(Wij)=WijXj2\text{score}(W_{ij}) = |W_{ij}| \cdot \|X_j\|_2

where Xj2\|X_j\|_2 is the L2 norm of the jj-th input feature across the calibration set.

# Wanda pruning - no weight updates, no Hessian computation
def wanda_prune_layer(W, X_norms, N=2, M=4):
    """
    W: weight matrix [out_features, in_features]
    X_norms: L2 norms of input activations per feature [in_features]
             computed as sqrt(sum(X[:, j]^2)) over calibration samples
    """
    # Compute importance scores
    # score[i,j] = |W[i,j]| * ||X_j||_2
    scores = W.abs() * X_norms.unsqueeze(0)  # [out_features, in_features]

    # For each group of M, keep top-N by score
    out_features, in_features = W.shape
    mask = torch.zeros_like(W)

    for col_start in range(0, in_features, M):
        col_end = min(col_start + M, in_features)
        group_scores = scores[:, col_start:col_end]
        _, topk = group_scores.topk(N, dim=1)

        for k in range(N):
            rows = torch.arange(out_features)
            mask[rows, col_start + topk[:, k]] = 1.0

    W_pruned = W * mask
    return W_pruned, mask

# Cost comparison:
# ASP:       O(training_cost * fine_tune_epochs) - most expensive
# SparseGPT: O(n_layers * d_in^2 * n_calib) - hours on 1 GPU
# Wanda:     O(n_layers * d_out * d_in) - minutes on 1 GPU
📊

Pruning Algorithm Comparison (Llama-7B, 2:4 Sparsity)

AlgorithmCalibration TimePerplexity (Dense: 5.68)Requires Fine-tuning
Magnitude (ASP, no finetune) 0 min 10.42 (+4.74) No (but recommended)
SparseGPT ~60 min (1x A100) 6.56 (+0.88) No
Wanda ~2 min (1x A100) 6.72 (+1.04) No
ASP + 2 epoch finetune ~100 GPU-hrs 5.95 (+0.27) Yes

Joint Sparsity + Quantization Pipeline

The critical question: do you prune first and then quantize, or quantize first and then prune, or do both simultaneously?

Order Matters

# Three possible orderings:

# Option 1: Prune -> Quantize (most common)
# Prune weights to 2:4, then quantize non-zero values to INT8
# Pro: pruning operates on full-precision weights (better decisions)
# Con: quantization error added on top of pruning error
W_pruned = sparsegpt_prune(W_fp16, H_inv)          # 2:4 sparse FP16
W_final = quantize_per_channel_int8(W_pruned)        # 2:4 sparse INT8

# Option 2: Quantize -> Prune (rare, usually worse)
# Quantize to INT8 first, then prune
# Pro: pruning decisions account for quantization noise
# Con: quantization of dense matrix is suboptimal (doesn't know what gets pruned)
W_quant = quantize_per_channel_int8(W_fp16)          # dense INT8
W_final = magnitude_prune_24(W_quant)                # 2:4 sparse INT8

# Option 3: Joint optimization (best quality, most expensive)
# Simultaneously determine pruning mask and quantization parameters
# Used by SparseGPT with quantization extension
W_final = sparsegpt_prune_and_quantize(W_fp16, H_inv)  # 2:4 sparse INT8

SparseGPT Joint Pruning + Quantization

SparseGPT can be extended to perform joint sparsity and quantization in a single pass. The algorithm alternates between pruning decisions and quantization within the Hessian-based update framework:

def sparsegpt_joint(W, H_inv, group_size=128, bits=8):
    """
    Joint 2:4 pruning + group quantization.

    Key idea: process columns left-to-right in blocks.
    For each block:
      1. Determine 2:4 pruning mask (which 2 of 4 to keep)
      2. Quantize the surviving weights
      3. Compute the quantization + pruning error
      4. Update remaining columns using H_inv to compensate
    """
    out_features, in_features = W.shape
    W_sparse_quant = W.clone()

    block_size = 128  # Process 128 columns at a time

    for block_start in range(0, in_features, block_size):
        block_end = min(block_start + block_size, in_features)
        block_cols = list(range(block_start, block_end))

        # Extract the block
        W_block = W_sparse_quant[:, block_start:block_end].clone()
        H_block = H_inv[block_start:block_end, block_start:block_end]

        # Step 1: Determine 2:4 pruning mask for this block
        mask = compute_24_mask(W_block, H_block)

        # Step 2: Apply pruning
        W_block *= mask

        # Step 3: Quantize non-zero values (group quantization)
        for g_start in range(0, block_end - block_start, group_size):
            g_end = min(g_start + group_size, block_end - block_start)
            group = W_block[:, g_start:g_end]
            nonzero_mask = mask[:, g_start:g_end]

            # Compute scale from non-zero values only
            max_val = group[nonzero_mask.bool()].abs().max()
            scale = max_val / (2**(bits-1) - 1)

            # Quantize and dequantize
            group_quant = torch.round(group / scale).clamp(
                -(2**(bits-1)), 2**(bits-1) - 1
            ) * scale
            W_block[:, g_start:g_end] = group_quant * nonzero_mask

        # Step 4: Compute error and update remaining columns
        error = W_sparse_quant[:, block_start:block_end] - W_block
        W_sparse_quant[:, block_start:block_end] = W_block

        # Propagate error to remaining columns
        if block_end < in_features:
            W_sparse_quant[:, block_end:] -= (
                error @ H_inv[block_start:block_end, block_end:]
            )

    return W_sparse_quant
Performance

Joint optimization consistently produces 0.1-0.3 perplexity points better results than sequential prune-then-quantize, because the Hessian-based weight update compensates for both the pruning error and the quantization error simultaneously. The additional compute cost is negligible since the Hessian is already computed.

Kernel-Level Implementation with cuSPARSELt

NVIDIA’s cuSPARSELt library provides the API for executing 2:4 sparse GEMMs on sparse tensor cores.

#include <cusparseLt.h>
#include <cuda_runtime.h>

// cuSPARSELt workflow for sparse GEMM:
// C = alpha * A_sparse * B + beta * C
// where A is 2:4 structured sparse

void sparse_gemm_int8(
    int M, int N, int K,
    const int8_t* A_dense,  // Dense weight matrix [M, K]
    const int8_t* B,         // Activation matrix [K, N]
    int32_t* C,              // Output [M, N]
    float alpha, float beta
) {
    cusparseLtHandle_t handle;
    cusparseLtInit(&handle);

    // 1. Create matrix descriptors
    cusparseLtMatDescriptor_t matA, matB, matC;

    // A is the structured sparse matrix
    cusparseLtStructuredDescriptorInit(
        &handle, &matA,
        M, K,           // dimensions
        K,              // leading dimension
        8,              // alignment (bytes)
        CUDA_R_8I,      // data type: INT8
        CUSPARSE_ORDER_ROW,
        CUSPARSELT_SPARSITY_50_PERCENT  // 2:4 sparsity
    );

    // B and C are dense
    cusparseLtDenseDescriptorInit(
        &handle, &matB, K, N, N, 8, CUDA_R_8I, CUSPARSE_ORDER_ROW
    );
    cusparseLtDenseDescriptorInit(
        &handle, &matC, M, N, N, 8, CUDA_R_32I, CUSPARSE_ORDER_ROW
    );

    // 2. Create the matmul descriptor
    cusparseLtMatmulDescriptor_t matmul;
    cusparseLtMatmulDescriptorInit(
        &handle, &matmul,
        CUSPARSE_OPERATION_NON_TRANSPOSE,  // opA
        CUSPARSE_OPERATION_NON_TRANSPOSE,  // opB
        &matA, &matB, &matC, &matC,
        CUSPARSE_COMPUTE_32I               // compute type
    );

    // 3. Prune the weight matrix to 2:4 pattern
    // cuSPARSELt can do this, but we typically have pre-pruned weights
    int8_t* A_pruned;
    cudaMalloc(&A_pruned, M * K * sizeof(int8_t));

    // Prune with magnitude-based selection
    cusparseLtSpMMAPrune(
        &handle, &matmul,
        A_dense, A_pruned,
        CUSPARSELT_PRUNE_SPMMA_STRIP,  // 2:4 strip pruning
        nullptr  // stream
    );

    // Verify the pruning pattern is valid
    int is_valid;
    cusparseLtSpMMAPruneCheck(
        &handle, &matmul, A_pruned, &is_valid, nullptr
    );
    // is_valid == 1 if all groups satisfy 2:4 constraint

    // 4. Compress the pruned matrix (50% compression + metadata)
    size_t compressed_size, compressed_buffer_size;
    cusparseLtSpMMACompressedSize(
        &handle, &matmul, &compressed_size, &compressed_buffer_size
    );

    int8_t* A_compressed;
    cudaMalloc(&A_compressed, compressed_size);

    cusparseLtSpMMACompress(
        &handle, &matmul,
        A_pruned, A_compressed,
        nullptr  // stream
    );

    // 5. Plan and execute the sparse GEMM
    cusparseLtMatmulAlgSelection_t alg_sel;
    cusparseLtMatmulAlgSelectionInit(
        &handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT
    );

    cusparseLtMatmulPlan_t plan;
    cusparseLtMatmulPlanInit(&handle, &plan, &matmul, &alg_sel);

    size_t workspace_size;
    cusparseLtMatmulGetWorkspace(&handle, &plan, &workspace_size);
    void* workspace;
    cudaMalloc(&workspace, workspace_size);

    // Execute: C = alpha * A_compressed * B + beta * C
    cusparseLtMatmul(
        &handle, &plan,
        &alpha, A_compressed, B, &beta, C, C,
        workspace, nullptr, 0
    );

    // Cleanup
    cusparseLtMatmulPlanDestroy(&plan);
    cudaFree(A_compressed);
    cudaFree(A_pruned);
    cudaFree(workspace);
    cusparseLtDestroy(&handle);
}

Sparse GEMM Performance Model

The theoretical speedup from 2:4 sparsity in a GEMM is:

Speedupcompute=2×\text{Speedup}_{\text{compute}} = 2\times

But real speedup depends on whether the GEMM is compute-bound or memory-bound. For memory-bound cases (small batch sizes in LLM decode), the 50% reduction in weight data is the primary benefit:

# Performance model for sparse vs dense GEMM
def sparse_gemm_speedup(M, N, K, dtype_bytes, hw_sparse_ratio=2.0):
    """
    M: batch * seq_len (activation rows)
    N: output features
    K: input features
    """
    # Compute FLOPs
    dense_flops = 2 * M * N * K
    sparse_flops = 2 * M * N * (K // 2)  # Only non-zero elements

    # Memory traffic
    # Dense: read A[M,K] + W[K,N] + write C[M,N]
    dense_bytes = (M * K + K * N + M * N) * dtype_bytes

    # Sparse: read A[M,K] + W_compressed[K/2,N] + metadata + write C[M,N]
    metadata_bytes = (K // 4) * N  # 2 bits per element, packed
    sparse_bytes = (M * K + (K // 2) * N + metadata_bytes + M * N) * dtype_bytes

    # Arithmetic intensity (FLOPs per byte)
    dense_ai = dense_flops / dense_bytes
    sparse_ai = sparse_flops / sparse_bytes

    # H100 SXM roofline parameters
    peak_dense_tflops = 990   # INT8 dense
    peak_sparse_tflops = 1979  # INT8 sparse (2x)
    mem_bw_tb = 3.35           # TB/s HBM3

    # Effective throughput (min of compute and memory bound)
    dense_tput = min(
        peak_dense_tflops * 1e12,
        mem_bw_tb * 1e12 * dense_ai
    )
    sparse_tput = min(
        peak_sparse_tflops * 1e12,
        mem_bw_tb * 1e12 * sparse_ai
    )

    dense_time = dense_flops / dense_tput
    sparse_time = sparse_flops / sparse_tput

    return dense_time / sparse_time

Sparse INT8 vs Dense FP16 Speedup by Batch Size (Llama-7B, H100)

(x speedup)
BS=1 (decode) Memory-bound: weight size helps
1.45 x speedup
BS=8 Transitioning
1.78 x speedup
BS=32 Compute benefit kicks in
2.15 x speedup
BS=128 (prefill) Near 2x compute + smaller weights
2.42 x speedup
BS=512 Compute-dominant
2.55 x speedup

Accuracy Impact and Recovery

Layer Sensitivity Analysis

Not all layers tolerate sparsity equally. Attention projection layers (Q, K, V, O) tend to be more sensitive than FFN layers. The first and last layers are the most sensitive.

# Layer sensitivity analysis for 2:4 sparsity
def measure_layer_sensitivity(model, calibration_data, metric="perplexity"):
    """
    Prune one layer at a time to 2:4 sparsity and measure quality impact.
    This identifies which layers need special treatment.
    """
    baseline = evaluate(model, calibration_data, metric)
    sensitivities = {}

    for name, param in model.named_parameters():
        if "weight" not in name or param.dim() != 2:
            continue

        # Save original weight
        original = param.data.clone()

        # Apply 2:4 pruning to this layer only
        mask = compute_wanda_mask(param.data, activation_norms[name])
        param.data *= mask

        # Measure degradation
        score = evaluate(model, calibration_data, metric)
        sensitivities[name] = score - baseline

        # Restore
        param.data = original

    return sensitivities

# Typical sensitivity ranking (Llama-7B, perplexity delta):
# Layer 0 (embed projection):  +0.35  (most sensitive)
# Layer 31 (final layer):      +0.28
# Attention Q/K projections:    +0.08-0.15 average
# Attention V/O projections:    +0.05-0.10 average
# FFN gate/up projections:      +0.03-0.07 average
# FFN down projections:         +0.04-0.08 average

Selective Sparsity: Skip Sensitive Layers

A practical strategy: keep the most sensitive layers dense and apply 2:4 sparsity only to the rest.

def selective_sparse_quantize(model, sensitivity_threshold=0.15):
    """
    Apply 2:4 sparse INT8 to tolerant layers, dense INT8 to sensitive layers.
    """
    sensitivities = measure_layer_sensitivity(model, calib_data)

    sparse_layers = []
    dense_layers = []

    for name, sensitivity in sensitivities.items():
        if sensitivity > sensitivity_threshold:
            dense_layers.append(name)
            # Apply only INT8 quantization (no sparsity)
            quantize_int8(model, name)
        else:
            sparse_layers.append(name)
            # Apply joint 2:4 sparsity + INT8 quantization
            prune_24_and_quantize_int8(model, name)

    # Typically: ~85-90% of layers get sparse treatment
    # Only first layer, last layer, and a few attention projections stay dense
    return model, sparse_layers, dense_layers
📊

Accuracy vs Compression (Llama-2-7B, WikiText-2 PPL)

MethodBitsSparsityPerplexityModel Size
Dense FP16 (baseline) 16 0% 5.47 13.5 GB
Dense INT8 8 0% 5.51 (+0.04) 6.8 GB
2:4 Sparse FP16 (Wanda) 16 50% 6.08 (+0.61) 6.8 GB
2:4 Sparse INT8 (Wanda + PTQ) 8 50% 6.35 (+0.88) 3.4 GB
2:4 Sparse INT8 (SparseGPT joint) 8 50% 6.12 (+0.65) 3.4 GB
Selective sparse INT8 (skip sensitive) 8 ~45% 5.85 (+0.38) 3.8 GB

Production Integration with vLLM and TensorRT-LLM

TensorRT-LLM Sparse Engine

TensorRT-LLM has native support for 2:4 sparse tensor core execution through its cuSPARSELt integration:

# TensorRT-LLM: building a sparse INT8 engine
# Step 1: Export sparse weights from the pruning pipeline
import tensorrt_llm
from tensorrt_llm.quantization import QuantMode

# Configure quantization + sparsity
quant_mode = QuantMode.from_description(
    quantize_weights=True,
    quantize_activations=True,
    per_token=True,
    per_channel=True,
    use_weight_only=False,  # W8A8, not weight-only
)

# Load the model with sparse weights
config = {
    "architecture": "LlamaForCausalLM",
    "dtype": "float16",
    "quantization": {
        "quant_algo": "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN",
        "has_zero_point": False,
    },
    "sparsity": {
        "enabled": True,
        "pattern": "2:4",
        # Layers to apply sparsity
        "sparse_layers": [
            "*.mlp.gate_proj.weight",
            "*.mlp.up_proj.weight",
            "*.mlp.down_proj.weight",
            "*.self_attn.q_proj.weight",
            "*.self_attn.k_proj.weight",
            "*.self_attn.v_proj.weight",
            "*.self_attn.o_proj.weight",
        ],
        # Layers to keep dense (first/last are sensitive)
        "dense_layers": [
            "model.layers.0.*",
            "model.layers.31.*",
            "model.embed_tokens.*",
            "lm_head.*",
        ]
    }
}

Memory Budget Comparison

# Memory analysis for Llama-2-70B serving
def memory_budget(model_params_B, method):
    """Calculate GPU memory usage for different compression methods."""
    params = model_params_B * 1e9

    if method == "dense_fp16":
        weight_bytes = params * 2
        kv_cache_bytes_per_token = 2 * 80 * 8192 * 2 // 8 * 2  # 2 * n_layers * d_model * 2 / n_heads * bytes
        overhead_factor = 1.1

    elif method == "dense_int8":
        weight_bytes = params * 1
        kv_cache_bytes_per_token = 2 * 80 * 8192 * 2 // 8 * 1
        overhead_factor = 1.15  # Scales, zero points

    elif method == "sparse_int8":
        # 50% of weights are zero -> compressed to 50% + metadata
        weight_bytes = params * 1 * 0.5 + params * 0.0625  # values + metadata
        kv_cache_bytes_per_token = 2 * 80 * 8192 * 2 // 8 * 1
        overhead_factor = 1.15

    total_weight_gb = weight_bytes * overhead_factor / (1024**3)
    return total_weight_gb

# Results for Llama-2-70B:
# Dense FP16:     ~140 GB (2x H100-80GB, tensor parallel)
# Dense INT8:     ~75 GB  (1x H100-80GB, tight)
# 2:4 Sparse INT8: ~42 GB (1x H100-80GB, plenty of room for KV cache)

Llama-2-70B Throughput (tokens/sec, H100 SXM)

(tokens/sec)
Dense FP16 (2x H100 TP) Baseline, 2 GPUs
2,150 tokens/sec
Dense INT8 (1x H100) Fits on 1 GPU
3,400 tokens/sec
2:4 Sparse INT8 (1x H100) 50% more headroom
5,100 tokens/sec
Dense FP16 (1x H100, OOM at BS>4) Does not fit well
890 tokens/sec

Practical Calibration Workflow

The complete pipeline for producing a deployment-ready 2:4 sparse INT8 model:

# End-to-end pipeline: dense FP16 model -> 2:4 sparse INT8

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Step 1: Load model and calibration data
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16
).cuda()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# 128 calibration samples, 2048 tokens each
calib_dataset = load_calibration_data(
    dataset="wikitext",
    tokenizer=tokenizer,
    n_samples=128,
    seq_len=2048
)

# Step 2: Collect activation statistics (for Wanda and SmoothQuant)
activation_norms = {}
def collect_hook(name):
    def hook(module, input, output):
        x = input[0]
        # L2 norm per feature across batch and sequence dimensions
        norm = x.float().pow(2).sum(dim=(0, 1)).sqrt()
        if name in activation_norms:
            activation_norms[name] += norm
        else:
            activation_norms[name] = norm
    return hook

hooks = []
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        hooks.append(module.register_forward_hook(collect_hook(name)))

# Run calibration forward passes
with torch.no_grad():
    for batch in calib_dataset:
        model(batch.cuda())

for h in hooks:
    h.remove()

# Normalize
for name in activation_norms:
    activation_norms[name] /= len(calib_dataset)

# Step 3: Apply SmoothQuant migration (optional, helps INT8 quality)
smooth_alpha = 0.5
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Linear) and name in activation_norms:
        act_scales = activation_norms[name]
        weight_scales = module.weight.abs().max(dim=0).values

        # Migration factor
        s = (act_scales.pow(smooth_alpha) / weight_scales.pow(1 - smooth_alpha)).clamp(min=1e-5)

        # Scale weights up, activations down
        module.weight.data *= s.unsqueeze(0)
        # (activation scaling applied at runtime or folded into preceding LayerNorm)

# Step 4: Joint 2:4 pruning + INT8 quantization with SparseGPT
for layer_idx in range(model.config.num_hidden_layers):
    layer = model.model.layers[layer_idx]

    for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj",
                       "gate_proj", "up_proj", "down_proj"]:
        linear = getattr_nested(layer, proj_name)
        W = linear.weight.data.float()

        # Compute Hessian approximation from calibration
        H = compute_hessian(linear, calib_dataset)
        H_inv = torch.linalg.inv(H + 1e-4 * torch.eye(H.shape[0]).cuda())

        # Joint prune + quantize
        W_sparse_int8 = sparsegpt_joint(W, H_inv, group_size=128, bits=8)

        linear.weight.data = W_sparse_int8.half()

# Step 5: Export to deployment format
export_sparse_int8_model(model, output_dir="llama-7b-sparse-int8")
💡 Tip

The calibration dataset should be representative of your deployment distribution. Using WikiText for calibration but deploying on code generation tasks can yield misleading quality estimates. Always validate on your target domain before deploying a sparse+quantized model.

Limitations and When Not to Use Sparsity

2:4 structured sparsity is not universally beneficial:

When sparsity + quantization works well:
  - Large models (>= 7B parameters) where accuracy is more robust
  - Compute-bound workloads (large batch prefill)
  - Memory-constrained deployments (fit on fewer GPUs)
  - FFN-heavy architectures (MLP layers tolerate sparsity better)

When to avoid it:
  - Small models (< 3B) where 50% weight pruning causes significant degradation
  - Tasks requiring high factual precision (knowledge-intensive QA)
  - When INT4/INT8 weight-only quantization already meets your throughput target
  - Hardware without sparse tensor core support (pre-Ampere, AMD, Intel)
  - Workloads dominated by attention (sparsity helps linear layers, not attention)
📊

When Sparsity Hurts: Small Model Results

ModelDense FP16 PPL2:4 Sparse INT8 PPLDeltaVerdict
Llama-2-70B 3.32 3.61 +0.29 Acceptable
Llama-2-13B 4.88 5.42 +0.54 Marginal
Llama-2-7B 5.47 6.12 +0.65 Use with care
Llama-2-3B 7.05 8.94 +1.89 Too much degradation
Llama-2-1B 11.2 17.8 +6.6 Not viable

Summary

The combination of 2:4 structured sparsity and INT8 quantization provides a compound compression and throughput benefit: roughly 4x memory reduction and 2-3x throughput improvement over dense FP16 on Ampere and Hopper GPUs. The key decisions are: (1) which pruning algorithm to use (SparseGPT for best quality, Wanda for speed), (2) whether to do joint or sequential pruning+quantization (joint is better), and (3) which layers to make sparse (skip the most sensitive ones). For models at 7B parameters and above, this combination is one of the most effective deployment optimization strategies available on NVIDIA hardware.