Quantization reduces the number of bits per weight. Sparsity reduces the number of weights altogether. When combined, the two techniques compound: a 2:4 sparse INT8 model uses roughly 4x less memory than a dense FP16 model and can execute up to 2-3x faster on NVIDIA Ampere and Hopper GPUs, because the sparse tensor cores process only the non-zero elements at double the throughput of their dense counterparts.
This is not a free lunch. The accuracy cost of simultaneous pruning and quantization is higher than either technique alone, and the pruning must follow NVIDIA’s rigid 2:4 structural constraint: exactly 2 out of every 4 contiguous weights must be zero. The algorithms that produce these patterns (SparseGPT, Wanda, NVIDIA ASP) differ significantly in cost, quality, and ease of integration.
This post covers the full pipeline: the hardware execution model, the pruning algorithms, the joint sparsity+quantization calibration, and production benchmarks.
The 2:4 Structured Sparsity Hardware Model
What 2:4 Means at the Hardware Level
NVIDIA’s sparse tensor cores (Ampere SM 8.0 and later) support a specific sparsity pattern: in every group of 4 contiguous elements along the reduction dimension of a matrix multiply, exactly 2 must be zero. The hardware stores only the 2 non-zero values plus a 2-bit index that encodes which positions are non-zero.
// Dense representation (4 elements):
// [a, 0, b, 0] -> stored as: values=[a, b], index=0b0101 (positions 0, 2)
// [0, a, 0, b] -> stored as: values=[a, b], index=0b1010 (positions 1, 3)
// [a, b, 0, 0] -> stored as: values=[a, b], index=0b0011 (positions 0, 1)
//
// There are C(4,2) = 6 valid patterns per group of 4.
// The 2-bit index encodes one of these 6 patterns.
//
// Memory layout for a sparse matrix (K x N):
// - Compressed values: (K/2) x N elements (50% of original)
// - Metadata: (K/4) x N x 2 bits = (K/8) x N bytes
// - Total: ~50% + ~6.25% overhead = ~56% of dense size
The sparse tensor core performs a 16x8x32 (M x N x K) matrix multiply-accumulate in one cycle for INT8, processing the 32 sparse K elements as 16 non-zero values. For FP16, the shape is 16x8x16 with 16 sparse elements yielding 8 non-zero values. In both cases, the throughput doubles compared to the dense tensor core operation because the hardware processes half the data in the same cycle count.
// Sparse tensor core throughput (per SM, per clock):
//
// Ampere (SM 8.0 - A100):
// Dense INT8: 256 ops/clock/SM
// Sparse INT8: 512 ops/clock/SM (2x)
// Dense FP16: 128 ops/clock/SM
// Sparse FP16: 256 ops/clock/SM (2x)
//
// Hopper (SM 9.0 - H100):
// Dense INT8: 512 ops/clock/SM
// Sparse INT8: 1024 ops/clock/SM (2x)
// Dense FP16: 256 ops/clock/SM
// Sparse FP16: 512 ops/clock/SM (2x)
//
// The 2x is theoretical peak. Real speedup depends on:
// 1. Whether the kernel is compute-bound (not memory-bound)
// 2. Whether the sparse metadata decoding overlaps with computation
// 3. The matrix dimensions (small matrices do not saturate tensor cores)
The 2:4 sparsity constraint is inflexible. You cannot use 3:4 or 1:4 patterns. You cannot mix sparse and dense rows within the same GEMM call through the sparse tensor core path. The entire weight matrix must be uniformly 2:4 sparse along the K dimension. If even a single group of 4 violates the pattern, the hardware path cannot be used.
Storage Format
The compressed sparse format stores the non-zero values contiguously and uses a separate metadata array to record the positions:
// cuSPARSELt storage format for 2:4 sparse matrix
// Original dense matrix A: shape [M, K] in row-major
// Compressed matrix A_sparse: shape [M, K/2]
// Metadata: shape [M, K/4] with 2 bits per element
// Example: compressing a 4x8 matrix
// Dense:
// [1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0]
// [0.0, 5.0, 0.0, 6.0, 0.0, 7.0, 0.0, 8.0]
// [1.0, 2.0, 0.0, 0.0, 0.0, 0.0, 3.0, 4.0]
// [0.0, 0.0, 5.0, 6.0, 7.0, 8.0, 0.0, 0.0]
//
// Compressed values (M=4, K/2=4):
// [1.0, 2.0, 3.0, 4.0]
// [5.0, 6.0, 7.0, 8.0]
// [1.0, 2.0, 3.0, 4.0]
// [5.0, 6.0, 7.0, 8.0]
//
// Metadata encodes which 2 of 4 positions are non-zero:
// Row 0: positions (0,2), (0,2) -> binary: 0101, 0101
// Row 1: positions (1,3), (1,3) -> binary: 1010, 1010
// Row 2: positions (0,1), (2,3) -> binary: 0011, 1100
// Row 3: positions (2,3), (0,1) -> binary: 1100, 0011
struct SparseMatrix {
half* compressed_values; // [M, K/2]
uint8_t* metadata; // [M, K/8] (2 bits per element, packed)
int M, K; // Original dimensions
};
Pruning Algorithms for 2:4 Sparsity
NVIDIA ASP (Automatic Sparsity)
NVIDIA’s baseline approach: train the model to convergence, prune to 2:4 using magnitude-based selection, then fine-tune for a few epochs to recover accuracy.
# NVIDIA ASP (Automatic SParsity) - Magnitude-based pruning
# From nvidia/apex library
import torch
from apex.contrib.sparsity import ASP
model = load_pretrained_model("llama-7b")
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# Step 1: Compute the 2:4 mask based on weight magnitudes
# For each group of 4 weights, keep the 2 with largest absolute value
ASP.prune_trained_model(model, optimizer)
# Step 2: Fine-tune with the mask applied
# The mask is fixed - pruned weights stay zero during training
for epoch in range(fine_tune_epochs):
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
# ASP automatically re-applies the mask after each optimizer step
# to ensure pruned weights remain zero
# ASP internals for mask computation:
def compute_asp_mask(weight, N=2, M=4):
"""For each group of M weights, keep top-N by magnitude."""
# weight shape: [out_features, in_features]
# Process along in_features dimension (K dimension for GEMM)
K = weight.shape[1]
mask = torch.zeros_like(weight)
for i in range(0, K, M):
group = weight[:, i:i+M].abs()
# Find top-N indices in each group
_, topk_indices = group.topk(N, dim=1)
# Set mask to 1 at those positions
for j in range(N):
mask[:, i + topk_indices[:, j]] = 1.0 if j == 0 else mask[:, i + topk_indices[:, j]]
# Simpler vectorized implementation:
reshaped = weight.view(-1, M)
_, indices = reshaped.abs().topk(N, dim=1)
mask_flat = torch.zeros_like(reshaped)
mask_flat.scatter_(1, indices, 1.0)
mask = mask_flat.view_as(weight)
return mask
The problem with ASP for LLMs: it requires fine-tuning. For a 70B model, even a few epochs of fine-tuning costs thousands of GPU-hours and requires the full training dataset.
SparseGPT: One-Shot Pruning Without Retraining
SparseGPT applies the principles from the Optimal Brain Surgeon framework: when you prune a weight, update the remaining weights to compensate for the information loss, using the inverse Hessian of the layer’s loss function.
# SparseGPT algorithm - prunes to 2:4 sparsity without fine-tuning
# Processes one layer at a time, using calibration data to compute Hessians
def sparsegpt_prune_layer(W, H_inv, sparsity_pattern="2:4"):
"""
W: weight matrix [out_features, in_features]
H_inv: inverse Hessian of the layer loss w.r.t. weights [in_features, in_features]
The key insight: when pruning weight W[i,j], the optimal update to
remaining weights W[i, S] (where S is the set of surviving weights) is:
delta_W[i, S] = -W[i,j] * H_inv[S, j] / H_inv[j, j]
This minimizes the squared error ||WX - W_pruned X||^2 where X is
the input activations.
"""
out_features, in_features = W.shape
W_sparse = W.clone()
# Process columns in groups of 4 (for 2:4 pattern)
group_size = 4
keep = 2
for col_start in range(0, in_features, group_size):
col_end = min(col_start + group_size, in_features)
group_cols = list(range(col_start, col_end))
# For each row, determine which 2 of 4 to prune
for row in range(out_features):
group_weights = W_sparse[row, group_cols]
# Compute pruning error for each possible mask
# Error for pruning column j: (W[row,j])^2 / (2 * H_inv[j,j])
errors = torch.zeros(len(group_cols))
for idx, col in enumerate(group_cols):
errors[idx] = (W_sparse[row, col] ** 2) / (2 * H_inv[col, col])
# Keep the 2 with highest pruning cost (prune the 2 with lowest cost)
_, prune_indices = errors.topk(keep, largest=False)
# Apply weight updates for pruned columns
for prune_idx in prune_indices:
prune_col = group_cols[prune_idx]
w_val = W_sparse[row, prune_col]
# Update remaining weights in this row
for surv_idx in range(len(group_cols)):
if surv_idx not in prune_indices:
surv_col = group_cols[surv_idx]
W_sparse[row, surv_col] -= (
w_val * H_inv[surv_col, prune_col] / H_inv[prune_col, prune_col]
)
# Zero out the pruned weight
W_sparse[row, prune_col] = 0.0
return W_sparse
SparseGPT requires computing the Hessian (or a factored approximation) per layer, which costs memory and compute. For a layer with (Llama-70B), the Hessian is 512 MB in FP32. This is manageable on a single GPU. The full pruning of Llama-70B takes roughly 4 hours on a single A100.
Wanda: Pruning Without Any Weight Update
Wanda (Weights AND Activations) observes that the importance of a weight depends not just on its magnitude but also on the magnitude of the activations it multiplies. The pruning metric is:
where is the L2 norm of the -th input feature across the calibration set.
# Wanda pruning - no weight updates, no Hessian computation
def wanda_prune_layer(W, X_norms, N=2, M=4):
"""
W: weight matrix [out_features, in_features]
X_norms: L2 norms of input activations per feature [in_features]
computed as sqrt(sum(X[:, j]^2)) over calibration samples
"""
# Compute importance scores
# score[i,j] = |W[i,j]| * ||X_j||_2
scores = W.abs() * X_norms.unsqueeze(0) # [out_features, in_features]
# For each group of M, keep top-N by score
out_features, in_features = W.shape
mask = torch.zeros_like(W)
for col_start in range(0, in_features, M):
col_end = min(col_start + M, in_features)
group_scores = scores[:, col_start:col_end]
_, topk = group_scores.topk(N, dim=1)
for k in range(N):
rows = torch.arange(out_features)
mask[rows, col_start + topk[:, k]] = 1.0
W_pruned = W * mask
return W_pruned, mask
# Cost comparison:
# ASP: O(training_cost * fine_tune_epochs) - most expensive
# SparseGPT: O(n_layers * d_in^2 * n_calib) - hours on 1 GPU
# Wanda: O(n_layers * d_out * d_in) - minutes on 1 GPU
Pruning Algorithm Comparison (Llama-7B, 2:4 Sparsity)
| Algorithm | Calibration Time | Perplexity (Dense: 5.68) | Requires Fine-tuning |
|---|---|---|---|
| Magnitude (ASP, no finetune) | 0 min | 10.42 (+4.74) | No (but recommended) |
| SparseGPT | ~60 min (1x A100) | 6.56 (+0.88) | No |
| Wanda | ~2 min (1x A100) | 6.72 (+1.04) | No |
| ASP + 2 epoch finetune | ~100 GPU-hrs | 5.95 (+0.27) | Yes |
Joint Sparsity + Quantization Pipeline
The critical question: do you prune first and then quantize, or quantize first and then prune, or do both simultaneously?
Order Matters
# Three possible orderings:
# Option 1: Prune -> Quantize (most common)
# Prune weights to 2:4, then quantize non-zero values to INT8
# Pro: pruning operates on full-precision weights (better decisions)
# Con: quantization error added on top of pruning error
W_pruned = sparsegpt_prune(W_fp16, H_inv) # 2:4 sparse FP16
W_final = quantize_per_channel_int8(W_pruned) # 2:4 sparse INT8
# Option 2: Quantize -> Prune (rare, usually worse)
# Quantize to INT8 first, then prune
# Pro: pruning decisions account for quantization noise
# Con: quantization of dense matrix is suboptimal (doesn't know what gets pruned)
W_quant = quantize_per_channel_int8(W_fp16) # dense INT8
W_final = magnitude_prune_24(W_quant) # 2:4 sparse INT8
# Option 3: Joint optimization (best quality, most expensive)
# Simultaneously determine pruning mask and quantization parameters
# Used by SparseGPT with quantization extension
W_final = sparsegpt_prune_and_quantize(W_fp16, H_inv) # 2:4 sparse INT8
SparseGPT Joint Pruning + Quantization
SparseGPT can be extended to perform joint sparsity and quantization in a single pass. The algorithm alternates between pruning decisions and quantization within the Hessian-based update framework:
def sparsegpt_joint(W, H_inv, group_size=128, bits=8):
"""
Joint 2:4 pruning + group quantization.
Key idea: process columns left-to-right in blocks.
For each block:
1. Determine 2:4 pruning mask (which 2 of 4 to keep)
2. Quantize the surviving weights
3. Compute the quantization + pruning error
4. Update remaining columns using H_inv to compensate
"""
out_features, in_features = W.shape
W_sparse_quant = W.clone()
block_size = 128 # Process 128 columns at a time
for block_start in range(0, in_features, block_size):
block_end = min(block_start + block_size, in_features)
block_cols = list(range(block_start, block_end))
# Extract the block
W_block = W_sparse_quant[:, block_start:block_end].clone()
H_block = H_inv[block_start:block_end, block_start:block_end]
# Step 1: Determine 2:4 pruning mask for this block
mask = compute_24_mask(W_block, H_block)
# Step 2: Apply pruning
W_block *= mask
# Step 3: Quantize non-zero values (group quantization)
for g_start in range(0, block_end - block_start, group_size):
g_end = min(g_start + group_size, block_end - block_start)
group = W_block[:, g_start:g_end]
nonzero_mask = mask[:, g_start:g_end]
# Compute scale from non-zero values only
max_val = group[nonzero_mask.bool()].abs().max()
scale = max_val / (2**(bits-1) - 1)
# Quantize and dequantize
group_quant = torch.round(group / scale).clamp(
-(2**(bits-1)), 2**(bits-1) - 1
) * scale
W_block[:, g_start:g_end] = group_quant * nonzero_mask
# Step 4: Compute error and update remaining columns
error = W_sparse_quant[:, block_start:block_end] - W_block
W_sparse_quant[:, block_start:block_end] = W_block
# Propagate error to remaining columns
if block_end < in_features:
W_sparse_quant[:, block_end:] -= (
error @ H_inv[block_start:block_end, block_end:]
)
return W_sparse_quant
Joint optimization consistently produces 0.1-0.3 perplexity points better results than sequential prune-then-quantize, because the Hessian-based weight update compensates for both the pruning error and the quantization error simultaneously. The additional compute cost is negligible since the Hessian is already computed.
Kernel-Level Implementation with cuSPARSELt
NVIDIA’s cuSPARSELt library provides the API for executing 2:4 sparse GEMMs on sparse tensor cores.
#include <cusparseLt.h>
#include <cuda_runtime.h>
// cuSPARSELt workflow for sparse GEMM:
// C = alpha * A_sparse * B + beta * C
// where A is 2:4 structured sparse
void sparse_gemm_int8(
int M, int N, int K,
const int8_t* A_dense, // Dense weight matrix [M, K]
const int8_t* B, // Activation matrix [K, N]
int32_t* C, // Output [M, N]
float alpha, float beta
) {
cusparseLtHandle_t handle;
cusparseLtInit(&handle);
// 1. Create matrix descriptors
cusparseLtMatDescriptor_t matA, matB, matC;
// A is the structured sparse matrix
cusparseLtStructuredDescriptorInit(
&handle, &matA,
M, K, // dimensions
K, // leading dimension
8, // alignment (bytes)
CUDA_R_8I, // data type: INT8
CUSPARSE_ORDER_ROW,
CUSPARSELT_SPARSITY_50_PERCENT // 2:4 sparsity
);
// B and C are dense
cusparseLtDenseDescriptorInit(
&handle, &matB, K, N, N, 8, CUDA_R_8I, CUSPARSE_ORDER_ROW
);
cusparseLtDenseDescriptorInit(
&handle, &matC, M, N, N, 8, CUDA_R_32I, CUSPARSE_ORDER_ROW
);
// 2. Create the matmul descriptor
cusparseLtMatmulDescriptor_t matmul;
cusparseLtMatmulDescriptorInit(
&handle, &matmul,
CUSPARSE_OPERATION_NON_TRANSPOSE, // opA
CUSPARSE_OPERATION_NON_TRANSPOSE, // opB
&matA, &matB, &matC, &matC,
CUSPARSE_COMPUTE_32I // compute type
);
// 3. Prune the weight matrix to 2:4 pattern
// cuSPARSELt can do this, but we typically have pre-pruned weights
int8_t* A_pruned;
cudaMalloc(&A_pruned, M * K * sizeof(int8_t));
// Prune with magnitude-based selection
cusparseLtSpMMAPrune(
&handle, &matmul,
A_dense, A_pruned,
CUSPARSELT_PRUNE_SPMMA_STRIP, // 2:4 strip pruning
nullptr // stream
);
// Verify the pruning pattern is valid
int is_valid;
cusparseLtSpMMAPruneCheck(
&handle, &matmul, A_pruned, &is_valid, nullptr
);
// is_valid == 1 if all groups satisfy 2:4 constraint
// 4. Compress the pruned matrix (50% compression + metadata)
size_t compressed_size, compressed_buffer_size;
cusparseLtSpMMACompressedSize(
&handle, &matmul, &compressed_size, &compressed_buffer_size
);
int8_t* A_compressed;
cudaMalloc(&A_compressed, compressed_size);
cusparseLtSpMMACompress(
&handle, &matmul,
A_pruned, A_compressed,
nullptr // stream
);
// 5. Plan and execute the sparse GEMM
cusparseLtMatmulAlgSelection_t alg_sel;
cusparseLtMatmulAlgSelectionInit(
&handle, &alg_sel, &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT
);
cusparseLtMatmulPlan_t plan;
cusparseLtMatmulPlanInit(&handle, &plan, &matmul, &alg_sel);
size_t workspace_size;
cusparseLtMatmulGetWorkspace(&handle, &plan, &workspace_size);
void* workspace;
cudaMalloc(&workspace, workspace_size);
// Execute: C = alpha * A_compressed * B + beta * C
cusparseLtMatmul(
&handle, &plan,
&alpha, A_compressed, B, &beta, C, C,
workspace, nullptr, 0
);
// Cleanup
cusparseLtMatmulPlanDestroy(&plan);
cudaFree(A_compressed);
cudaFree(A_pruned);
cudaFree(workspace);
cusparseLtDestroy(&handle);
}
Sparse GEMM Performance Model
The theoretical speedup from 2:4 sparsity in a GEMM is:
But real speedup depends on whether the GEMM is compute-bound or memory-bound. For memory-bound cases (small batch sizes in LLM decode), the 50% reduction in weight data is the primary benefit:
# Performance model for sparse vs dense GEMM
def sparse_gemm_speedup(M, N, K, dtype_bytes, hw_sparse_ratio=2.0):
"""
M: batch * seq_len (activation rows)
N: output features
K: input features
"""
# Compute FLOPs
dense_flops = 2 * M * N * K
sparse_flops = 2 * M * N * (K // 2) # Only non-zero elements
# Memory traffic
# Dense: read A[M,K] + W[K,N] + write C[M,N]
dense_bytes = (M * K + K * N + M * N) * dtype_bytes
# Sparse: read A[M,K] + W_compressed[K/2,N] + metadata + write C[M,N]
metadata_bytes = (K // 4) * N # 2 bits per element, packed
sparse_bytes = (M * K + (K // 2) * N + metadata_bytes + M * N) * dtype_bytes
# Arithmetic intensity (FLOPs per byte)
dense_ai = dense_flops / dense_bytes
sparse_ai = sparse_flops / sparse_bytes
# H100 SXM roofline parameters
peak_dense_tflops = 990 # INT8 dense
peak_sparse_tflops = 1979 # INT8 sparse (2x)
mem_bw_tb = 3.35 # TB/s HBM3
# Effective throughput (min of compute and memory bound)
dense_tput = min(
peak_dense_tflops * 1e12,
mem_bw_tb * 1e12 * dense_ai
)
sparse_tput = min(
peak_sparse_tflops * 1e12,
mem_bw_tb * 1e12 * sparse_ai
)
dense_time = dense_flops / dense_tput
sparse_time = sparse_flops / sparse_tput
return dense_time / sparse_time
Sparse INT8 vs Dense FP16 Speedup by Batch Size (Llama-7B, H100)
(x speedup)Accuracy Impact and Recovery
Layer Sensitivity Analysis
Not all layers tolerate sparsity equally. Attention projection layers (Q, K, V, O) tend to be more sensitive than FFN layers. The first and last layers are the most sensitive.
# Layer sensitivity analysis for 2:4 sparsity
def measure_layer_sensitivity(model, calibration_data, metric="perplexity"):
"""
Prune one layer at a time to 2:4 sparsity and measure quality impact.
This identifies which layers need special treatment.
"""
baseline = evaluate(model, calibration_data, metric)
sensitivities = {}
for name, param in model.named_parameters():
if "weight" not in name or param.dim() != 2:
continue
# Save original weight
original = param.data.clone()
# Apply 2:4 pruning to this layer only
mask = compute_wanda_mask(param.data, activation_norms[name])
param.data *= mask
# Measure degradation
score = evaluate(model, calibration_data, metric)
sensitivities[name] = score - baseline
# Restore
param.data = original
return sensitivities
# Typical sensitivity ranking (Llama-7B, perplexity delta):
# Layer 0 (embed projection): +0.35 (most sensitive)
# Layer 31 (final layer): +0.28
# Attention Q/K projections: +0.08-0.15 average
# Attention V/O projections: +0.05-0.10 average
# FFN gate/up projections: +0.03-0.07 average
# FFN down projections: +0.04-0.08 average
Selective Sparsity: Skip Sensitive Layers
A practical strategy: keep the most sensitive layers dense and apply 2:4 sparsity only to the rest.
def selective_sparse_quantize(model, sensitivity_threshold=0.15):
"""
Apply 2:4 sparse INT8 to tolerant layers, dense INT8 to sensitive layers.
"""
sensitivities = measure_layer_sensitivity(model, calib_data)
sparse_layers = []
dense_layers = []
for name, sensitivity in sensitivities.items():
if sensitivity > sensitivity_threshold:
dense_layers.append(name)
# Apply only INT8 quantization (no sparsity)
quantize_int8(model, name)
else:
sparse_layers.append(name)
# Apply joint 2:4 sparsity + INT8 quantization
prune_24_and_quantize_int8(model, name)
# Typically: ~85-90% of layers get sparse treatment
# Only first layer, last layer, and a few attention projections stay dense
return model, sparse_layers, dense_layers
Accuracy vs Compression (Llama-2-7B, WikiText-2 PPL)
| Method | Bits | Sparsity | Perplexity | Model Size |
|---|---|---|---|---|
| Dense FP16 (baseline) | 16 | 0% | 5.47 | 13.5 GB |
| Dense INT8 | 8 | 0% | 5.51 (+0.04) | 6.8 GB |
| 2:4 Sparse FP16 (Wanda) | 16 | 50% | 6.08 (+0.61) | 6.8 GB |
| 2:4 Sparse INT8 (Wanda + PTQ) | 8 | 50% | 6.35 (+0.88) | 3.4 GB |
| 2:4 Sparse INT8 (SparseGPT joint) | 8 | 50% | 6.12 (+0.65) | 3.4 GB |
| Selective sparse INT8 (skip sensitive) | 8 | ~45% | 5.85 (+0.38) | 3.8 GB |
Production Integration with vLLM and TensorRT-LLM
TensorRT-LLM Sparse Engine
TensorRT-LLM has native support for 2:4 sparse tensor core execution through its cuSPARSELt integration:
# TensorRT-LLM: building a sparse INT8 engine
# Step 1: Export sparse weights from the pruning pipeline
import tensorrt_llm
from tensorrt_llm.quantization import QuantMode
# Configure quantization + sparsity
quant_mode = QuantMode.from_description(
quantize_weights=True,
quantize_activations=True,
per_token=True,
per_channel=True,
use_weight_only=False, # W8A8, not weight-only
)
# Load the model with sparse weights
config = {
"architecture": "LlamaForCausalLM",
"dtype": "float16",
"quantization": {
"quant_algo": "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN",
"has_zero_point": False,
},
"sparsity": {
"enabled": True,
"pattern": "2:4",
# Layers to apply sparsity
"sparse_layers": [
"*.mlp.gate_proj.weight",
"*.mlp.up_proj.weight",
"*.mlp.down_proj.weight",
"*.self_attn.q_proj.weight",
"*.self_attn.k_proj.weight",
"*.self_attn.v_proj.weight",
"*.self_attn.o_proj.weight",
],
# Layers to keep dense (first/last are sensitive)
"dense_layers": [
"model.layers.0.*",
"model.layers.31.*",
"model.embed_tokens.*",
"lm_head.*",
]
}
}
Memory Budget Comparison
# Memory analysis for Llama-2-70B serving
def memory_budget(model_params_B, method):
"""Calculate GPU memory usage for different compression methods."""
params = model_params_B * 1e9
if method == "dense_fp16":
weight_bytes = params * 2
kv_cache_bytes_per_token = 2 * 80 * 8192 * 2 // 8 * 2 # 2 * n_layers * d_model * 2 / n_heads * bytes
overhead_factor = 1.1
elif method == "dense_int8":
weight_bytes = params * 1
kv_cache_bytes_per_token = 2 * 80 * 8192 * 2 // 8 * 1
overhead_factor = 1.15 # Scales, zero points
elif method == "sparse_int8":
# 50% of weights are zero -> compressed to 50% + metadata
weight_bytes = params * 1 * 0.5 + params * 0.0625 # values + metadata
kv_cache_bytes_per_token = 2 * 80 * 8192 * 2 // 8 * 1
overhead_factor = 1.15
total_weight_gb = weight_bytes * overhead_factor / (1024**3)
return total_weight_gb
# Results for Llama-2-70B:
# Dense FP16: ~140 GB (2x H100-80GB, tensor parallel)
# Dense INT8: ~75 GB (1x H100-80GB, tight)
# 2:4 Sparse INT8: ~42 GB (1x H100-80GB, plenty of room for KV cache)
Llama-2-70B Throughput (tokens/sec, H100 SXM)
(tokens/sec)Practical Calibration Workflow
The complete pipeline for producing a deployment-ready 2:4 sparse INT8 model:
# End-to-end pipeline: dense FP16 model -> 2:4 sparse INT8
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# Step 1: Load model and calibration data
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16
).cuda()
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# 128 calibration samples, 2048 tokens each
calib_dataset = load_calibration_data(
dataset="wikitext",
tokenizer=tokenizer,
n_samples=128,
seq_len=2048
)
# Step 2: Collect activation statistics (for Wanda and SmoothQuant)
activation_norms = {}
def collect_hook(name):
def hook(module, input, output):
x = input[0]
# L2 norm per feature across batch and sequence dimensions
norm = x.float().pow(2).sum(dim=(0, 1)).sqrt()
if name in activation_norms:
activation_norms[name] += norm
else:
activation_norms[name] = norm
return hook
hooks = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
hooks.append(module.register_forward_hook(collect_hook(name)))
# Run calibration forward passes
with torch.no_grad():
for batch in calib_dataset:
model(batch.cuda())
for h in hooks:
h.remove()
# Normalize
for name in activation_norms:
activation_norms[name] /= len(calib_dataset)
# Step 3: Apply SmoothQuant migration (optional, helps INT8 quality)
smooth_alpha = 0.5
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and name in activation_norms:
act_scales = activation_norms[name]
weight_scales = module.weight.abs().max(dim=0).values
# Migration factor
s = (act_scales.pow(smooth_alpha) / weight_scales.pow(1 - smooth_alpha)).clamp(min=1e-5)
# Scale weights up, activations down
module.weight.data *= s.unsqueeze(0)
# (activation scaling applied at runtime or folded into preceding LayerNorm)
# Step 4: Joint 2:4 pruning + INT8 quantization with SparseGPT
for layer_idx in range(model.config.num_hidden_layers):
layer = model.model.layers[layer_idx]
for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"]:
linear = getattr_nested(layer, proj_name)
W = linear.weight.data.float()
# Compute Hessian approximation from calibration
H = compute_hessian(linear, calib_dataset)
H_inv = torch.linalg.inv(H + 1e-4 * torch.eye(H.shape[0]).cuda())
# Joint prune + quantize
W_sparse_int8 = sparsegpt_joint(W, H_inv, group_size=128, bits=8)
linear.weight.data = W_sparse_int8.half()
# Step 5: Export to deployment format
export_sparse_int8_model(model, output_dir="llama-7b-sparse-int8")
The calibration dataset should be representative of your deployment distribution. Using WikiText for calibration but deploying on code generation tasks can yield misleading quality estimates. Always validate on your target domain before deploying a sparse+quantized model.
Limitations and When Not to Use Sparsity
2:4 structured sparsity is not universally beneficial:
When sparsity + quantization works well:
- Large models (>= 7B parameters) where accuracy is more robust
- Compute-bound workloads (large batch prefill)
- Memory-constrained deployments (fit on fewer GPUs)
- FFN-heavy architectures (MLP layers tolerate sparsity better)
When to avoid it:
- Small models (< 3B) where 50% weight pruning causes significant degradation
- Tasks requiring high factual precision (knowledge-intensive QA)
- When INT4/INT8 weight-only quantization already meets your throughput target
- Hardware without sparse tensor core support (pre-Ampere, AMD, Intel)
- Workloads dominated by attention (sparsity helps linear layers, not attention)
When Sparsity Hurts: Small Model Results
| Model | Dense FP16 PPL | 2:4 Sparse INT8 PPL | Delta | Verdict |
|---|---|---|---|---|
| Llama-2-70B | 3.32 | 3.61 | +0.29 | Acceptable |
| Llama-2-13B | 4.88 | 5.42 | +0.54 | Marginal |
| Llama-2-7B | 5.47 | 6.12 | +0.65 | Use with care |
| Llama-2-3B | 7.05 | 8.94 | +1.89 | Too much degradation |
| Llama-2-1B | 11.2 | 17.8 | +6.6 | Not viable |
Summary
The combination of 2:4 structured sparsity and INT8 quantization provides a compound compression and throughput benefit: roughly 4x memory reduction and 2-3x throughput improvement over dense FP16 on Ampere and Hopper GPUs. The key decisions are: (1) which pruning algorithm to use (SparseGPT for best quality, Wanda for speed), (2) whether to do joint or sequential pruning+quantization (joint is better), and (3) which layers to make sparse (skip the most sensitive ones). For models at 7B parameters and above, this combination is one of the most effective deployment optimization strategies available on NVIDIA hardware.