Part of Series Transformer Anatomy 32 of 36
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Distributed Data Parallel: Gradient Synchronization, Bucket All-Reduce, and Overlap with Backward 21 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 22 Dropout and Regularization in Transformers: Where It Helps, Where It Hurts 23 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 24 Mixed Precision Training: BF16 Forward, FP32 Master Weights, and the Precision Hierarchy 25 Token Prediction Heads: Next-Token, Multi-Token, and Classifier Heads 26 Mixture of Depths: Conditional Computation Per Layer for Faster Inference 27 Sparse Attention Patterns: Local, Strided, Hash-Based, and Learnable Sparsity 28 Rotary Position Embedding: The Complete Mathematical Derivation 29 Knowledge Distillation: Training Small Models to Match Large Ones 30 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search 31 Pruning at Scale: SparseGPT, Wanda, and Structured Removal of Redundant Parameters 32 The Transformer in 2026: What Changed, What Stayed, and What's Next 33 Data Loading: Tokenization, Sequence Packing, Padding Strategies, and Attention Masks 34 The FlashAttention Backward Pass: Recomputation, Memory Savings, and the 33% Compute Overhead 35 The Inference Engine: Token Generation Loop, KV Cache Management, and Autoregressive Decoding 36 Tensor Parallelism Implementation: Splitting Weights Across GPUs for Training and Inference

A 70B parameter model has 70 billion weights. Not all of them matter. Pruning removes the ones that contribute least to output quality, producing a sparser model that uses less memory and (with hardware support) runs faster. The challenge is doing this without retraining — large language models cost millions to train, so any pruning method that requires retraining is impractical.

Two methods solve this problem at scale. SparseGPT (Frantar and Alistarh, 2023) uses second-order information (the Hessian) to optimally update surviving weights as pruned weights are removed. Wanda (Sun et al., 2024) uses a simpler criterion — weight magnitude times input activation norm — that requires no weight updates at all. Both achieve 50% unstructured sparsity on Llama-class models with minimal quality degradation, processing the entire model in minutes on a single GPU.

This post covers the mathematics behind both methods, their implementation, the difference between structured and unstructured pruning, N:M sparsity for Ampere GPUs, and empirical quality-vs-sparsity curves.

Why Pruning Works

1.1 Weight Distribution in Trained Models

After training, the weight matrices in a transformer exhibit a characteristic distribution: most weights cluster near zero, with a long tail of high-magnitude outliers. For a typical Llama 7B weight matrix WR4096×4096W \in \mathbb{R}^{4096 \times 4096}, the distribution looks like:

  • 50% of weights have magnitude less than 0.005
  • 90% have magnitude less than 0.02
  • 99% have magnitude less than 0.08
  • The remaining 1% have magnitudes up to 0.5 or higher

This concentration near zero suggests that many weights can be removed (set to zero) without significantly changing the model’s outputs. The question is which weights to remove and how to compensate for their removal.

1.2 The Pruning Problem

Given a weight matrix WRdout×dinW \in \mathbb{R}^{d_{out} \times d_{in}} and a calibration dataset of inputs XRn×dinX \in \mathbb{R}^{n \times d_{in}}, we want to find a sparse weight matrix W^\hat{W} that minimizes the reconstruction error:

minW^WXW^X22subject toW^0k\min_{\hat{W}} \| WX - \hat{W}X \|_2^2 \quad \text{subject to} \quad \|\hat{W}\|_0 \leq k

where kk is the number of non-zero weights we want to keep. This is a combinatorial optimization problem — NP-hard in general. Practical pruning methods approximate it.

1.3 Magnitude Pruning: The Baseline

The simplest approach: remove weights with the smallest absolute values.

import torch

def magnitude_prune(weight, sparsity):
    """Remove the smallest-magnitude weights.

    Args:
        weight: [d_out, d_in] weight matrix
        sparsity: fraction of weights to remove (0.0 to 1.0)
    Returns:
        pruned weight matrix with zeros in pruned positions
    """
    num_params = weight.numel()
    num_prune = int(num_params * sparsity)

    # Find the threshold: magnitude below which we prune
    magnitudes = weight.abs().flatten()
    threshold = torch.kthvalue(magnitudes, num_prune).values

    # Create mask and apply
    mask = weight.abs() >= threshold
    return weight * mask

Magnitude pruning works surprisingly well at low sparsity (less than 30%). But it degrades sharply beyond 50% because it ignores a critical factor: the input distribution. A weight with magnitude 0.001 connected to an input feature with activation magnitude 1000 contributes more than a weight with magnitude 0.1 connected to a feature with activation magnitude 0.001.

SparseGPT: Optimal One-Shot Pruning

2.1 The Optimal Brain Surgeon Framework

SparseGPT builds on Optimal Brain Surgeon (OBS), a second-order pruning method from 1993. The key insight: when you prune weight wijw_{ij}, you should update the remaining weights to compensate. The optimal update depends on the inverse Hessian of the loss with respect to the weights.

For a linear layer y=Wxy = Wx with squared error loss, the Hessian with respect to row wiw_i of WW is:

H=2nXXTH = \frac{2}{n} X X^T

where XRdin×nX \in \mathbb{R}^{d_{in} \times n} is the matrix of input activations across nn calibration samples. The factor of 2 comes from the squared error derivative.

When we prune weight wipw_{ip} (row ii, column pp), the optimal update to the remaining weights in row ii is:

