A 70B parameter model has 70 billion weights. Not all of them matter. Pruning removes the ones that contribute least to output quality, producing a sparser model that uses less memory and (with hardware support) runs faster. The challenge is doing this without retraining — large language models cost millions to train, so any pruning method that requires retraining is impractical.
Two methods solve this problem at scale. SparseGPT (Frantar and Alistarh, 2023) uses second-order information (the Hessian) to optimally update surviving weights as pruned weights are removed. Wanda (Sun et al., 2024) uses a simpler criterion — weight magnitude times input activation norm — that requires no weight updates at all. Both achieve 50% unstructured sparsity on Llama-class models with minimal quality degradation, processing the entire model in minutes on a single GPU.
This post covers the mathematics behind both methods, their implementation, the difference between structured and unstructured pruning, N:M sparsity for Ampere GPUs, and empirical quality-vs-sparsity curves.
Why Pruning Works
1.1 Weight Distribution in Trained Models
After training, the weight matrices in a transformer exhibit a characteristic distribution: most weights cluster near zero, with a long tail of high-magnitude outliers. For a typical Llama 7B weight matrix , the distribution looks like:
- 50% of weights have magnitude less than 0.005
- 90% have magnitude less than 0.02
- 99% have magnitude less than 0.08
- The remaining 1% have magnitudes up to 0.5 or higher
This concentration near zero suggests that many weights can be removed (set to zero) without significantly changing the model’s outputs. The question is which weights to remove and how to compensate for their removal.
1.2 The Pruning Problem
Given a weight matrix and a calibration dataset of inputs , we want to find a sparse weight matrix that minimizes the reconstruction error:
where is the number of non-zero weights we want to keep. This is a combinatorial optimization problem — NP-hard in general. Practical pruning methods approximate it.
1.3 Magnitude Pruning: The Baseline
The simplest approach: remove weights with the smallest absolute values.
import torch
def magnitude_prune(weight, sparsity):
"""Remove the smallest-magnitude weights.
Args:
weight: [d_out, d_in] weight matrix
sparsity: fraction of weights to remove (0.0 to 1.0)
Returns:
pruned weight matrix with zeros in pruned positions
"""
num_params = weight.numel()
num_prune = int(num_params * sparsity)
# Find the threshold: magnitude below which we prune
magnitudes = weight.abs().flatten()
threshold = torch.kthvalue(magnitudes, num_prune).values
# Create mask and apply
mask = weight.abs() >= threshold
return weight * mask
Magnitude pruning works surprisingly well at low sparsity (less than 30%). But it degrades sharply beyond 50% because it ignores a critical factor: the input distribution. A weight with magnitude 0.001 connected to an input feature with activation magnitude 1000 contributes more than a weight with magnitude 0.1 connected to a feature with activation magnitude 0.001.
SparseGPT: Optimal One-Shot Pruning
2.1 The Optimal Brain Surgeon Framework
SparseGPT builds on Optimal Brain Surgeon (OBS), a second-order pruning method from 1993. The key insight: when you prune weight , you should update the remaining weights to compensate. The optimal update depends on the inverse Hessian of the loss with respect to the weights.
For a linear layer with squared error loss, the Hessian with respect to row of is:
where is the matrix of input activations across calibration samples. The factor of 2 comes from the squared error derivative.
When we prune weight (row , column ), the optimal update to the remaining weights in row is:
This update minimizes the increase in squared error caused by removing . The pruning error (increase in loss) is:
The inverse Hessian tells us two things: (1) which weight to prune (the one with the smallest ), and (2) how to update the remaining weights to compensate ().
2.2 The Scalability Problem
Classical OBS requires computing for each row, which costs . For a Llama 7B layer with , this means inverting a matrix — feasible but slow. The real problem is that after pruning one weight, changes, requiring a rank-1 update before selecting the next weight. Pruning weights (where is the sparsity ratio) requires rank-1 updates, each costing . Total: . For and , that is roughly operations per row. Unacceptable.
2.3 SparseGPT’s Column-Wise Algorithm
SparseGPT’s key contribution is an efficient algorithm that processes weights column by column, amortizing the Hessian updates. Instead of selecting the globally optimal weight to prune next, SparseGPT processes columns left-to-right and makes pruning decisions for each column using the current Hessian inverse.
import torch
import torch.nn as nn
def sparsegpt_prune(weight, hessian_inv, sparsity, blocksize=128):
"""SparseGPT: one-shot pruning with Hessian-based weight updates.
Args:
weight: [d_out, d_in] weight matrix
hessian_inv: [d_in, d_in] inverse Hessian (precomputed)
sparsity: fraction of weights to prune
blocksize: number of columns to process at once
Returns:
pruned weight matrix with compensated surviving weights
"""
W = weight.clone()
d_out, d_in = W.shape
# Determine number of weights to prune per row
num_prune_per_row = int(d_in * sparsity)
# Process columns in blocks
for col_start in range(0, d_in, blocksize):
col_end = min(col_start + blocksize, d_in)
block_cols = col_end - col_start
# Extract the block of columns and corresponding Hessian inverse
W_block = W[:, col_start:col_end].clone()
H_inv_block = hessian_inv[col_start:col_end, col_start:col_end]
# Error accumulator for compensating later columns
Err = torch.zeros_like(W_block)
for j in range(block_cols):
col_idx = col_start + j
w_col = W_block[:, j] # [d_out]
h_inv_jj = H_inv_block[j, j] # scalar
# Pruning criterion: magnitude / diagonal Hessian inverse
scores = w_col.abs() ** 2 / h_inv_jj
# Determine which weights in this column to prune
# (simplified: prune if this weight is among the smallest
# across all columns for this row)
prune_mask = _should_prune(W, col_idx, num_prune_per_row)
# For pruned weights: compute the error
Err[:, j] = w_col * prune_mask.float()
W_block[:, j] = w_col * (~prune_mask).float()
# Compensate remaining columns in this block
if j < block_cols - 1:
update = Err[:, j:j+1] / h_inv_jj
h_inv_row = H_inv_block[j, j+1:block_cols]
W_block[:, j+1:block_cols] -= update @ h_inv_row.unsqueeze(0)
# Write back the pruned and compensated block
W[:, col_start:col_end] = W_block
# Compensate all remaining columns (after this block)
if col_end < d_in:
h_inv_cross = hessian_inv[col_start:col_end, col_end:]
W[:, col_end:] -= (Err / H_inv_block.diag().unsqueeze(0)) @ h_inv_cross
return W
def _should_prune(W, col_idx, num_prune_per_row):
"""Determine if weights at col_idx should be pruned.
Returns boolean mask of shape [d_out]."""
row_magnitudes = W.abs()
thresholds = torch.kthvalue(
row_magnitudes, num_prune_per_row, dim=1
).values
return row_magnitudes[:, col_idx] <= thresholds
2.4 Computing the Hessian Inverse
The Hessian is computed from calibration data (typically 128 samples from C4 or WikiText). The inverse is computed via Cholesky decomposition:
def compute_hessian_inverse(activations, damp=0.01):
"""Compute inverse Hessian from calibration activations.
Args:
activations: [n_samples, d_in] input activations
damp: damping factor for numerical stability
Returns:
[d_in, d_in] inverse Hessian
"""
n, d = activations.shape
H = (activations.T @ activations) / n # [d_in, d_in]
# Add damping for numerical stability
H += damp * torch.eye(d, device=H.device) * H.diag().mean()
# Cholesky decomposition: H = L L^T
L = torch.linalg.cholesky(H)
# Inverse via triangular solve: H^{-1} = (L^T)^{-1} L^{-1}
H_inv = torch.cholesky_inverse(L)
return H_inv
The Cholesky decomposition costs . For , that is roughly operations — about 10ms on an A100. This is computed once per layer, so the total cost for all layers in a 7B model is under 1 second.
2.5 Full SparseGPT Pipeline
The complete pipeline processes the model layer-by-layer:
def sparsegpt_full(model, calibration_loader, sparsity=0.5, blocksize=128):
"""Apply SparseGPT to all linear layers in the model.
Args:
model: transformer model
calibration_loader: dataloader yielding calibration inputs
sparsity: target sparsity ratio
blocksize: column block size for processing
"""
# Collect calibration activations by running forward pass
hooks = {}
activations = {}
def make_hook(name):
def hook_fn(module, input_args, output):
if name not in activations:
activations[name] = []
activations[name].append(input_args[0].detach().reshape(-1, input_args[0].shape[-1]))
return hook_fn
# Register hooks on all linear layers
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
hooks[name] = module.register_forward_hook(make_hook(name))
# Run calibration data through model
model.eval()
with torch.no_grad():
for batch in calibration_loader:
model(batch["input_ids"].cuda())
# Remove hooks
for h in hooks.values():
h.remove()
# Prune each linear layer
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# Stack calibration activations
X = torch.cat(activations[name], dim=0) # [n_total, d_in]
# Compute Hessian inverse
H_inv = compute_hessian_inverse(X)
# Prune with SparseGPT
W_pruned = sparsegpt_prune(
module.weight.data, H_inv, sparsity, blocksize
)
module.weight.data = W_pruned
print(f"Pruned {name}: {sparsity*100:.0f}% sparse, "
f"nnz={W_pruned.count_nonzero().item()}")
Wanda: Pruning Without Weight Updates
3.1 The Core Insight
Wanda (Weights AND Activations) makes a simpler observation: the importance of a weight depends on both its magnitude and the magnitude of the input it processes. A small weight connected to a large activation can be more important than a large weight connected to a near-zero activation.
The Wanda score for weight is:
where is the norm of the -th input feature across all calibration samples. Weights with the smallest scores are pruned. No weight updates. No Hessian computation.
3.2 Why This Works
Consider the squared error from pruning weight :
The Wanda score is the square root of this error (up to a constant). So Wanda is actually a first-order approximation of the pruning error — it ranks weights by the error caused by their removal, but without the Hessian correction that SparseGPT uses to compensate surviving weights.
3.3 Implementation
def wanda_prune(weight, activations, sparsity, per_row=True):
"""Wanda pruning: magnitude * activation norm.
Args:
weight: [d_out, d_in] weight matrix
activations: [n_samples, d_in] calibration activations
sparsity: fraction of weights to prune
per_row: if True, prune per-row (preserves structure per neuron)
Returns:
pruned weight matrix (no weight updates applied)
"""
d_out, d_in = weight.shape
# Compute per-feature activation norms
act_norms = activations.norm(dim=0) # [d_in]
# Wanda scores: |w| * ||x||
scores = weight.abs() * act_norms.unsqueeze(0) # [d_out, d_in]
if per_row:
# Prune independently per row (output neuron)
num_prune = int(d_in * sparsity)
# For each row, find the threshold
sorted_scores, _ = scores.sort(dim=1)
thresholds = sorted_scores[:, num_prune - 1].unsqueeze(1)
mask = scores > thresholds # Keep weights above threshold
else:
# Global pruning across the entire matrix
num_prune = int(weight.numel() * sparsity)
flat_scores = scores.flatten()
threshold = torch.kthvalue(flat_scores, num_prune).values
mask = scores > threshold
return weight * mask
def wanda_full(model, calibration_loader, sparsity=0.5):
"""Apply Wanda to all linear layers in the model."""
hooks = {}
activations = {}
def make_hook(name):
def hook_fn(module, input_args, output):
if name not in activations:
activations[name] = []
activations[name].append(
input_args[0].detach().reshape(-1, input_args[0].shape[-1])
)
return hook_fn
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
hooks[name] = module.register_forward_hook(make_hook(name))
model.eval()
with torch.no_grad():
for batch in calibration_loader:
model(batch["input_ids"].cuda())
for h in hooks.values():
h.remove()
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
X = torch.cat(activations[name], dim=0)
W_pruned = wanda_prune(module.weight.data, X, sparsity)
module.weight.data = W_pruned
Wanda is roughly 10x faster than SparseGPT because it skips Hessian computation and weight updates. For a 7B model, Wanda takes about 30 seconds vs. SparseGPT’s 5 minutes on a single A100. The quality difference at 50% unstructured sparsity is typically less than 0.5 perplexity points.
Unstructured vs. Structured Sparsity
4.1 Unstructured Sparsity
Unstructured pruning removes individual weights anywhere in the matrix. The resulting sparse matrix has no regular pattern — zeros are scattered randomly. This gives maximum flexibility for choosing which weights to remove, but provides no speedup on standard hardware. A sparse matrix with 50% zeros still requires the same number of memory accesses on a GPU unless the hardware has explicit sparse support.
def demonstrate_unstructured(weight, sparsity=0.5):
"""Show the random pattern of unstructured sparsity."""
mask = torch.rand_like(weight) > sparsity
sparse_weight = weight * mask
# Count non-zeros
nnz = sparse_weight.count_nonzero().item()
total = weight.numel()
print(f"Non-zeros: {nnz}/{total} ({nnz/total*100:.1f}%)")
# No regular pattern -- sparse operations needed for speedup
# Standard matmul still touches all elements
return sparse_weight
4.2 Structured Sparsity
Structured pruning removes entire rows, columns, or blocks. This directly reduces the matrix dimensions, giving real speedup on any hardware.
Row pruning removes entire output neurons:
def structured_row_prune(weight, activations, sparsity):
"""Remove entire rows (output neurons) from weight matrix.
Args:
weight: [d_out, d_in]
activations: [n_samples, d_in]
sparsity: fraction of rows to remove
Returns:
pruned weight [d_out * (1-sparsity), d_in], kept indices
"""
d_out, d_in = weight.shape
num_keep = int(d_out * (1 - sparsity))
# Score each row by its expected output magnitude
act_norms = activations.norm(dim=0) # [d_in]
row_scores = (weight.abs() * act_norms.unsqueeze(0)).sum(dim=1) # [d_out]
# Keep the highest-scoring rows
_, keep_indices = row_scores.topk(num_keep)
keep_indices = keep_indices.sort().values
return weight[keep_indices], keep_indices
def structured_column_prune(weight, activations, sparsity):
"""Remove entire columns (input features) from weight matrix.
Args:
weight: [d_out, d_in]
activations: [n_samples, d_in]
sparsity: fraction of columns to remove
Returns:
pruned weight [d_out, d_in * (1-sparsity)], kept indices
"""
d_out, d_in = weight.shape
num_keep = int(d_in * (1 - sparsity))
act_norms = activations.norm(dim=0) # [d_in]
col_scores = (weight.abs() * act_norms.unsqueeze(0)).sum(dim=0) # [d_in]
_, keep_indices = col_scores.topk(num_keep)
keep_indices = keep_indices.sort().values
return weight[:, keep_indices], keep_indices
Block pruning removes rectangular blocks:
def block_prune(weight, activations, sparsity, block_size=64):
"""Remove blocks of weights.
Args:
weight: [d_out, d_in]
activations: [n_samples, d_in]
sparsity: fraction of blocks to remove
block_size: size of square blocks
"""
d_out, d_in = weight.shape
act_norms = activations.norm(dim=0)
# Score each block
block_scores = []
block_positions = []
for i in range(0, d_out, block_size):
for j in range(0, d_in, block_size):
i_end = min(i + block_size, d_out)
j_end = min(j + block_size, d_in)
block = weight[i:i_end, j:j_end]
block_act = act_norms[j:j_end]
score = (block.abs() * block_act.unsqueeze(0)).sum().item()
block_scores.append(score)
block_positions.append((i, j, i_end, j_end))
# Prune lowest-scoring blocks
num_blocks = len(block_scores)
num_prune = int(num_blocks * sparsity)
scores_tensor = torch.tensor(block_scores)
_, prune_indices = scores_tensor.topk(num_prune, largest=False)
pruned = weight.clone()
for idx in prune_indices:
i, j, i_end, j_end = block_positions[idx.item()]
pruned[i:i_end, j:j_end] = 0
return pruned
4.3 Quality-Speed Tradeoff
Structured vs Unstructured Pruning at 50% Sparsity
| Method | WikiText-2 PPL | Delta vs Dense |
|---|---|---|
| Dense baseline | 5.68 | 0% |
| Unstructured 50% (SparseGPT) | 5.97 | +5.1% |
| Unstructured 50% (Wanda) | 6.12 | +7.7% |
| Row-structured 50% | 7.84 | +38% |
| Column-structured 50% | 8.21 | +44.5% |
| Block-structured 50% (64x64) | 6.89 | +21.3% |
Structured pruning hurts quality significantly more than unstructured pruning at the same sparsity level. The constraint of removing entire rows or columns means you cannot avoid removing important weights — if a row has one critical weight and 4095 unimportant ones, you still lose the critical one. Block pruning is a middle ground: finer granularity than rows/columns, but coarser than individual weights.
N:M Sparsity on Ampere GPUs
5.1 The Hardware Constraint
NVIDIA Ampere (A100) and later GPUs include hardware support for a specific sparsity pattern: N:M sparsity, where out of every M consecutive weights, exactly N are zero. The most common pattern is 2:4 sparsity: out of every 4 consecutive weights, 2 are zero (50% sparsity). The hardware contains a sparse tensor core that skips the zero multiplications, achieving roughly 2x throughput.
5.2 Why 2:4 Specifically
The 2:4 pattern is a hardware design choice. Each sparse tensor core has a metadata register that stores which 2 of the 4 weights are non-zero, using 2 bits per group. The hardware uses this metadata to select the corresponding input activations, performing only 2 multiplications instead of 4. This gives a 2x speedup with minimal metadata overhead (0.5 bits per weight).
def enforce_nm_sparsity(weight, n=2, m=4):
"""Enforce N:M sparsity pattern.
For every group of M consecutive weights (along d_in),
keep the N largest and zero out the rest.
Args:
weight: [d_out, d_in]
n: number of zeros per group
m: group size
Returns:
weight with N:M sparsity pattern
"""
d_out, d_in = weight.shape
assert d_in % m == 0, f"d_in ({d_in}) must be divisible by m ({m})"
# Reshape to [d_out, d_in/m, m]
W = weight.reshape(d_out, d_in // m, m)
# Find the top (m-n) weights in each group (the ones to keep)
keep_count = m - n # 2 for 2:4 pattern
_, top_indices = W.abs().topk(keep_count, dim=2)
# Create mask
mask = torch.zeros_like(W)
mask.scatter_(2, top_indices, 1.0)
# Apply mask
result = (W * mask).reshape(d_out, d_in)
return result
5.3 Combining SparseGPT/Wanda with N:M
Both SparseGPT and Wanda can be adapted to produce N:M patterns instead of arbitrary unstructured sparsity:
def wanda_nm_prune(weight, activations, n=2, m=4):
"""Wanda pruning with N:M sparsity constraint.
Instead of pruning the globally lowest-scoring weights,
prune the N lowest-scoring weights in each group of M.
"""
d_out, d_in = weight.shape
# Compute Wanda scores
act_norms = activations.norm(dim=0) # [d_in]
scores = weight.abs() * act_norms.unsqueeze(0) # [d_out, d_in]
# Reshape to groups of M
S = scores.reshape(d_out, d_in // m, m)
W = weight.reshape(d_out, d_in // m, m)
# In each group, keep the (m-n) highest-scoring weights
keep_count = m - n
_, top_indices = S.topk(keep_count, dim=2)
mask = torch.zeros_like(W)
mask.scatter_(2, top_indices, 1.0)
result = (W * mask).reshape(d_out, d_in)
return result
def sparsegpt_nm_prune(weight, hessian_inv, n=2, m=4):
"""SparseGPT with N:M constraint.
Process in groups of M columns. Within each group,
prune the N columns with highest error, then compensate.
"""
d_out, d_in = weight.shape
W = weight.clone()
for group_start in range(0, d_in, m):
group_end = group_start + m
W_group = W[:, group_start:group_end].clone()
H_inv_group = hessian_inv[group_start:group_end, group_start:group_end]
# Score each column in the group
diag = H_inv_group.diag()
col_scores = (W_group ** 2).sum(dim=0) / diag # [m]
# Keep top (m-n) scoring columns
keep_count = m - n
_, keep_idx = col_scores.topk(keep_count)
prune_idx = torch.tensor([i for i in range(m) if i not in keep_idx])
# Zero out pruned columns and compensate
for j in prune_idx:
err = W_group[:, j].clone()
W_group[:, j] = 0
# Compensate surviving columns
h_jj = H_inv_group[j, j]
for k in keep_idx:
W_group[:, k] -= (err / h_jj) * H_inv_group[j, k]
W[:, group_start:group_end] = W_group
# Compensate future groups
if group_end < d_in:
for j in prune_idx:
err = weight[:, group_start + j]
h_jj = hessian_inv[group_start + j, group_start + j]
h_cross = hessian_inv[group_start + j, group_end:]
W[:, group_end:] -= (err / h_jj).unsqueeze(1) * h_cross.unsqueeze(0)
return W
5.4 N:M Performance on Ampere
2:4 Sparsity Performance on A100
| Configuration | Throughput (tok/s) | Delta vs Dense |
|---|---|---|
| Dense (cuBLAS) | 2,847 | baseline |
| 2:4 Sparse (cuSPARSELt) | 4,952 | +73.9% |
| 2:4 Sparse + INT8 weights | 7,208 | +153% |
| Unstructured 50% (no HW support) | 2,891 | +1.5% |
Unstructured 50% sparsity gives almost no speedup without dedicated hardware support. The GPU still loads the full matrix from memory and multiplies by zeros. N:M sparsity with cuSPARSELt achieves real speedups because the hardware skips zero entries at the tensor core level.
Quality vs. Sparsity Curves
6.1 Measuring Degradation
The standard evaluation protocol for pruning: measure perplexity on WikiText-2 and zero-shot accuracy on downstream tasks (ARC, HellaSwag, WinoGrande, PIQA) at increasing sparsity levels.
def evaluate_pruning_sweep(model, tokenizer, calibration_loader,
eval_dataset, sparsity_levels):
"""Evaluate model quality across sparsity levels."""
import copy
from lm_eval import evaluator
results = []
for sparsity in sparsity_levels:
# Clone the model
pruned_model = copy.deepcopy(model)
# Apply Wanda pruning
wanda_full(pruned_model, calibration_loader, sparsity=sparsity)
# Measure perplexity
ppl = compute_perplexity(pruned_model, tokenizer, eval_dataset)
# Measure zero-shot accuracy
accuracy = evaluator.simple_evaluate(
model=pruned_model,
tasks=["arc_easy", "hellaswag", "winogrande", "piqa"],
batch_size=32
)
results.append({
"sparsity": sparsity,
"perplexity": ppl,
"avg_accuracy": accuracy["results"]["average"]
})
print(f"Sparsity: {sparsity:.0%} | PPL: {ppl:.2f} | "
f"Avg Acc: {accuracy['results']['average']:.1%}")
del pruned_model
return results
6.2 Empirical Results
Llama 7B: Perplexity vs Sparsity
| Metric | 0% | 10% | 20% | 30% | 40% | 50% | 60% | 70% | 80% |
|---|---|---|---|---|---|---|---|---|---|
| SparseGPT (unstructured) | |||||||||
| Wanda (unstructured) | |||||||||
| Magnitude (unstructured) | |||||||||
| 2:4 N:M (Wanda) |
Key observations from the empirical curves:
-
Up to 30% sparsity: All methods perform similarly. There is genuine redundancy in the model that any method can find.
-
30-50% sparsity: SparseGPT’s Hessian compensation provides a measurable advantage (0.15-0.3 perplexity points over Wanda, 0.5-1.5 over magnitude).
-
50-60% sparsity: The gap widens significantly. SparseGPT’s weight updates become critical for maintaining quality.
-
Beyond 70%: All one-shot methods degrade rapidly. At this level, retraining or iterative pruning is necessary.
6.3 Task-Specific Analysis
Llama 7B: Zero-Shot Accuracy at 50% Sparsity
| Metric | ARC-Easy | ARC-Challenge | HellaSwag | WinoGrande | PIQA | Average |
|---|---|---|---|---|---|---|
| Dense | ||||||
| SparseGPT 50% | ||||||
| Wanda 50% |
At 50% sparsity, both methods retain roughly 96-97% of the dense model’s average accuracy. The degradation is relatively uniform across tasks, with ARC-Challenge (the hardest reasoning task) showing the largest relative drop.
Layer-Wise Sensitivity
7.1 Not All Layers Are Equal
Different layers have different sensitivity to pruning. Early layers (near the embedding) and late layers (near the output head) are typically more sensitive than middle layers.
def layer_sensitivity_analysis(model, calibration_loader,
eval_dataset, tokenizer):
"""Measure per-layer sensitivity to pruning.
For each layer, prune only that layer to 50% while keeping
all other layers dense. Measure the perplexity increase.
"""
import copy
base_ppl = compute_perplexity(model, tokenizer, eval_dataset)
sensitivities = {}
for name, module in model.named_modules():
if not isinstance(module, nn.Linear):
continue
# Clone model, prune only this layer
test_model = copy.deepcopy(model)
target = dict(test_model.named_modules())[name]
# Quick magnitude prune for sensitivity analysis
mask = magnitude_prune(target.weight.data, sparsity=0.5)
target.weight.data = mask
ppl = compute_perplexity(test_model, tokenizer, eval_dataset)
sensitivities[name] = ppl - base_ppl
del test_model
# Sort by sensitivity
sorted_layers = sorted(sensitivities.items(), key=lambda x: x[1], reverse=True)
for name, delta_ppl in sorted_layers[:10]:
print(f"{name}: +{delta_ppl:.2f} perplexity")
return sensitivities
7.2 Non-Uniform Sparsity
The sensitivity analysis motivates non-uniform sparsity: prune sensitive layers less aggressively and insensitive layers more aggressively, keeping the total parameter count the same.
def allocate_sparsity(sensitivities, target_avg_sparsity=0.5,
min_sparsity=0.2, max_sparsity=0.8):
"""Allocate per-layer sparsity inversely proportional to sensitivity.
High-sensitivity layers get low sparsity.
Low-sensitivity layers get high sparsity.
Average across all layers equals target_avg_sparsity.
"""
layers = list(sensitivities.keys())
sens_values = torch.tensor([sensitivities[l] for l in layers])
# Invert: high sensitivity -> low sparsity
inv_sens = 1.0 / (sens_values + 1e-8)
# Normalize to mean = target
sparsity_ratios = inv_sens / inv_sens.mean() * target_avg_sparsity
# Clip to valid range
sparsity_ratios = sparsity_ratios.clamp(min_sparsity, max_sparsity)
# Adjust to hit target average
current_avg = sparsity_ratios.mean().item()
sparsity_ratios *= target_avg_sparsity / current_avg
sparsity_ratios = sparsity_ratios.clamp(min_sparsity, max_sparsity)
return {layer: s.item() for layer, s in zip(layers, sparsity_ratios)}
Non-uniform sparsity allocations typically improve perplexity by 0.1-0.3 points over uniform sparsity at the same average sparsity level. The first and last 2-3 transformer blocks are consistently the most sensitive — they should be pruned at 20-30% while middle blocks can tolerate 60-70%.
Pruning + Quantization: Stacking Compression
8.1 Orthogonal Techniques
Pruning (removing weights) and quantization (reducing bit-width) are largely orthogonal. A 50% sparse model quantized to 4 bits uses roughly % of the original model’s memory. The quality loss compounds, but less than you might expect because they remove different types of redundancy.
def prune_then_quantize(model, calibration_loader, prune_sparsity=0.5):
"""Apply Wanda pruning followed by GPTQ quantization."""
# Step 1: Prune
wanda_full(model, calibration_loader, sparsity=prune_sparsity)
# Step 2: Quantize surviving weights to 4-bit
# (using GPTQ or similar -- simplified here)
for name, module in model.named_modules():
if isinstance(module, nn.Linear):
W = module.weight.data
# Only quantize non-zero weights
mask = W != 0
W_nonzero = W[mask]
# Symmetric 4-bit quantization
max_val = W_nonzero.abs().max()
scale = max_val / 7 # 4-bit signed: -8 to 7
W_quant = torch.round(W / scale).clamp(-8, 7) * scale
W_quant[~mask] = 0 # Keep pruned weights at zero
module.weight.data = W_quant
def measure_compression(model):
"""Measure effective compression ratio."""
total_params = 0
nonzero_params = 0
for p in model.parameters():
total_params += p.numel()
nonzero_params += p.count_nonzero().item()
sparsity = 1 - nonzero_params / total_params
# Assuming 4-bit for non-zero, 0 bits for zero
effective_bits = (1 - sparsity) * 4
compression = 16 / effective_bits # vs FP16 baseline
print(f"Sparsity: {sparsity:.1%}")
print(f"Effective bits per param: {effective_bits:.2f}")
print(f"Compression ratio: {compression:.1f}x")
8.2 Combined Results
Compression Stack: Pruning + Quantization on Llama 7B
| Configuration | Perplexity | Memory | Compression |
|---|---|---|---|
| Dense FP16 (baseline) | 5.68 | 14.0 GB | 1.0x |
| 50% sparse FP16 (Wanda) | 6.12 | 7.0 GB | 2.0x |
| Dense INT4 (GPTQ) | 5.85 | 3.5 GB | 4.0x |
| 50% sparse + INT4 | 6.38 | 1.75 GB | 8.0x |
| 2:4 sparse + INT8 | 6.01 | 3.5 GB | 4.0x (2x speed) |
The 2:4 sparse + INT8 combination is particularly attractive: it achieves 4x memory compression with actual 2x inference speedup (via sparse tensor cores), and the quality degradation is only 0.33 perplexity points.
Iterative and Recovery-Based Pruning
9.1 Sparse Fine-Tuning
One-shot pruning is fast but leaves quality on the table. If you can afford some training budget, sparse fine-tuning recovers much of the lost quality:
def sparse_finetune(model, train_loader, optimizer, mask, epochs=2):
"""Fine-tune a pruned model while maintaining the sparsity pattern.
Args:
model: pruned model
train_loader: training data
optimizer: optimizer (typically AdamW with low LR)
mask: dict mapping parameter names to binary masks
epochs: number of fine-tuning epochs
"""
model.train()
for epoch in range(epochs):
total_loss = 0
for batch in train_loader:
input_ids = batch["input_ids"].cuda()
labels = batch["labels"].cuda()
outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward()
# Zero out gradients for pruned weights
with torch.no_grad():
for name, param in model.named_parameters():
if name in mask:
param.grad *= mask[name]
optimizer.step()
optimizer.zero_grad()
# Re-apply mask to handle numerical drift
with torch.no_grad():
for name, param in model.named_parameters():
if name in mask:
param.data *= mask[name]
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")
9.2 Iterative Magnitude Pruning (IMP)
The lottery ticket hypothesis approach: train, prune a small fraction, retrain, prune more, repeat. This finds better sparse networks than one-shot methods but costs much more compute:
def iterative_magnitude_pruning(model, train_loader, eval_loader,
target_sparsity=0.9, prune_steps=10,
finetune_epochs=2):
"""Iterative magnitude pruning with retraining.
Prune in small steps, retraining between each step.
"""
current_sparsity = 0.0
sparsity_per_step = 1 - (1 - target_sparsity) ** (1 / prune_steps)
mask = {}
for name, param in model.named_parameters():
if "weight" in name and param.dim() == 2:
mask[name] = torch.ones_like(param, dtype=torch.bool)
for step in range(prune_steps):
# Prune: remove smallest magnitude surviving weights
with torch.no_grad():
for name, param in model.named_parameters():
if name not in mask:
continue
surviving = param[mask[name]]
if surviving.numel() == 0:
continue
num_prune = int(surviving.numel() * sparsity_per_step)
threshold = torch.kthvalue(surviving.abs().flatten(), num_prune).values
new_prune = (param.abs() < threshold) & mask[name]
mask[name] &= ~new_prune
param.data *= mask[name].float()
# Compute current sparsity
total = sum(m.numel() for m in mask.values())
nonzero = sum(m.sum().item() for m in mask.values())
current_sparsity = 1 - nonzero / total
# Retrain
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
sparse_finetune(model, train_loader, optimizer, mask, finetune_epochs)
ppl = compute_perplexity(model, None, eval_loader)
print(f"Step {step+1}/{prune_steps}: "
f"sparsity={current_sparsity:.1%}, ppl={ppl:.2f}")
Iterative pruning at 90% sparsity achieves perplexity close to one-shot pruning at 50%. The cost is 10-20x more training compute. For LLMs, this is rarely practical — the one-shot methods (SparseGPT, Wanda) dominate in practice because they require zero training.
Practical Deployment
10.1 Sparse Format Conversion
For deployment, sparse weights need to be stored in a compressed format:
def convert_to_csr(weight):
"""Convert dense weight to CSR (Compressed Sparse Row) format.
CSR stores only non-zero values plus index arrays.
Memory: nnz * (value_bytes + index_bytes) + d_out * pointer_bytes
"""
sparse = weight.to_sparse_csr()
crow_indices = sparse.crow_indices() # [d_out + 1] row pointers
col_indices = sparse.col_indices() # [nnz] column indices
values = sparse.values() # [nnz] non-zero values
# Memory comparison
dense_bytes = weight.numel() * weight.element_size()
sparse_bytes = (values.numel() * values.element_size() +
col_indices.numel() * 4 + # int32 indices
crow_indices.numel() * 4)
print(f"Dense: {dense_bytes / 1e6:.1f} MB")
print(f"CSR: {sparse_bytes / 1e6:.1f} MB")
print(f"Ratio: {sparse_bytes / dense_bytes:.2f}x")
return sparse
def convert_to_nm_format(weight, n=2, m=4):
"""Convert to N:M sparse format for cuSPARSELt.
Stores non-zero values (compressed) + 2-bit metadata per group.
"""
d_out, d_in = weight.shape
num_groups = d_in // m
keep_per_group = m - n # 2 for 2:4
# Extract non-zero values and metadata
W_groups = weight.reshape(d_out, num_groups, m)
_, top_idx = W_groups.abs().topk(keep_per_group, dim=2)
# Compressed values: only store non-zeros
values = torch.gather(W_groups, 2, top_idx) # [d_out, num_groups, keep]
# Metadata: which positions are non-zero (2 bits per group for 2:4)
metadata = top_idx # [d_out, num_groups, keep]
compressed_bytes = values.numel() * values.element_size()
metadata_bytes = d_out * num_groups * 1 # ~2 bits per group, packed
dense_bytes = weight.numel() * weight.element_size()
print(f"Dense: {dense_bytes / 1e6:.1f} MB")
print(f"2:4 format: {(compressed_bytes + metadata_bytes) / 1e6:.1f} MB")
return values, metadata
10.2 End-to-End Pruning Pipeline
def full_pruning_pipeline(model_name, sparsity=0.5, method="wanda",
nm_sparsity=False, quantize=False):
"""Complete pruning pipeline from HuggingFace model to deployment."""
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Prepare calibration data
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
calibration_texts = dataset["text"][:128]
calibration_tokens = tokenizer(
calibration_texts, return_tensors="pt",
padding=True, truncation=True, max_length=2048
)
calibration_loader = [{"input_ids": calibration_tokens["input_ids"]}]
# Prune
if method == "wanda":
if nm_sparsity:
wanda_nm_full(model, calibration_loader)
else:
wanda_full(model, calibration_loader, sparsity=sparsity)
elif method == "sparsegpt":
sparsegpt_full(model, calibration_loader, sparsity=sparsity)
# Optional quantization
if quantize:
prune_then_quantize(model, calibration_loader, prune_sparsity=0)
# Evaluate
eval_dataset = load_dataset(
"wikitext", "wikitext-2-raw-v1", split="test"
)
ppl = compute_perplexity(model, tokenizer, eval_dataset)
print(f"Final perplexity: {ppl:.2f}")
# Save
output_dir = f"{model_name}-{method}-{sparsity}"
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
return model
Comparison Summary
Pruning Methods Comparison (Llama 7B, 50% Sparsity)
| Method | Perplexity | Time | Weight Updates | Best For |
|---|---|---|---|---|
| Magnitude pruning | 7.34 | 0 min | None | Worst quality |
| SparseGPT | 5.97 | 5 min | Hessian-based | Best quality |
| Wanda | 6.12 | 0.5 min | None | Best speed/quality |
| Wanda 2:4 N:M | 6.25 | 0.5 min | None | Best deployment |
When to use each method:
- SparseGPT: When quality is paramount and you can afford 5-10 minutes of compute per model. Best for high sparsity (60%+).
- Wanda: Default choice. Nearly as good as SparseGPT, 10x faster, simpler implementation.
- Wanda 2:4: When deploying to Ampere/Hopper GPUs and you need actual inference speedup, not just compression.
- Structured pruning: When you need speedup on hardware without sparse tensor core support and can tolerate more quality loss.
The field is converging on a practical recipe: Wanda or SparseGPT for initial pruning, 2:4 N:M format for deployment on NVIDIA GPUs, optional sparse fine-tuning for recovering the last fraction of quality, and stacking with INT4/INT8 quantization for maximum compression.
References
- Frantar, E. and Alistarh, D. “SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot.” ICML 2023.
- Sun, M. et al. “A Simple and Effective Pruning Approach for Large Language Models.” ICLR 2024.
- Mishra, A. et al. “Accelerating Sparse Deep Neural Networks.” arXiv 2021.
- Pool, J. and Yu, C. “Channel Permutations for N:M Sparsity.” NeurIPS 2021.
- Frankle, J. and Carlin, M. “The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks.” ICLR 2019.