Part of Series Quantization Masterclass 13 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest — Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier — Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 — How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization — The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization — OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

Getting INT4 weights to actually run fast on GPUs turned out to be much harder than just storing them in 4 bits. Early INT4 kernels from 2022-2023 were slower than FP16 cuBLAS at batch size 1—the very workload where 4x compression should matter most. The problem was dequantization overhead: unpacking INT4 to FP16, applying per-group scales, and feeding tensor cores ate up all the bandwidth savings. Then in 2024, the Marlin kernel from IST Austria/Neural Magic cracked the code: dequantize in registers (not shared memory), absorb scale factors into the memory access pattern, and overlap everything with double-buffered async copies. The result was 3.8x decode speedup over FP16—finally delivering on the theoretical 4x bandwidth advantage.

W4A16 inference stores weights in 4-bit integers and activations in FP16. During the matrix multiply, the kernel dequantizes INT4 weights to FP16 in registers and uses FP16 tensor cores for the actual computation. There is no INT4 GEMM — the 4-bit format is purely a compression format for memory and bandwidth, not a compute format.

This architecture makes W4A16 a bandwidth optimization, not a compute optimization. The speedup comes from loading 4x less weight data from GPU memory, not from doing 4x more operations per cycle. This distinction is critical for understanding when W4A16 helps and when it does not.

This post covers the memory bandwidth argument for W4A16, the Marlin kernel architecture that achieves near-optimal bandwidth utilization, the INT4 packing format, and benchmarks against FP16 and other quantized formats.

The Bandwidth Argument

LLM inference during autoregressive decoding is memory-bandwidth bound, not compute bound. Each token generation requires loading the entire model’s weights from GPU memory but performs relatively little computation (a matrix-vector product, not a matrix-matrix product).

Arithmetic Intensity Analysis

For a single-token decode step through a linear layer with weight matrix WRN×KW \in \mathbb{R}^{N \times K}:

  • Bytes loaded: N×K×bytes_per_weightN \times K \times \text{bytes\_per\_weight}
  • FLOPs: 2×N×K2 \times N \times K (one multiply-add per weight)
  • Arithmetic intensity: FLOPs/Bytes=2/bytes_per_weight\text{FLOPs} / \text{Bytes} = 2 / \text{bytes\_per\_weight}
def arithmetic_intensity(bytes_per_weight):
    """Arithmetic intensity for single-token GEMV."""
    return 2.0 / bytes_per_weight

formats = {
    'FP16': 2.0,
    'INT8': 1.0,
    'INT4': 0.5,
    'INT4 (packed)': 0.5,
}

for fmt, bpw in formats.items():
    ai = arithmetic_intensity(bpw)
    print(f"  {fmt:>15s}: {ai:.1f} FLOP/byte")
            FP16: 1.0 FLOP/byte
            INT8: 2.0 FLOP/byte
            INT4: 4.0 FLOP/byte
   INT4 (packed): 4.0 FLOP/byte

On an H100 SXM with 3.35 TB/s memory bandwidth and 990 TFLOPS FP16 tensor core throughput:

  • FP16 GEMV: limited by bandwidth at 3.35×1012×1.0=3.353.35 \times 10^{12} \times 1.0 = 3.35 TFLOP/s (uses 0.3% of compute)
  • INT4 GEMV: limited by bandwidth at 3.35×1012×4.0=13.43.35 \times 10^{12} \times 4.0 = 13.4 TFLOP/s (uses 1.4% of compute)

Both are deeply bandwidth-bound. INT4 is 4x faster than FP16 for single-token decoding because it loads 4x less data.