δi=wip[H1]ppH:,p1\delta_i = -\frac{w_{ip}}{[H^{-1}]_{pp}} \cdot H^{-1}_{:,p}

This update minimizes the increase in squared error caused by removing wipw_{ip}. The pruning error (increase in loss) is:

ΔLp=wip22[H1]pp\Delta \mathcal{L}_p = \frac{w_{ip}^2}{2 [H^{-1}]_{pp}}

ℹ️ Note

The inverse Hessian tells us two things: (1) which weight to prune (the one with the smallest w2[H1]pp\frac{w^2}{[H^{-1}]_{pp}}), and (2) how to update the remaining weights to compensate (δi=wip[H1]ppH:,p1\delta_i = -\frac{w_{ip}}{[H^{-1}]_{pp}} \cdot H^{-1}_{:,p}).

2.2 The Scalability Problem

Classical OBS requires computing H1H^{-1} for each row, which costs O(din3)O(d_{in}^3). For a Llama 7B layer with din=4096d_{in} = 4096, this means inverting a 4096×40964096 \times 4096 matrix — feasible but slow. The real problem is that after pruning one weight, H1H^{-1} changes, requiring a rank-1 update before selecting the next weight. Pruning din×sd_{in} \times s weights (where ss is the sparsity ratio) requires O(din2×s)O(d_{in}^2 \times s) rank-1 updates, each costing O(din2)O(d_{in}^2). Total: O(din4×s)O(d_{in}^4 \times s). For din=4096d_{in} = 4096 and s=0.5s = 0.5, that is roughly 3.5×10143.5 \times 10^{14} operations per row. Unacceptable.

2.3 SparseGPT’s Column-Wise Algorithm

SparseGPT’s key contribution is an efficient algorithm that processes weights column by column, amortizing the Hessian updates. Instead of selecting the globally optimal weight to prune next, SparseGPT processes columns left-to-right and makes pruning decisions for each column using the current Hessian inverse.

import torch
import torch.nn as nn

def sparsegpt_prune(weight, hessian_inv, sparsity, blocksize=128):
    """SparseGPT: one-shot pruning with Hessian-based weight updates.

    Args:
        weight: [d_out, d_in] weight matrix
        hessian_inv: [d_in, d_in] inverse Hessian (precomputed)
        sparsity: fraction of weights to prune
        blocksize: number of columns to process at once
    Returns:
        pruned weight matrix with compensated surviving weights
    """
    W = weight.clone()
    d_out, d_in = W.shape

    # Determine number of weights to prune per row
    num_prune_per_row = int(d_in * sparsity)

    # Process columns in blocks
    for col_start in range(0, d_in, blocksize):
        col_end = min(col_start + blocksize, d_in)
        block_cols = col_end - col_start

        # Extract the block of columns and corresponding Hessian inverse
        W_block = W[:, col_start:col_end].clone()
        H_inv_block = hessian_inv[col_start:col_end, col_start:col_end]

        # Error accumulator for compensating later columns
        Err = torch.zeros_like(W_block)

        for j in range(block_cols):
            col_idx = col_start + j
            w_col = W_block[:, j]       # [d_out]
            h_inv_jj = H_inv_block[j, j] # scalar

            # Pruning criterion: magnitude / diagonal Hessian inverse
            scores = w_col.abs() ** 2 / h_inv_jj

            # Determine which weights in this column to prune
            # (simplified: prune if this weight is among the smallest
            #  across all columns for this row)
            prune_mask = _should_prune(W, col_idx, num_prune_per_row)

            # For pruned weights: compute the error
            Err[:, j] = w_col * prune_mask.float()
            W_block[:, j] = w_col * (~prune_mask).float()

            # Compensate remaining columns in this block
            if j < block_cols - 1:
                update = Err[:, j:j+1] / h_inv_jj
                h_inv_row = H_inv_block[j, j+1:block_cols]
                W_block[:, j+1:block_cols] -= update @ h_inv_row.unsqueeze(0)

        # Write back the pruned and compensated block
        W[:, col_start:col_end] = W_block

        # Compensate all remaining columns (after this block)
        if col_end < d_in:
            h_inv_cross = hessian_inv[col_start:col_end, col_end:]
            W[:, col_end:] -= (Err / H_inv_block.diag().unsqueeze(0)) @ h_inv_cross

    return W

def _should_prune(W, col_idx, num_prune_per_row):
    """Determine if weights at col_idx should be pruned.
    Returns boolean mask of shape [d_out]."""
    row_magnitudes = W.abs()
    thresholds = torch.kthvalue(
        row_magnitudes, num_prune_per_row, dim=1
    ).values
    return row_magnitudes[:, col_idx] <= thresholds

2.4 Computing the Hessian Inverse

The Hessian H=XXTH = XX^T is computed from calibration data (typically 128 samples from C4 or WikiText). The inverse is computed via Cholesky decomposition:

def compute_hessian_inverse(activations, damp=0.01):
    """Compute inverse Hessian from calibration activations.

    Args:
        activations: [n_samples, d_in] input activations
        damp: damping factor for numerical stability
    Returns:
        [d_in, d_in] inverse Hessian
    """
    n, d = activations.shape
    H = (activations.T @ activations) / n  # [d_in, d_in]

    # Add damping for numerical stability
    H += damp * torch.eye(d, device=H.device) * H.diag().mean()

    # Cholesky decomposition: H = L L^T
    L = torch.linalg.cholesky(H)

    # Inverse via triangular solve: H^{-1} = (L^T)^{-1} L^{-1}
    H_inv = torch.cholesky_inverse(L)

    return H_inv
Performance

The Cholesky decomposition costs O(din3/3)O(d_{in}^3 / 3). For din=4096d_{in} = 4096, that is roughly 2.3×10102.3 \times 10^{10} operations — about 10ms on an A100. This is computed once per layer, so the total cost for all layers in a 7B model is under 1 second.

2.5 Full SparseGPT Pipeline

The complete pipeline processes the model layer-by-layer:

def sparsegpt_full(model, calibration_loader, sparsity=0.5, blocksize=128):
    """Apply SparseGPT to all linear layers in the model.

    Args:
        model: transformer model
        calibration_loader: dataloader yielding calibration inputs
        sparsity: target sparsity ratio
        blocksize: column block size for processing
    """
    # Collect calibration activations by running forward pass
    hooks = {}
    activations = {}

    def make_hook(name):
        def hook_fn(module, input_args, output):
            if name not in activations:
                activations[name] = []
            activations[name].append(input_args[0].detach().reshape(-1, input_args[0].shape[-1]))
        return hook_fn

    # Register hooks on all linear layers
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            hooks[name] = module.register_forward_hook(make_hook(name))

    # Run calibration data through model
    model.eval()
    with torch.no_grad():
        for batch in calibration_loader:
            model(batch["input_ids"].cuda())

    # Remove hooks
    for h in hooks.values():
        h.remove()

    # Prune each linear layer
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            # Stack calibration activations
            X = torch.cat(activations[name], dim=0)  # [n_total, d_in]

            # Compute Hessian inverse
            H_inv = compute_hessian_inverse(X)

            # Prune with SparseGPT
            W_pruned = sparsegpt_prune(
                module.weight.data, H_inv, sparsity, blocksize
            )
            module.weight.data = W_pruned

            print(f"Pruned {name}: {sparsity*100:.0f}% sparse, "
                  f"nnz={W_pruned.count_nonzero().item()}")

Wanda: Pruning Without Weight Updates

3.1 The Core Insight

Wanda (Weights AND Activations) makes a simpler observation: the importance of a weight depends on both its magnitude and the magnitude of the input it processes. A small weight connected to a large activation can be more important than a large weight connected to a near-zero activation.

The Wanda score for weight wijw_{ij} is:

Sij=wijXj2S_{ij} = |w_{ij}| \cdot \|X_j\|_2

where Xj2\|X_j\|_2 is the 2\ell_2 norm of the jj-th input feature across all calibration samples. Weights with the smallest scores are pruned. No weight updates. No Hessian computation.

3.2 Why This Works

Consider the squared error from pruning weight wijw_{ij}:

ΔLij=1nk=1n(wijxkj)2=wij21nXj22\Delta \mathcal{L}_{ij} = \frac{1}{n} \sum_{k=1}^{n} (w_{ij} \cdot x_{kj})^2 = w_{ij}^2 \cdot \frac{1}{n} \|X_j\|_2^2

The Wanda score wijXj2|w_{ij}| \cdot \|X_j\|_2 is the square root of this error (up to a constant). So Wanda is actually a first-order approximation of the pruning error — it ranks weights by the error caused by their removal, but without the Hessian correction that SparseGPT uses to compensate surviving weights.

3.3 Implementation

def wanda_prune(weight, activations, sparsity, per_row=True):
    """Wanda pruning: magnitude * activation norm.

    Args:
        weight: [d_out, d_in] weight matrix
        activations: [n_samples, d_in] calibration activations
        sparsity: fraction of weights to prune
        per_row: if True, prune per-row (preserves structure per neuron)
    Returns:
        pruned weight matrix (no weight updates applied)
    """
    d_out, d_in = weight.shape

    # Compute per-feature activation norms
    act_norms = activations.norm(dim=0)  # [d_in]

    # Wanda scores: |w| * ||x||
    scores = weight.abs() * act_norms.unsqueeze(0)  # [d_out, d_in]

    if per_row:
        # Prune independently per row (output neuron)
        num_prune = int(d_in * sparsity)

        # For each row, find the threshold
        sorted_scores, _ = scores.sort(dim=1)
        thresholds = sorted_scores[:, num_prune - 1].unsqueeze(1)

        mask = scores > thresholds  # Keep weights above threshold
    else:
        # Global pruning across the entire matrix
        num_prune = int(weight.numel() * sparsity)
        flat_scores = scores.flatten()
        threshold = torch.kthvalue(flat_scores, num_prune).values
        mask = scores > threshold

    return weight * mask

def wanda_full(model, calibration_loader, sparsity=0.5):
    """Apply Wanda to all linear layers in the model."""
    hooks = {}
    activations = {}

    def make_hook(name):
        def hook_fn(module, input_args, output):
            if name not in activations:
                activations[name] = []
            activations[name].append(
                input_args[0].detach().reshape(-1, input_args[0].shape[-1])
            )
        return hook_fn

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            hooks[name] = module.register_forward_hook(make_hook(name))

    model.eval()
    with torch.no_grad():
        for batch in calibration_loader:
            model(batch["input_ids"].cuda())

    for h in hooks.values():
        h.remove()

    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            X = torch.cat(activations[name], dim=0)
            W_pruned = wanda_prune(module.weight.data, X, sparsity)
            module.weight.data = W_pruned
💡 Tip

Wanda is roughly 10x faster than SparseGPT because it skips Hessian computation and weight updates. For a 7B model, Wanda takes about 30 seconds vs. SparseGPT’s 5 minutes on a single A100. The quality difference at 50% unstructured sparsity is typically less than 0.5 perplexity points.