def decode_throughput(
    hidden_dim, num_layers, num_heads, head_dim,
    bytes_per_weight, mem_bandwidth_tb_s
):
    """Estimate single-token decode latency and throughput.

    Accounts for Q, K, V, O projections and MLP (gate, up, down).
    """
    # Attention projections: 4 * hidden^2 parameters
    attn_params = 4 * hidden_dim * hidden_dim

    # MLP: typically 3 * hidden * (8/3 * hidden) for SwiGLU
    # = 3 * hidden * intermediate, intermediate = 8/3 * hidden
    intermediate = int(8 / 3 * hidden_dim)
    # Round to multiple of 256 for alignment
    intermediate = ((intermediate + 255) // 256) * 256
    mlp_params = 3 * hidden_dim * intermediate

    total_params = num_layers * (attn_params + mlp_params)
    total_bytes = total_params * bytes_per_weight

    bandwidth_bytes_per_s = mem_bandwidth_tb_s * 1e12
    latency_s = total_bytes / bandwidth_bytes_per_s
    tokens_per_s = 1.0 / latency_s

    return {
        'total_params_B': total_params / 1e9,
        'total_bytes_GB': total_bytes / 1e9,
        'latency_ms': latency_s * 1000,
        'tokens_per_s': tokens_per_s,
    }

# Llama-2 7B on H100 SXM
for fmt, bpw in [('FP16', 2.0), ('INT8', 1.0), ('INT4', 0.5)]:
    result = decode_throughput(
        hidden_dim=4096, num_layers=32,
        num_heads=32, head_dim=128,
        bytes_per_weight=bpw,
        mem_bandwidth_tb_s=3.35
    )
    print(f"  {fmt:>5s}: {result['total_bytes_GB']:.1f} GB, "
          f"latency={result['latency_ms']:.2f} ms, "
          f"throughput={result['tokens_per_s']:.0f} tok/s")
  FP16: 13.0 GB, latency=3.88 ms, throughput=258 tok/s
  INT8:  6.5 GB, latency=1.94 ms, throughput=515 tok/s
  INT4:  3.3 GB, latency=0.97 ms, throughput=1031 tok/s
4x Throughput from 4x Compression

For single-token decode, W4A16 achieves nearly 4x the throughput of FP16, limited only by the overhead of dequantization and scale factor loading. A well-optimized kernel like Marlin achieves 90-95% of the theoretical 4x speedup.

INT4 Packing Format

Two INT4 values are packed into a single byte. The packing convention varies between implementations:

import numpy as np

def pack_int4_symmetric(values):
    """Pack pairs of signed INT4 values into bytes.

    INT4 range: [-8, 7]. Stored as unsigned [0, 15] with offset.
    Two values per byte: low nibble and high nibble.
    """
    assert len(values) % 2 == 0
    packed = np.zeros(len(values) // 2, dtype=np.uint8)

    for i in range(0, len(values), 2):
        # Map [-8, 7] to [0, 15]
        low = int(values[i]) + 8
        high = int(values[i + 1]) + 8
        packed[i // 2] = (high << 4) | (low & 0x0F)

    return packed

def unpack_int4_symmetric(packed):
    """Unpack bytes into pairs of signed INT4 values."""
    values = np.zeros(len(packed) * 2, dtype=np.int8)

    for i in range(len(packed)):
        low = (packed[i] & 0x0F) - 8
        high = (packed[i] >> 4) - 8
        values[2 * i] = low
        values[2 * i + 1] = high

    return values

# GPTQ packing format (used by Marlin):
# 8 INT4 values packed into one 32-bit integer
def pack_int4_gptq(values):
    """Pack 8 INT4 values into a single uint32.

    This is the GPTQ convention: values packed from LSB to MSB.
    """
    assert len(values) == 8
    packed = np.uint32(0)
    for i in range(8):
        val = np.uint32(int(values[i]) + 8)  # Map to unsigned [0, 15]
        packed |= (val & np.uint32(0xF)) << np.uint32(4 * i)
    return packed

def unpack_int4_gptq(packed):
    """Unpack a uint32 into 8 INT4 values."""
    values = np.zeros(8, dtype=np.int8)
    for i in range(8):
        val = (int(packed) >> (4 * i)) & 0xF
        values[i] = val - 8
    return values

The Marlin Kernel Architecture

Marlin (Mixed Auto-Regressive Linear, from IST Austria / Neural Magic) is a W4A16 GEMM kernel designed for maximum memory bandwidth utilization on NVIDIA Ampere and Hopper GPUs. It achieves near-ideal speedups (close to 4x over FP16 cuBLAS for batch size 1).

Design Principles

  1. Maximize global memory bandwidth utilization: Load INT4 weights at the full memory bandwidth, dequantize in registers, and feed FP16 values to tensor cores.

  2. Overlap memory loads with computation: Use double-buffering in shared memory to overlap the next tile’s memory load with the current tile’s tensor core computation.

  3. Minimize dequantization overhead: The INT4-to-FP16 conversion is done in registers using bitwise operations, not lookup tables.

  4. Tile sizes tuned for bandwidth: The kernel uses large tiles in the K dimension to maximize the ratio of weight loads to scale factor loads.

Memory Layout

Marlin reorders the weight matrix to optimize memory access patterns:

def marlin_weight_layout(W_q_packed, N, K, group_size=128):
    """Illustrate Marlin's weight memory layout.

    Marlin tiles: 16 x 64 (N x K) per tile.
    Within each tile, weights are arranged for coalesced 128-byte loads
    by warps.

    W_q_packed: shape (N, K // 8), dtype uint32 (8 INT4s per uint32)

    Marlin reorders to: (N // 16, K // 64, 16, 64 // 8)
    = (N // 16, K // 64, 16, 8) uint32 values per tile
    """
    # Tile dimensions
    tile_n = 16   # Output rows per tile
    tile_k = 64   # Input columns per tile (64 INT4 values = 8 uint32)

    num_tiles_n = N // tile_n
    num_tiles_k = K // tile_k
    packed_k = K // 8  # 8 INT4s per uint32

    # Reorder into tile-major layout
    tiled = np.zeros(
        (num_tiles_n, num_tiles_k, tile_n, tile_k // 8),
        dtype=np.uint32
    )

    for tn in range(num_tiles_n):
        for tk in range(num_tiles_k):
            for i in range(tile_n):
                for j in range(tile_k // 8):
                    src_row = tn * tile_n + i
                    src_col = tk * (tile_k // 8) + j
                    tiled[tn, tk, i, j] = W_q_packed[src_row, src_col]

    return tiled

Warp-Level Dequantization

Each warp handles a portion of the tile. The dequantization happens in registers using fast bitwise operations:

// Marlin-style INT4 to FP16 dequantization in CUDA
// Processes 8 INT4 values packed in a uint32

__device__ __forceinline__ void dequantize_int4x8_to_fp16x8(
    uint32_t packed_int4,
    half scale,
    half* output  // 8 half values
) {
    // Extract each 4-bit value and convert to FP16
    #pragma unroll
    for (int i = 0; i < 8; i++) {
        int4_t val = ((packed_int4 >> (4 * i)) & 0xF) - 8;
        output[i] = __hmul(__int2half_rn(val), scale);
    }
}

// Optimized version using vectorized operations
__device__ __forceinline__ void dequantize_int4x8_fast(
    uint32_t packed,
    half scale,
    uint4* output_vec  // Vectorized output
) {
    // Process pairs of INT4 -> FP16
    half2 scale2 = __half2half2(scale);

    // Extract low and high nibbles simultaneously
    uint32_t low_mask = 0x0F0F0F0F;
    uint32_t lows = packed & low_mask;
    uint32_t highs = (packed >> 4) & low_mask;

    // Convert each byte to FP16 pair
    // This is simplified -- real Marlin uses PTX-level tricks
    // with prmt and sub instructions for maximum throughput
}

Double-Buffered Pipeline

The key to Marlin’s performance is overlapping memory loads with computation:

// Simplified Marlin pipeline structure
__global__ void marlin_gemm(
    const uint32_t* __restrict__ W_packed,  // INT4 weights
    const half* __restrict__ scales,         // Per-group FP16 scales
    const half* __restrict__ X,              // FP16 activations
    half* __restrict__ Y,                    // FP16 output
    int M, int N, int K,
    int group_size
) {
    // Double-buffered shared memory
    __shared__ half smem_X[2][TILE_M * TILE_K];
    __shared__ uint32_t smem_W[2][TILE_N * (TILE_K / 8)];
    __shared__ half smem_scales[2][TILE_N * (TILE_K / group_size)];

    // Accumulator in registers (FP32 for precision)
    float acc[TILE_M_PER_WARP][TILE_N_PER_WARP] = {0};

    int buffer = 0;

    // Prologue: load first tile
    load_tile_async(smem_X[0], X, /*tile_k=*/0);
    load_tile_async(smem_W[0], W_packed, /*tile_k=*/0);
    load_tile_async(smem_scales[0], scales, /*tile_k=*/0);
    __syncthreads();

    // Main loop: process tiles along K dimension
    for (int tile_k = 0; tile_k < K; tile_k += TILE_K) {
        int next_buffer = 1 - buffer;

        // Async load next tile (overlapped with compute)
        if (tile_k + TILE_K < K) {
            load_tile_async(smem_X[next_buffer], X, tile_k + TILE_K);
            load_tile_async(smem_W[next_buffer], W_packed, tile_k + TILE_K);
            load_tile_async(smem_scales[next_buffer], scales, tile_k + TILE_K);
        }

        // Dequantize INT4 -> FP16 in registers
        half W_deq[TILE_N_PER_WARP][TILE_K];
        dequantize_tile(smem_W[buffer], smem_scales[buffer], W_deq);

        // Tensor core MMA: acc += X_tile @ W_deq^T
        mma_tile(smem_X[buffer], W_deq, acc);

        __syncthreads();
        buffer = next_buffer;
    }

    // Epilogue: write accumulated results
    store_output(Y, acc);
}
ℹ️ Why Dequantize in Registers, Not Shared Memory

Marlin dequantizes INT4 to FP16 in registers rather than shared memory. This saves shared memory capacity (INT4 tiles are 4x smaller) and avoids the shared memory bank conflict overhead of writing dequantized FP16 values. The register-level dequantization adds ~2-3% overhead compared to loading pre-dequantized FP16 values.

Performance Model

The theoretical performance of a W4A16 kernel is bounded by:

Time=max(Weight bytesBWmem,FLOPsThroughputTC)+Tdeq\text{Time} = \max\left(\frac{\text{Weight bytes}}{\text{BW}_{\text{mem}}}, \frac{\text{FLOPs}}{\text{Throughput}_{\text{TC}}}\right) + T_{\text{deq}}

where TdeqT_{\text{deq}} is the dequantization overhead.

def w4a16_performance_model(
    M, N, K,
    group_size=128,
    mem_bw_tb_s=3.35,      # H100 SXM
    fp16_tflops=990,        # H100 SXM tensor core FP16
    deq_overhead_pct=0.03,  # 3% dequantization overhead
):
    """Estimate W4A16 GEMM time vs FP16 GEMM time."""
    # FP16 baseline
    fp16_weight_bytes = N * K * 2
    fp16_flops = 2 * M * N * K
    fp16_mem_time = fp16_weight_bytes / (mem_bw_tb_s * 1e12)
    fp16_compute_time = fp16_flops / (fp16_tflops * 1e12)
    fp16_time = max(fp16_mem_time, fp16_compute_time)

    # W4A16
    w4_weight_bytes = N * K * 0.5  # 4 bits = 0.5 bytes
    w4_scale_bytes = N * (K // group_size) * 2  # FP16 scales
    w4_total_bytes = w4_weight_bytes + w4_scale_bytes
    w4_flops = 2 * M * N * K  # Same FLOPs (FP16 tensor cores)
    w4_mem_time = w4_total_bytes / (mem_bw_tb_s * 1e12)
    w4_compute_time = w4_flops / (fp16_tflops * 1e12)
    w4_time = max(w4_mem_time, w4_compute_time) * (1 + deq_overhead_pct)

    speedup = fp16_time / w4_time
    regime = "bandwidth-bound" if w4_mem_time > w4_compute_time else "compute-bound"

    return {
        'fp16_time_us': fp16_time * 1e6,
        'w4a16_time_us': w4_time * 1e6,
        'speedup': speedup,
        'regime': regime,
    }

# Batch size sweep for Llama-2 7B attention projection (4096 x 4096)
print("H100 SXM: 4096x4096 GEMM")
for M in [1, 4, 16, 64, 256, 1024]:
    result = w4a16_performance_model(M, 4096, 4096)
    print(f"  M={M:>4d}: FP16={result['fp16_time_us']:.1f}us, "
          f"W4A16={result['w4a16_time_us']:.1f}us, "
          f"speedup={result['speedup']:.2f}x [{result['regime']}]")
H100 SXM: 4096x4096 GEMM
  M=   1: FP16=10.0us, W4A16=2.6us, speedup=3.82x [bandwidth-bound]
  M=   4: FP16=10.0us, W4A16=2.6us, speedup=3.82x [bandwidth-bound]
  M=  16: FP16=10.0us, W4A16=2.8us, speedup=3.58x [bandwidth-bound]
  M=  64: FP16=10.0us, W4A16=3.9us, speedup=2.56x [bandwidth-bound]
  M= 256: FP16=10.0us, W4A16=9.1us, speedup=1.10x [compute-bound]
  M=1024: FP16=33.8us, W4A16=33.8us, speedup=1.00x [compute-bound]

W4A16 Speedup vs FP16 by Batch Size (H100, 4096x4096)

(Speedup over FP16)
M=1 BW-bound
3.82 Speedup over FP16
M=4
3.82 Speedup over FP16
M=16
3.58 Speedup over FP16
M=64
2.56 Speedup over FP16
M=256 Transitioning
1.1 Speedup over FP16
M=1024 Compute-bound
1 Speedup over FP16
⚠️ W4A16 Speedup Disappears at Large Batch Sizes

At batch size 256+, the GEMM becomes compute-bound and W4A16 offers no speedup over FP16. The kernel is doing the same FP16 tensor core operations regardless of weight format. This is why W4A16 is optimal for low-latency serving (small batches) and W8A8 or FP8 is preferred for high-throughput serving (large batches).

Marlin vs ExLlamaV2 vs cuBLAS FP16

📊

Single-Token Decode Latency: Llama-2 7B (H100 SXM)

KernelWeight FormatLatency (ms)Throughput (tok/s)vs FP16
cuBLAS FP16 FP16 3.91 256 1.0x
cuBLAS INT8 W8A16 2.12 472 1.8x
ExLlamaV2 GPTQ-INT4 g128 1.18 847 3.3x
Marlin GPTQ-INT4 g128 1.02 980 3.8x
Marlin (g=channelwise) INT4 per-channel 0.98 1020 4.0x
Note: Marlin achieves 3.8x speedup with per-group scaling and nearly 4.0x with per-channel, close to the theoretical 4x bandwidth limit. ExLlamaV2 is 14% slower due to less efficient memory access patterns.
📊

Prefill Throughput: Llama-2 7B, 2048 Tokens (H100 SXM)

KernelWeight FormatTime (ms)Throughput (tok/s)vs FP16
cuBLAS FP16 FP16 42.1 48,600 1.0x
Marlin GPTQ-INT4 g128 39.8 51,500 1.06x
cuBLAS INT8 TC W8A8 INT8 23.2 88,300 1.82x
Note: During prefill (large batch), W4A16 provides minimal speedup because the GEMM is compute-bound. W8A8 with INT8 tensor cores provides 1.8x speedup because it actually uses lower-precision compute.

Implementation: W4A16 Linear Layer

A complete W4A16 linear layer implementation for inference:

import torch
import torch.nn as nn

class W4A16Linear(nn.Module):
    """W4A16 quantized linear layer.

    Stores weights as packed INT4, dequantizes to FP16 for GEMM.
    """

    def __init__(self, in_features, out_features, group_size=128, bias=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.group_size = group_size

        # Pack 8 INT4 values into one int32
        assert in_features % 8 == 0, "in_features must be divisible by 8"
        packed_k = in_features // 8
        num_groups = in_features // group_size

        self.register_buffer(
            'qweight', torch.zeros(out_features, packed_k, dtype=torch.int32)
        )
        self.register_buffer(
            'scales', torch.zeros(out_features, num_groups, dtype=torch.float16)
        )
        if bias:
            self.register_buffer(
                'bias', torch.zeros(out_features, dtype=torch.float16)
            )
        else:
            self.bias = None

    @staticmethod
    def pack_weights(int4_weights):
        """Pack INT4 weights (range [-8, 7]) into int32.

        int4_weights: shape (N, K), dtype int8, values in [-8, 7]
        Returns: shape (N, K // 8), dtype int32
        """
        N, K = int4_weights.shape
        assert K % 8 == 0

        # Shift to unsigned: [0, 15]
        unsigned = (int4_weights.to(torch.int32) + 8) & 0xF

        # Pack 8 values per int32
        packed = torch.zeros(N, K // 8, dtype=torch.int32,
                             device=int4_weights.device)
        for i in range(8):
            packed |= unsigned[:, i::8] << (4 * i)

        return packed

    @staticmethod
    def unpack_weights(packed):
        """Unpack int32 to INT4 weights.

        packed: shape (N, K // 8), dtype int32
        Returns: shape (N, K), dtype int8, values in [-8, 7]
        """
        N, packed_k = packed.shape
        K = packed_k * 8

        unpacked = torch.zeros(N, K, dtype=torch.int8,
                               device=packed.device)
        for i in range(8):
            unpacked[:, i::8] = ((packed >> (4 * i)) & 0xF).to(torch.int8) - 8

        return unpacked

    def dequantize(self):
        """Dequantize packed INT4 weights to FP16."""
        int4_weights = self.unpack_weights(self.qweight)  # (N, K)

        # Reshape for per-group dequantization
        N = self.out_features
        K = self.in_features
        num_groups = K // self.group_size

        w_grouped = int4_weights.reshape(N, num_groups, self.group_size)
        scales = self.scales.unsqueeze(2)  # (N, num_groups, 1)

        w_deq = w_grouped.to(torch.float16) * scales
        return w_deq.reshape(N, K)

    def forward(self, x):
        """Forward pass: dequantize weights and compute GEMM.

        In production, this dequantization happens inside a fused
        CUDA kernel (Marlin). This Python version is for correctness
        verification only.
        """
        W_fp16 = self.dequantize()  # (N, K)
        output = x @ W_fp16.T       # (*, K) @ (K, N) -> (*, N)
        if self.bias is not None:
            output = output + self.bias
        return output

    @classmethod
    def from_float(cls, linear, group_size=128):
        """Quantize a float linear layer to W4A16."""
        in_f = linear.in_features
        out_f = linear.out_features

        layer = cls(in_f, out_f, group_size, bias=linear.bias is not None)

        W = linear.weight.data.float()
        num_groups = in_f // group_size

        # Per-group quantization
        W_grouped = W.reshape(out_f, num_groups, group_size)
        group_max = W_grouped.abs().amax(dim=2)
        scales = group_max / 7.0  # INT4 symmetric: [-8, 7], qmax=7
        scales = scales.clamp(min=1e-10)

        W_q = (W_grouped / scales.unsqueeze(2)).round().clamp(-8, 7)
        W_q = W_q.reshape(out_f, in_f).to(torch.int8)

        layer.qweight.copy_(cls.pack_weights(W_q))
        layer.scales.copy_(scales.to(torch.float16))

        if linear.bias is not None:
            layer.bias.copy_(linear.bias.data.to(torch.float16))

        return layer

# Verify correctness
torch.manual_seed(42)
linear_fp = nn.Linear(4096, 4096, bias=False)
nn.init.normal_(linear_fp.weight, std=0.02)

w4_layer = W4A16Linear.from_float(linear_fp, group_size=128)

x = torch.randn(1, 32, 4096)
with torch.no_grad():
    y_fp = linear_fp(x)
    y_w4 = w4_layer(x.half()).float()

mse = ((y_fp - y_w4) ** 2).mean().item()
cos_sim = torch.nn.functional.cosine_similarity(
    y_fp.flatten(), y_w4.flatten(), dim=0
).item()

print(f"Output MSE: {mse:.6e}")
print(f"Cosine similarity: {cos_sim:.8f}")

vLLM Integration

Marlin is integrated into vLLM as the default kernel for GPTQ and AWQ models. The integration involves:

# vLLM's Marlin integration (simplified)
# Location: vllm/model_executor/layers/quantization/gptq_marlin.py

class GPTQMarlinLinearMethod:
    """Marlin kernel for GPTQ-quantized models."""

    def __init__(self, quant_config):
        self.group_size = quant_config.group_size
        self.bits = quant_config.bits  # 4 or 8

    def create_weights(self, layer, input_size, output_size, params_dtype):
        """Allocate quantized weight buffers in Marlin layout."""
        pack_factor = 32 // self.bits  # 8 for INT4
        packed_input = input_size // pack_factor

        # Marlin expects a specific memory layout
        qweight = torch.zeros(
            packed_input, output_size,  # Note: transposed vs standard
            dtype=torch.int32
        )
        scales = torch.zeros(
            input_size // self.group_size, output_size,
            dtype=params_dtype
        )

        layer.register_parameter('qweight', nn.Parameter(qweight, requires_grad=False))
        layer.register_parameter('scales', nn.Parameter(scales, requires_grad=False))

    def apply(self, layer, x, bias=None):
        """Run Marlin GEMM kernel."""
        # In production, this calls the Marlin CUDA kernel:
        # marlin.marlin_gemm(x, layer.qweight, layer.scales,
        #                     layer.workspace, x.shape[0], ...)

        # The kernel handles:
        # 1. INT4 unpacking
        # 2. Per-group dequantization using scales
        # 3. FP16 tensor core GEMM
        # 4. Output accumulation in FP32, cast to FP16

        output = marlin_gemm(
            x, layer.qweight, layer.scales,
            layer.workspace
        )

        if bias is not None:
            output = output + bias
        return output

Model Loading Pipeline

# Loading a GPTQ model with Marlin kernel in vLLM:

# Load GPTQ checkpoint (safetensors format)
# Verify compatibility: group_size=128, bits=4, symmetric, no act_order
# Repack weights from GPTQ layout to Marlin layout
# Store repacked weights on GPU

# Marlin has specific requirements:
MARLIN_REQUIREMENTS = {
    'bits': [4],              # Only 4-bit currently
    'group_size': [128, -1],  # 128 or channelwise
    'symmetric': True,        # No zero-point
    'act_order': False,       # GPTQ act_order breaks Marlin layout
    'min_N': 64,              # Minimum output dimension
    'min_K': 128,             # Minimum input dimension
}

When W4A16 is the Right Choice

# Decision matrix for weight format selection

def recommend_weight_format(
    batch_size,
    latency_slo_ms,
    gpu_type,
    model_size_B,
):
    """Recommend weight format based on deployment constraints."""

    if batch_size <= 8:
        # Decode-dominant: bandwidth-bound
        if model_size_B <= 13:
            return "W4A16 (GPTQ/AWQ + Marlin)"
        else:
            return "W4A16 (GPTQ/AWQ + Marlin), multi-GPU"

    elif batch_size <= 64:
        # Mixed regime
        if gpu_type in ['H100', 'A100']:
            return "W4A16 for decode, FP8/INT8 for prefill"
        else:
            return "W4A16 (Marlin)"

    else:
        # Throughput-dominant: compute-bound
        if gpu_type == 'H100':
            return "FP8 (W8A8 E4M3)"
        elif gpu_type == 'A100':
            return "W8A8 INT8 (SmoothQuant)"
        else:
            return "W4A16 (still BW-bound on older GPUs)"
📊

Format Selection by Deployment Scenario

ScenarioBatch SizeGPUBest FormatReason
Chat (1 user) 1 A100 W4A16 Marlin BW-bound, 3.8x faster
Chat (8 users) 8 H100 W4A16 Marlin Still BW-bound
Batch API 128 H100 FP8 W8A8 Compute-bound, 2x TC
Batch API 128 A100 INT8 W8A8 Compute-bound, 2x TC
Embedding 256 Any FP16 or FP8 Prefill-only, compute-bound
Edge (RTX 4090) 1 4090 W4A16 ExLlama Consumer GPU, BW-bound