Unstructured vs. Structured Sparsity

4.1 Unstructured Sparsity

Unstructured pruning removes individual weights anywhere in the matrix. The resulting sparse matrix has no regular pattern — zeros are scattered randomly. This gives maximum flexibility for choosing which weights to remove, but provides no speedup on standard hardware. A sparse matrix with 50% zeros still requires the same number of memory accesses on a GPU unless the hardware has explicit sparse support.

def demonstrate_unstructured(weight, sparsity=0.5):
    """Show the random pattern of unstructured sparsity."""
    mask = torch.rand_like(weight) > sparsity
    sparse_weight = weight * mask

    # Count non-zeros
    nnz = sparse_weight.count_nonzero().item()
    total = weight.numel()
    print(f"Non-zeros: {nnz}/{total} ({nnz/total*100:.1f}%)")

    # No regular pattern -- sparse operations needed for speedup
    # Standard matmul still touches all elements
    return sparse_weight

4.2 Structured Sparsity

Structured pruning removes entire rows, columns, or blocks. This directly reduces the matrix dimensions, giving real speedup on any hardware.

Row pruning removes entire output neurons:

def structured_row_prune(weight, activations, sparsity):
    """Remove entire rows (output neurons) from weight matrix.

    Args:
        weight: [d_out, d_in]
        activations: [n_samples, d_in]
        sparsity: fraction of rows to remove
    Returns:
        pruned weight [d_out * (1-sparsity), d_in], kept indices
    """
    d_out, d_in = weight.shape
    num_keep = int(d_out * (1 - sparsity))

    # Score each row by its expected output magnitude
    act_norms = activations.norm(dim=0)  # [d_in]
    row_scores = (weight.abs() * act_norms.unsqueeze(0)).sum(dim=1)  # [d_out]

    # Keep the highest-scoring rows
    _, keep_indices = row_scores.topk(num_keep)
    keep_indices = keep_indices.sort().values

    return weight[keep_indices], keep_indices

def structured_column_prune(weight, activations, sparsity):
    """Remove entire columns (input features) from weight matrix.

    Args:
        weight: [d_out, d_in]
        activations: [n_samples, d_in]
        sparsity: fraction of columns to remove
    Returns:
        pruned weight [d_out, d_in * (1-sparsity)], kept indices
    """
    d_out, d_in = weight.shape
    num_keep = int(d_in * (1 - sparsity))

    act_norms = activations.norm(dim=0)  # [d_in]
    col_scores = (weight.abs() * act_norms.unsqueeze(0)).sum(dim=0)  # [d_in]

    _, keep_indices = col_scores.topk(num_keep)
    keep_indices = keep_indices.sort().values

    return weight[:, keep_indices], keep_indices

Block pruning removes rectangular blocks:

def block_prune(weight, activations, sparsity, block_size=64):
    """Remove blocks of weights.

    Args:
        weight: [d_out, d_in]
        activations: [n_samples, d_in]
        sparsity: fraction of blocks to remove
        block_size: size of square blocks
    """
    d_out, d_in = weight.shape
    act_norms = activations.norm(dim=0)

    # Score each block
    block_scores = []
    block_positions = []

    for i in range(0, d_out, block_size):
        for j in range(0, d_in, block_size):
            i_end = min(i + block_size, d_out)
            j_end = min(j + block_size, d_in)

            block = weight[i:i_end, j:j_end]
            block_act = act_norms[j:j_end]
            score = (block.abs() * block_act.unsqueeze(0)).sum().item()

            block_scores.append(score)
            block_positions.append((i, j, i_end, j_end))

    # Prune lowest-scoring blocks
    num_blocks = len(block_scores)
    num_prune = int(num_blocks * sparsity)

    scores_tensor = torch.tensor(block_scores)
    _, prune_indices = scores_tensor.topk(num_prune, largest=False)

    pruned = weight.clone()
    for idx in prune_indices:
        i, j, i_end, j_end = block_positions[idx.item()]
        pruned[i:i_end, j:j_end] = 0

    return pruned

4.3 Quality-Speed Tradeoff

📊

Structured vs Unstructured Pruning at 50% Sparsity

MethodWikiText-2 PPLDelta vs Dense
Dense baseline 5.68 0%
Unstructured 50% (SparseGPT) 5.97 +5.1%
Unstructured 50% (Wanda) 6.12 +7.7%
Row-structured 50% 7.84 +38%
Column-structured 50% 8.21 +44.5%
Block-structured 50% (64x64) 6.89 +21.3%
Note: Llama 7B, WikiText-2 perplexity (lower is better). Dense baseline is 0% sparse.

Structured pruning hurts quality significantly more than unstructured pruning at the same sparsity level. The constraint of removing entire rows or columns means you cannot avoid removing important weights — if a row has one critical weight and 4095 unimportant ones, you still lose the critical one. Block pruning is a middle ground: finer granularity than rows/columns, but coarser than individual weights.

N:M Sparsity on Ampere GPUs

5.1 The Hardware Constraint

NVIDIA Ampere (A100) and later GPUs include hardware support for a specific sparsity pattern: N:M sparsity, where out of every M consecutive weights, exactly N are zero. The most common pattern is 2:4 sparsity: out of every 4 consecutive weights, 2 are zero (50% sparsity). The hardware contains a sparse tensor core that skips the zero multiplications, achieving roughly 2x throughput.

5.2 Why 2:4 Specifically

The 2:4 pattern is a hardware design choice. Each sparse tensor core has a metadata register that stores which 2 of the 4 weights are non-zero, using 2 bits per group. The hardware uses this metadata to select the corresponding input activations, performing only 2 multiplications instead of 4. This gives a 2x speedup with minimal metadata overhead (0.5 bits per weight).

def enforce_nm_sparsity(weight, n=2, m=4):
    """Enforce N:M sparsity pattern.

    For every group of M consecutive weights (along d_in),
    keep the N largest and zero out the rest.

    Args:
        weight: [d_out, d_in]
        n: number of zeros per group
        m: group size
    Returns:
        weight with N:M sparsity pattern
    """
    d_out, d_in = weight.shape
    assert d_in % m == 0, f"d_in ({d_in}) must be divisible by m ({m})"

    # Reshape to [d_out, d_in/m, m]
    W = weight.reshape(d_out, d_in // m, m)

    # Find the top (m-n) weights in each group (the ones to keep)
    keep_count = m - n  # 2 for 2:4 pattern
    _, top_indices = W.abs().topk(keep_count, dim=2)

    # Create mask
    mask = torch.zeros_like(W)
    mask.scatter_(2, top_indices, 1.0)

    # Apply mask
    result = (W * mask).reshape(d_out, d_in)
    return result

5.3 Combining SparseGPT/Wanda with N:M

Both SparseGPT and Wanda can be adapted to produce N:M patterns instead of arbitrary unstructured sparsity:

def wanda_nm_prune(weight, activations, n=2, m=4):
    """Wanda pruning with N:M sparsity constraint.

    Instead of pruning the globally lowest-scoring weights,
    prune the N lowest-scoring weights in each group of M.
    """
    d_out, d_in = weight.shape

    # Compute Wanda scores
    act_norms = activations.norm(dim=0)  # [d_in]
    scores = weight.abs() * act_norms.unsqueeze(0)  # [d_out, d_in]

    # Reshape to groups of M
    S = scores.reshape(d_out, d_in // m, m)
    W = weight.reshape(d_out, d_in // m, m)

    # In each group, keep the (m-n) highest-scoring weights
    keep_count = m - n
    _, top_indices = S.topk(keep_count, dim=2)

    mask = torch.zeros_like(W)
    mask.scatter_(2, top_indices, 1.0)

    result = (W * mask).reshape(d_out, d_in)
    return result

def sparsegpt_nm_prune(weight, hessian_inv, n=2, m=4):
    """SparseGPT with N:M constraint.

    Process in groups of M columns. Within each group,
    prune the N columns with highest error, then compensate.
    """
    d_out, d_in = weight.shape
    W = weight.clone()

    for group_start in range(0, d_in, m):
        group_end = group_start + m
        W_group = W[:, group_start:group_end].clone()
        H_inv_group = hessian_inv[group_start:group_end, group_start:group_end]

        # Score each column in the group
        diag = H_inv_group.diag()
        col_scores = (W_group ** 2).sum(dim=0) / diag  # [m]

        # Keep top (m-n) scoring columns
        keep_count = m - n
        _, keep_idx = col_scores.topk(keep_count)
        prune_idx = torch.tensor([i for i in range(m) if i not in keep_idx])

        # Zero out pruned columns and compensate
        for j in prune_idx:
            err = W_group[:, j].clone()
            W_group[:, j] = 0

            # Compensate surviving columns
            h_jj = H_inv_group[j, j]
            for k in keep_idx:
                W_group[:, k] -= (err / h_jj) * H_inv_group[j, k]

        W[:, group_start:group_end] = W_group

        # Compensate future groups
        if group_end < d_in:
            for j in prune_idx:
                err = weight[:, group_start + j]
                h_jj = hessian_inv[group_start + j, group_start + j]
                h_cross = hessian_inv[group_start + j, group_end:]
                W[:, group_end:] -= (err / h_jj).unsqueeze(1) * h_cross.unsqueeze(0)

    return W

5.4 N:M Performance on Ampere

📊

2:4 Sparsity Performance on A100

ConfigurationThroughput (tok/s)Delta vs Dense
Dense (cuBLAS) 2,847 baseline
2:4 Sparse (cuSPARSELt) 4,952 +73.9%
2:4 Sparse + INT8 weights 7,208 +153%
Unstructured 50% (no HW support) 2,891 +1.5%
Note: Llama 7B inference throughput on A100 80GB.
⚠️ Warning

Unstructured 50% sparsity gives almost no speedup without dedicated hardware support. The GPU still loads the full matrix from memory and multiplies by zeros. N:M sparsity with cuSPARSELt achieves real speedups because the hardware skips zero entries at the tensor core level.

Quality vs. Sparsity Curves

6.1 Measuring Degradation

The standard evaluation protocol for pruning: measure perplexity on WikiText-2 and zero-shot accuracy on downstream tasks (ARC, HellaSwag, WinoGrande, PIQA) at increasing sparsity levels.

def evaluate_pruning_sweep(model, tokenizer, calibration_loader,
                           eval_dataset, sparsity_levels):
    """Evaluate model quality across sparsity levels."""
    import copy
    from lm_eval import evaluator

    results = []

    for sparsity in sparsity_levels:
        # Clone the model
        pruned_model = copy.deepcopy(model)

        # Apply Wanda pruning
        wanda_full(pruned_model, calibration_loader, sparsity=sparsity)

        # Measure perplexity
        ppl = compute_perplexity(pruned_model, tokenizer, eval_dataset)

        # Measure zero-shot accuracy
        accuracy = evaluator.simple_evaluate(
            model=pruned_model,
            tasks=["arc_easy", "hellaswag", "winogrande", "piqa"],
            batch_size=32
        )

        results.append({
            "sparsity": sparsity,
            "perplexity": ppl,
            "avg_accuracy": accuracy["results"]["average"]
        })

        print(f"Sparsity: {sparsity:.0%} | PPL: {ppl:.2f} | "
              f"Avg Acc: {accuracy['results']['average']:.1%}")

        del pruned_model

    return results

6.2 Empirical Results

Llama 7B: Perplexity vs Sparsity

Metric 0%10%20%30%40%50%60%70%80%
SparseGPT (unstructured)
5.68
5.69
5.71
5.76
5.84
5.97
6.42
7.89
14.2
Wanda (unstructured)
5.68
5.7
5.73
5.8
5.93
6.12
6.78
8.45
17.8
Magnitude (unstructured)
5.68
5.72
5.82
6.05
6.48
7.34
9.21
15.6
42.1
2:4 N:M (Wanda)
5.68
6.25

Key observations from the empirical curves:

  1. Up to 30% sparsity: All methods perform similarly. There is genuine redundancy in the model that any method can find.

  2. 30-50% sparsity: SparseGPT’s Hessian compensation provides a measurable advantage (0.15-0.3 perplexity points over Wanda, 0.5-1.5 over magnitude).

  3. 50-60% sparsity: The gap widens significantly. SparseGPT’s weight updates become critical for maintaining quality.

  4. Beyond 70%: All one-shot methods degrade rapidly. At this level, retraining or iterative pruning is necessary.

6.3 Task-Specific Analysis

Llama 7B: Zero-Shot Accuracy at 50% Sparsity

Metric ARC-EasyARC-ChallengeHellaSwagWinoGrandePIQAAverage
Dense
75.2
46.3
76.1
70
79.1
69.3
SparseGPT 50%
73.8
44.1
74.2
68.5
78
67.7
Wanda 50%
72.9
43.2
73.4
67.8
77.2
66.9

At 50% sparsity, both methods retain roughly 96-97% of the dense model’s average accuracy. The degradation is relatively uniform across tasks, with ARC-Challenge (the hardest reasoning task) showing the largest relative drop.

Layer-Wise Sensitivity

7.1 Not All Layers Are Equal

Different layers have different sensitivity to pruning. Early layers (near the embedding) and late layers (near the output head) are typically more sensitive than middle layers.

def layer_sensitivity_analysis(model, calibration_loader,
                               eval_dataset, tokenizer):
    """Measure per-layer sensitivity to pruning.

    For each layer, prune only that layer to 50% while keeping
    all other layers dense. Measure the perplexity increase.
    """
    import copy

    base_ppl = compute_perplexity(model, tokenizer, eval_dataset)
    sensitivities = {}

    for name, module in model.named_modules():
        if not isinstance(module, nn.Linear):
            continue

        # Clone model, prune only this layer
        test_model = copy.deepcopy(model)
        target = dict(test_model.named_modules())[name]

        # Quick magnitude prune for sensitivity analysis
        mask = magnitude_prune(target.weight.data, sparsity=0.5)
        target.weight.data = mask

        ppl = compute_perplexity(test_model, tokenizer, eval_dataset)
        sensitivities[name] = ppl - base_ppl

        del test_model

    # Sort by sensitivity
    sorted_layers = sorted(sensitivities.items(), key=lambda x: x[1], reverse=True)
    for name, delta_ppl in sorted_layers[:10]:
        print(f"{name}: +{delta_ppl:.2f} perplexity")

    return sensitivities

7.2 Non-Uniform Sparsity

The sensitivity analysis motivates non-uniform sparsity: prune sensitive layers less aggressively and insensitive layers more aggressively, keeping the total parameter count the same.

def allocate_sparsity(sensitivities, target_avg_sparsity=0.5,
                      min_sparsity=0.2, max_sparsity=0.8):
    """Allocate per-layer sparsity inversely proportional to sensitivity.

    High-sensitivity layers get low sparsity.
    Low-sensitivity layers get high sparsity.
    Average across all layers equals target_avg_sparsity.
    """
    layers = list(sensitivities.keys())
    sens_values = torch.tensor([sensitivities[l] for l in layers])

    # Invert: high sensitivity -> low sparsity
    inv_sens = 1.0 / (sens_values + 1e-8)

    # Normalize to mean = target
    sparsity_ratios = inv_sens / inv_sens.mean() * target_avg_sparsity

    # Clip to valid range
    sparsity_ratios = sparsity_ratios.clamp(min_sparsity, max_sparsity)

    # Adjust to hit target average
    current_avg = sparsity_ratios.mean().item()
    sparsity_ratios *= target_avg_sparsity / current_avg
    sparsity_ratios = sparsity_ratios.clamp(min_sparsity, max_sparsity)

    return {layer: s.item() for layer, s in zip(layers, sparsity_ratios)}
Performance

Non-uniform sparsity allocations typically improve perplexity by 0.1-0.3 points over uniform sparsity at the same average sparsity level. The first and last 2-3 transformer blocks are consistently the most sensitive — they should be pruned at 20-30% while middle blocks can tolerate 60-70%.

Pruning + Quantization: Stacking Compression

8.1 Orthogonal Techniques

Pruning (removing weights) and quantization (reducing bit-width) are largely orthogonal. A 50% sparse model quantized to 4 bits uses roughly 0.5×4/16=12.50.5 \times 4/16 = 12.5% of the original model’s memory. The quality loss compounds, but less than you might expect because they remove different types of redundancy.

def prune_then_quantize(model, calibration_loader, prune_sparsity=0.5):
    """Apply Wanda pruning followed by GPTQ quantization."""
    # Step 1: Prune
    wanda_full(model, calibration_loader, sparsity=prune_sparsity)

    # Step 2: Quantize surviving weights to 4-bit
    # (using GPTQ or similar -- simplified here)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            W = module.weight.data

            # Only quantize non-zero weights
            mask = W != 0
            W_nonzero = W[mask]

            # Symmetric 4-bit quantization
            max_val = W_nonzero.abs().max()
            scale = max_val / 7  # 4-bit signed: -8 to 7

            W_quant = torch.round(W / scale).clamp(-8, 7) * scale
            W_quant[~mask] = 0  # Keep pruned weights at zero

            module.weight.data = W_quant

def measure_compression(model):
    """Measure effective compression ratio."""
    total_params = 0
    nonzero_params = 0

    for p in model.parameters():
        total_params += p.numel()
        nonzero_params += p.count_nonzero().item()

    sparsity = 1 - nonzero_params / total_params
    # Assuming 4-bit for non-zero, 0 bits for zero
    effective_bits = (1 - sparsity) * 4
    compression = 16 / effective_bits  # vs FP16 baseline

    print(f"Sparsity: {sparsity:.1%}")
    print(f"Effective bits per param: {effective_bits:.2f}")
    print(f"Compression ratio: {compression:.1f}x")

8.2 Combined Results

📊

Compression Stack: Pruning + Quantization on Llama 7B

ConfigurationPerplexityMemoryCompression
Dense FP16 (baseline) 5.68 14.0 GB 1.0x
50% sparse FP16 (Wanda) 6.12 7.0 GB 2.0x
Dense INT4 (GPTQ) 5.85 3.5 GB 4.0x
50% sparse + INT4 6.38 1.75 GB 8.0x
2:4 sparse + INT8 6.01 3.5 GB 4.0x (2x speed)
Note: WikiText-2 perplexity. Memory is model weight storage only.

The 2:4 sparse + INT8 combination is particularly attractive: it achieves 4x memory compression with actual 2x inference speedup (via sparse tensor cores), and the quality degradation is only 0.33 perplexity points.

Iterative and Recovery-Based Pruning

9.1 Sparse Fine-Tuning

One-shot pruning is fast but leaves quality on the table. If you can afford some training budget, sparse fine-tuning recovers much of the lost quality:

def sparse_finetune(model, train_loader, optimizer, mask, epochs=2):
    """Fine-tune a pruned model while maintaining the sparsity pattern.

    Args:
        model: pruned model
        train_loader: training data
        optimizer: optimizer (typically AdamW with low LR)
        mask: dict mapping parameter names to binary masks
        epochs: number of fine-tuning epochs
    """
    model.train()

    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            input_ids = batch["input_ids"].cuda()
            labels = batch["labels"].cuda()

            outputs = model(input_ids, labels=labels)
            loss = outputs.loss

            loss.backward()

            # Zero out gradients for pruned weights
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in mask:
                        param.grad *= mask[name]

            optimizer.step()
            optimizer.zero_grad()

            # Re-apply mask to handle numerical drift
            with torch.no_grad():
                for name, param in model.named_parameters():
                    if name in mask:
                        param.data *= mask[name]

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")

9.2 Iterative Magnitude Pruning (IMP)

The lottery ticket hypothesis approach: train, prune a small fraction, retrain, prune more, repeat. This finds better sparse networks than one-shot methods but costs much more compute:

def iterative_magnitude_pruning(model, train_loader, eval_loader,
                                 target_sparsity=0.9, prune_steps=10,
                                 finetune_epochs=2):
    """Iterative magnitude pruning with retraining.

    Prune in small steps, retraining between each step.
    """
    current_sparsity = 0.0
    sparsity_per_step = 1 - (1 - target_sparsity) ** (1 / prune_steps)

    mask = {}
    for name, param in model.named_parameters():
        if "weight" in name and param.dim() == 2:
            mask[name] = torch.ones_like(param, dtype=torch.bool)

    for step in range(prune_steps):
        # Prune: remove smallest magnitude surviving weights
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name not in mask:
                    continue

                surviving = param[mask[name]]
                if surviving.numel() == 0:
                    continue

                num_prune = int(surviving.numel() * sparsity_per_step)
                threshold = torch.kthvalue(surviving.abs().flatten(), num_prune).values

                new_prune = (param.abs() < threshold) & mask[name]
                mask[name] &= ~new_prune
                param.data *= mask[name].float()

        # Compute current sparsity
        total = sum(m.numel() for m in mask.values())
        nonzero = sum(m.sum().item() for m in mask.values())
        current_sparsity = 1 - nonzero / total

        # Retrain
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
        sparse_finetune(model, train_loader, optimizer, mask, finetune_epochs)

        ppl = compute_perplexity(model, None, eval_loader)
        print(f"Step {step+1}/{prune_steps}: "
              f"sparsity={current_sparsity:.1%}, ppl={ppl:.2f}")
ℹ️ Note

Iterative pruning at 90% sparsity achieves perplexity close to one-shot pruning at 50%. The cost is 10-20x more training compute. For LLMs, this is rarely practical — the one-shot methods (SparseGPT, Wanda) dominate in practice because they require zero training.

Practical Deployment

10.1 Sparse Format Conversion

For deployment, sparse weights need to be stored in a compressed format:

def convert_to_csr(weight):
    """Convert dense weight to CSR (Compressed Sparse Row) format.

    CSR stores only non-zero values plus index arrays.
    Memory: nnz * (value_bytes + index_bytes) + d_out * pointer_bytes
    """
    sparse = weight.to_sparse_csr()

    crow_indices = sparse.crow_indices()  # [d_out + 1] row pointers
    col_indices = sparse.col_indices()    # [nnz] column indices
    values = sparse.values()              # [nnz] non-zero values

    # Memory comparison
    dense_bytes = weight.numel() * weight.element_size()
    sparse_bytes = (values.numel() * values.element_size() +
                    col_indices.numel() * 4 +  # int32 indices
                    crow_indices.numel() * 4)

    print(f"Dense: {dense_bytes / 1e6:.1f} MB")
    print(f"CSR:   {sparse_bytes / 1e6:.1f} MB")
    print(f"Ratio: {sparse_bytes / dense_bytes:.2f}x")

    return sparse

def convert_to_nm_format(weight, n=2, m=4):
    """Convert to N:M sparse format for cuSPARSELt.

    Stores non-zero values (compressed) + 2-bit metadata per group.
    """
    d_out, d_in = weight.shape
    num_groups = d_in // m
    keep_per_group = m - n  # 2 for 2:4

    # Extract non-zero values and metadata
    W_groups = weight.reshape(d_out, num_groups, m)
    _, top_idx = W_groups.abs().topk(keep_per_group, dim=2)

    # Compressed values: only store non-zeros
    values = torch.gather(W_groups, 2, top_idx)  # [d_out, num_groups, keep]
    # Metadata: which positions are non-zero (2 bits per group for 2:4)
    metadata = top_idx  # [d_out, num_groups, keep]

    compressed_bytes = values.numel() * values.element_size()
    metadata_bytes = d_out * num_groups * 1  # ~2 bits per group, packed
    dense_bytes = weight.numel() * weight.element_size()

    print(f"Dense:      {dense_bytes / 1e6:.1f} MB")
    print(f"2:4 format: {(compressed_bytes + metadata_bytes) / 1e6:.1f} MB")

    return values, metadata

10.2 End-to-End Pruning Pipeline

def full_pruning_pipeline(model_name, sparsity=0.5, method="wanda",
                           nm_sparsity=False, quantize=False):
    """Complete pruning pipeline from HuggingFace model to deployment."""
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from datasets import load_dataset

    # Load model
    model = AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.float16, device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Prepare calibration data
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    calibration_texts = dataset["text"][:128]
    calibration_tokens = tokenizer(
        calibration_texts, return_tensors="pt",
        padding=True, truncation=True, max_length=2048
    )
    calibration_loader = [{"input_ids": calibration_tokens["input_ids"]}]

    # Prune
    if method == "wanda":
        if nm_sparsity:
            wanda_nm_full(model, calibration_loader)
        else:
            wanda_full(model, calibration_loader, sparsity=sparsity)
    elif method == "sparsegpt":
        sparsegpt_full(model, calibration_loader, sparsity=sparsity)

    # Optional quantization
    if quantize:
        prune_then_quantize(model, calibration_loader, prune_sparsity=0)

    # Evaluate
    eval_dataset = load_dataset(
        "wikitext", "wikitext-2-raw-v1", split="test"
    )
    ppl = compute_perplexity(model, tokenizer, eval_dataset)
    print(f"Final perplexity: {ppl:.2f}")

    # Save
    output_dir = f"{model_name}-{method}-{sparsity}"
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    return model

Comparison Summary

📊

Pruning Methods Comparison (Llama 7B, 50% Sparsity)

MethodPerplexityTimeWeight UpdatesBest For
Magnitude pruning 7.34 0 min None Worst quality
SparseGPT 5.97 5 min Hessian-based Best quality
Wanda 6.12 0.5 min None Best speed/quality
Wanda 2:4 N:M 6.25 0.5 min None Best deployment

When to use each method:

  • SparseGPT: When quality is paramount and you can afford 5-10 minutes of compute per model. Best for high sparsity (60%+).
  • Wanda: Default choice. Nearly as good as SparseGPT, 10x faster, simpler implementation.
  • Wanda 2:4: When deploying to Ampere/Hopper GPUs and you need actual inference speedup, not just compression.
  • Structured pruning: When you need speedup on hardware without sparse tensor core support and can tolerate more quality loss.

The field is converging on a practical recipe: Wanda or SparseGPT for initial pruning, 2:4 N:M format for deployment on NVIDIA GPUs, optional sparse fine-tuning for recovering the last fraction of quality, and stacking with INT4/INT8 quantization for maximum compression.

References

  1. Frantar, E. and Alistarh, D. “SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot.” ICML 2023.
  2. Sun, M. et al. “A Simple and Effective Pruning Approach for Large Language Models.” ICLR 2024.
  3. Mishra, A. et al. “Accelerating Sparse Deep Neural Networks.” arXiv 2021.
  4. Pool, J. and Yu, C. “Channel Permutations for N:M Sparsity.” NeurIPS 2021.
  5. Frankle, J. and Carlin, M. “The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks.” ICLR 2019.