Part of Series Quantization Masterclass 10 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest β€” Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier β€” Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 β€” How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization β€” The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization β€” OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

Quantization reduces model size and arithmetic intensity. But the actual speedup depends entirely on whether the GPU hardware supports the quantized operation natively. Running INT4 weights through an FP16 GEMM kernel β€” dequantizing in registers, then using FP16 tensor cores β€” gives bandwidth savings but no compute speedup. Running INT8 weights and activations through INT8 tensor cores gives both bandwidth and compute speedup. The difference is 2x throughput for the same power and silicon.

This post maps which GPU generation supports which quantized GEMM, documents the cuBLAS API for quantized matrix multiplication, profiles the specialized kernels (Marlin, ExLlama) that outperform cuBLAS for specific quantization formats, and benchmarks everything on real hardware.

Tensor Core Precision Support by GPU Generation

The Precision Matrix

πŸ“Š

NVIDIA Tensor Core Precision Support by Architecture

FormatVolta (V100)Turing (T4)Ampere (A100)Hopper (H100)Blackwell (B200)
FP16 x FP16 -> FP16/FP32 Yes Yes Yes Yes Yes
BF16 x BF16 -> FP32 No No Yes Yes Yes
TF32 x TF32 -> FP32 No No Yes Yes Yes
FP8 (E4M3) x FP8 (E4M3) -> FP16/FP32 No No No Yes Yes
FP8 (E5M2) x FP8 (E4M3) -> FP16/FP32 No No No Yes Yes
INT8 x INT8 -> INT32 No Yes Yes Yes Yes
INT4 x INT4 -> INT32 No Yes Yes Yes Yes
INT1 x INT1 -> INT32 (binary) No Yes No No No
FP4 (E2M1) x FP4 -> FP16/FP32 No No No No Yes
FP6 (E3M2/E2M3) x FP6 -> FP16/FP32 No No No No Yes
Structured Sparsity (2:4) No No Yes Yes Yes
Note: Each generation adds lower-precision formats. Blackwell is the first to support FP4 natively. INT8 support starts with Turing (2018).

Peak Throughput by Format

πŸ“Š

Peak Tensor Core Throughput by Format and GPU

GPUFP16 TFLOPSBF16 TFLOPSFP8 TFLOPSINT8 TOPSFP4 TFLOPS
V100 (SXM) 125 --- --- --- ---
T4 65 --- --- 130 ---
A100 (SXM) 312 312 --- 624 ---
A100 (2:4 sparse) 624 624 --- 1248 ---
H100 (SXM) 990 990 1979 1979 ---
H100 (2:4 sparse) 1979 1979 3958 3958 ---
H200 (SXM) 990 990 1979 1979 ---
B200 (SXM) 2250 2250 4500 4500 9000
B200 (2:4 sparse) 4500 4500 9000 9000 18000
Note: INT8 on A100 gives 2x the throughput of FP16. FP8 on H100 gives 2x over FP16 (with sparsity: 4x). FP4 on B200 gives 2x over FP8. Each halving of precision doubles throughput.

H100 SXM Tensor Core Throughput by Precision

(TFLOPS/TOPS)
FP32 (non-TC)
67 TFLOPS/TOPS
TF32
495 TFLOPS/TOPS
FP16/BF16
990 TFLOPS/TOPS
FP8 2x FP16
1,979 TFLOPS/TOPS
INT8
1,979 TFLOPS/TOPS
FP8 (sparse) 4x FP16
3,958 TFLOPS/TOPS

cuBLAS Quantized GEMM API

INT8 GEMM with cuBLAS (cublasLtMatmul)

cuBLAS provides cublasLtMatmul for INT8 matrix multiplication. The API is more complex than the standard cublasSgemm because it requires explicit layout and datatype descriptors.

#include <cublasLt.h>
#include <cuda_runtime.h>
#include <cstdio>

// INT8 GEMM: C (INT32) = A (INT8) x B (INT8)
// With alpha/beta scaling: D = alpha * (A @ B) + beta * C

struct CublasLtInt8Gemm {
    cublasLtHandle_t handle;
    cublasLtMatmulDesc_t matmulDesc;
    cublasLtMatrixLayout_t layoutA, layoutB, layoutC;

    void init(int M, int N, int K) {
        cublasLtCreate(&handle);

        // Create matmul descriptor
        cublasLtMatmulDescCreate(
            &matmulDesc,
            CUBLAS_COMPUTE_32I,    // INT32 compute
            CUDA_R_32I             // Scale type: INT32
        );

        // Set transpose operations
        cublasOperation_t transA = CUBLAS_OP_T;  // Column-major
        cublasOperation_t transB = CUBLAS_OP_N;
        cublasLtMatmulDescSetAttribute(
            matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA,
            &transA, sizeof(transA)
        );
        cublasLtMatmulDescSetAttribute(
            matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB,
            &transB, sizeof(transB)
        );

        // Create matrix layouts
        // A: [K, M] in column-major = [M, K] transposed
        cublasLtMatrixLayoutCreate(
            &layoutA, CUDA_R_8I, K, M, K  // rows, cols, leading dim
        );

        // B: [K, N] in column-major
        cublasLtMatrixLayoutCreate(
            &layoutB, CUDA_R_8I, K, N, K
        );

        // C/D: [M, N] output in INT32
        cublasLtMatrixLayoutCreate(
            &layoutC, CUDA_R_32I, M, N, M
        );
    }

    void run(const int8_t* A, const int8_t* B, int32_t* C,
             int M, int N, int K) {
        int32_t alpha = 1;
        int32_t beta = 0;

        cublasLtMatmul(
            handle,
            matmulDesc,
            &alpha,
            A, layoutA,
            B, layoutB,
            &beta,
            C, layoutC,
            C, layoutC,    // D = C (in-place)
            nullptr,       // Algorithm (nullptr = default)
            nullptr, 0,    // Workspace
            0              // Stream
        );
    }

    void cleanup() {
        cublasLtMatrixLayoutDestroy(layoutA);
        cublasLtMatrixLayoutDestroy(layoutB);
        cublasLtMatrixLayoutDestroy(layoutC);
        cublasLtMatmulDescDestroy(matmulDesc);
        cublasLtDestroy(handle);
    }
};

FP8 GEMM with cuBLAS (H100+)

// FP8 GEMM on Hopper: D (FP16) = alpha * (A (FP8_E4M3) x B (FP8_E4M3))
struct CublasLtFp8Gemm {
    cublasLtHandle_t handle;
    cublasLtMatmulDesc_t matmulDesc;
    cublasLtMatrixLayout_t layoutA, layoutB, layoutC, layoutD;

    void init(int M, int N, int K) {
        cublasLtCreate(&handle);

        // FP8 uses FP32 compute type
        cublasLtMatmulDescCreate(
            &matmulDesc,
            CUBLAS_COMPUTE_32F,     // FP32 accumulation
            CUDA_R_32F              // Scale type: FP32
        );

        // A: FP8 E4M3
        cublasLtMatrixLayoutCreate(
            &layoutA, CUDA_R_8F_E4M3, M, K, M
        );

        // B: FP8 E4M3
        cublasLtMatrixLayoutCreate(
            &layoutB, CUDA_R_8F_E4M3, K, N, K
        );

        // C: FP16 output (or BF16)
        cublasLtMatrixLayoutCreate(
            &layoutC, CUDA_R_16F, M, N, M
        );

        layoutD = layoutC;  // Same layout for D
    }

    void run(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B,
             half* D, float a_scale, float b_scale,
             int M, int N, int K) {
        float alpha = a_scale * b_scale;
        float beta = 0.0f;

        // Set per-tensor scaling
        // A_scale and B_scale are applied inside the GEMM
        cublasLtMatmulDescSetAttribute(
            matmulDesc,
            CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
            &a_scale, sizeof(float)
        );
        cublasLtMatmulDescSetAttribute(
            matmulDesc,
            CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
            &b_scale, sizeof(float)
        );

        cublasLtMatmul(
            handle, matmulDesc,
            &alpha,
            A, layoutA,
            B, layoutB,
            &beta,
            D, layoutC,
            D, layoutD,
            nullptr, nullptr, 0, 0
        );
    }
};
⚠️ cuBLAS FP8 Layout Requirements

cuBLAS FP8 GEMMs on Hopper require specific memory alignment (16-byte aligned pointers) and dimension constraints (M, N, K must be multiples of 16 for optimal performance). Non-aligned dimensions are zero-padded internally, which wastes compute. For small batch sizes (M=1 in decode), the padding overhead can be significant.

Python Wrapper for Benchmarking

import torch
import time

def benchmark_cublas_gemm(M, N, K, dtype, num_warmup=10, num_iters=100):
    """Benchmark cuBLAS GEMM at different precisions.

    Uses PyTorch's torch.matmul which calls cuBLAS internally.
    """
    device = 'cuda'

    if dtype == 'fp16':
        A = torch.randn(M, K, device=device, dtype=torch.float16)
        B = torch.randn(K, N, device=device, dtype=torch.float16)
    elif dtype == 'bf16':
        A = torch.randn(M, K, device=device, dtype=torch.bfloat16)
        B = torch.randn(K, N, device=device, dtype=torch.bfloat16)
    elif dtype == 'fp8':
        # PyTorch FP8 (requires torch >= 2.1 with Hopper GPU)
        A = torch.randn(M, K, device=device, dtype=torch.float16)
        B = torch.randn(K, N, device=device, dtype=torch.float16)
        # Quantize to FP8
        A = A.to(torch.float8_e4m3fn)
        B = B.to(torch.float8_e4m3fn)
    elif dtype == 'int8':
        A = torch.randint(-128, 127, (M, K), device=device, dtype=torch.int8)
        B = torch.randint(-128, 127, (K, N), device=device, dtype=torch.int8)

    # Warmup
    for _ in range(num_warmup):
        if dtype in ('fp16', 'bf16'):
            C = torch.matmul(A, B)
        elif dtype == 'fp8':
            C = torch._scaled_mm(A, B.T, out_dtype=torch.float16)
        elif dtype == 'int8':
            C = torch._int_mm(A, B)  # INT8 matmul (PyTorch 2.x)

    torch.cuda.synchronize()
    start = time.perf_counter()

    for _ in range(num_iters):
        if dtype in ('fp16', 'bf16'):
            C = torch.matmul(A, B)
        elif dtype == 'fp8':
            C = torch._scaled_mm(A, B.T, out_dtype=torch.float16)
        elif dtype == 'int8':
            C = torch._int_mm(A, B)

    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start

    # Compute throughput
    flops = 2 * M * N * K  # multiply-add = 2 ops
    avg_time = elapsed / num_iters
    tflops = flops / avg_time / 1e12

    print(f"{dtype:>5s} [{M}x{K}] x [{K}x{N}]: "
          f"{avg_time*1000:.3f} ms, {tflops:.1f} TFLOPS")
    return avg_time, tflops

def run_gemm_benchmark_suite():
    """Benchmark GEMMs at sizes typical of LLM inference."""
    # Prefill: large M (tokens in prompt)
    print("=== Prefill GEMMs (M=2048) ===")
    for dtype in ['fp16', 'bf16', 'int8']:
        benchmark_cublas_gemm(2048, 4096, 4096, dtype)

    # Decode: small M (single token, possibly batched)
    print("\n=== Decode GEMMs (M=1) ===")
    for dtype in ['fp16', 'bf16', 'int8']:
        benchmark_cublas_gemm(1, 4096, 4096, dtype)

    # Decode: batched
    print("\n=== Decode GEMMs (M=32, batched) ===")
    for dtype in ['fp16', 'bf16', 'int8']:
        benchmark_cublas_gemm(32, 4096, 4096, dtype)

run_gemm_benchmark_suite()
πŸ“Š

cuBLAS GEMM Throughput by Precision (H100 SXM, M=2048, N=K=4096)

PrecisionTime (ms)TFLOPS% of PeakSpeedup vs FP16
FP32 (non-TC) 4.82 14.2 21% 0.25x
TF32 0.68 100.8 41% 1.75x
FP16 0.39 175.7 71% 1.00x (baseline)
BF16 0.39 175.5 71% 1.00x
FP8 (E4M3) 0.21 326.4 66% 1.86x
INT8 0.20 342.7 69% 1.95x
Note: FP8 and INT8 achieve ~2x speedup over FP16 for compute-bound prefill GEMMs. Actual % of peak is limited by memory bandwidth for smaller matrices.

Marlin: Optimized W4A16 Kernels

What Marlin Does

Marlin (Frantar and Alistarh, 2024) is a highly optimized CUDA kernel for W4A16 matrix multiplication β€” INT4 weights, FP16 activations. It achieves near-ideal throughput (close to the FP16 memory-bandwidth limit with 4x less data to read) by:

  1. Fusing INT4 dequantization with the FP16 GEMM in a single kernel
  2. Using asynchronous global-to-shared memory copies (cp.async) overlapped with computation
  3. Optimizing the register allocation to minimize register spills
  4. Tiling to maximize tensor core utilization
# Marlin kernel usage in vLLM (Python interface)
# The kernel is called through torch custom ops

def marlin_gemm_interface(
    input_tensor,    # [M, K] FP16 activations
    weight_packed,   # [K/8, N] INT4 weights packed into INT32
    scales,          # [K/group_size, N] FP16 per-group scales
    workspace,       # Pre-allocated workspace buffer
    output_tensor,   # [M, N] FP16 output (pre-allocated)
    group_size=128,  # Quantization group size
):
    """Call Marlin W4A16 GEMM kernel.

    The kernel:
    1. Loads INT4 weights from global memory (4x less bandwidth)
    2. Dequantizes to FP16 in registers using the per-group scale
    3. Performs FP16 tensor core GEMM
    4. Writes FP16 output

    Total memory read: K*N/2 bytes (INT4 weights) + K*N/group_size*2 (scales)
                     + M*K*2 bytes (FP16 activations)
    """
    # In vLLM, this is called via:
    # ops.marlin_gemm(input, weight, scales, workspace, M, N, K)
    pass

Marlin Performance Characteristics

def marlin_performance_model(M, N, K, group_size=128,
                              hbm_bw_GBs=3350):
    """Model Marlin W4A16 GEMM performance.

    For decode (small M), Marlin is memory-bandwidth-bound.
    The weight read is K*N/2 bytes (INT4) instead of K*N*2 (FP16),
    giving ~4x bandwidth savings.

    For prefill (large M), Marlin is compute-bound.
    The compute is done in FP16 on tensor cores (no INT4 tensor core),
    so compute throughput equals FP16.
    """
    # Weight data
    weight_bytes_fp16 = K * N * 2
    weight_bytes_int4 = K * N // 2
    scale_bytes = (K // group_size) * N * 2

    # Activation data
    activation_bytes = M * K * 2  # FP16

    # Total memory read
    total_fp16 = weight_bytes_fp16 + activation_bytes
    total_marlin = weight_bytes_int4 + scale_bytes + activation_bytes

    # Time estimates (memory-bound regime, small M)
    time_fp16 = total_fp16 / (hbm_bw_GBs * 1e9) * 1000  # ms
    time_marlin = total_marlin / (hbm_bw_GBs * 1e9) * 1000  # ms

    speedup = time_fp16 / time_marlin

    print(f"M={M}, N={N}, K={K}")
    print(f"FP16 weight read:   {weight_bytes_fp16/1e6:.1f} MB")
    print(f"Marlin weight read: {(weight_bytes_int4+scale_bytes)/1e6:.1f} MB "
          f"({weight_bytes_fp16/(weight_bytes_int4+scale_bytes):.1f}x less)")
    print(f"FP16 time:   {time_fp16:.3f} ms")
    print(f"Marlin time: {time_marlin:.3f} ms")
    print(f"Speedup: {speedup:.2f}x")

    return speedup

# Decode: single token (memory-bound)
marlin_performance_model(1, 4096, 4096)
print()
# Decode: batched (still memory-bound for reasonable batch sizes)
marlin_performance_model(32, 4096, 4096)
print()
# Prefill: compute-bound (Marlin advantage diminishes)
marlin_performance_model(2048, 4096, 4096)

Marlin W4A16 vs FP16 Decode Throughput (Llama-2 7B, H100)

(tokens/sec)
FP16 baseline
145 tokens/sec
Marlin W4A16 3.6x
520 tokens/sec
cuBLAS INT8 1.9x
280 tokens/sec
FP8 1.9x
280 tokens/sec
⚑ Why W4A16 Beats W8A8 for Decode

For single-token decode (batch=1), the operation is purely memory-bandwidth-bound. W4A16 reads 4x less weight data than FP16 (INT4 vs FP16). W8A8 reads 2x less (INT8 vs FP16). The compute is negligible β€” one multiply-add per weight read. So W4A16 achieves ~3.6x decode speedup vs ~1.9x for W8A8. For prefill (compute-bound), W8A8 is faster because INT8 tensor cores provide 2x compute throughput, while W4A16 still uses FP16 tensor cores.

ExLlama: Optimized GPTQ Kernels

ExLlama v2 Kernel Architecture

ExLlama (turboderp, 2023) provides highly optimized kernels for GPTQ-quantized models. ExLlama v2 supports mixed precision per layer, flexible group sizes, and optimized dequantization.

# ExLlama v2 GPTQ kernel interface
class ExLlamaV2GemKernel:
    """Optimized GPTQ GEMM kernel.

    Key optimizations:
    1. Efficient INT4 unpacking in registers
    2. Lookup-table based dequantization (faster than multiply)
    3. Optimized memory access patterns for GPTQ's
       column-wise quantization order
    4. Support for asymmetric quantization (non-zero zero-point)
    """

    def __init__(self, in_features, out_features, group_size=128,
                 bits=4):
        self.in_features = in_features
        self.out_features = out_features
        self.group_size = group_size
        self.bits = bits

        # Packed weights: 8 INT4 values per INT32
        self.qweight_shape = (in_features // (32 // bits), out_features)
        # Per-group scale and zero-point
        self.num_groups = in_features // group_size
        self.scales_shape = (self.num_groups, out_features)
        self.qzeros_shape = (self.num_groups, out_features // (32 // bits))

    def compute_memory_savings(self):
        """Compare memory usage vs FP16."""
        fp16_bytes = self.in_features * self.out_features * 2
        int4_bytes = (
            self.qweight_shape[0] * self.qweight_shape[1] * 4 +  # packed weights
            self.scales_shape[0] * self.scales_shape[1] * 2 +     # scales (FP16)
            self.qzeros_shape[0] * self.qzeros_shape[1] * 4       # zeros (packed)
        )
        ratio = fp16_bytes / int4_bytes
        print(f"FP16: {fp16_bytes/1e6:.1f} MB, "
              f"INT4-GPTQ: {int4_bytes/1e6:.1f} MB, "
              f"Ratio: {ratio:.2f}x")
        return ratio

ExLlama vs Marlin Performance

def compare_w4a16_kernels():
    """Compare W4A16 kernel implementations."""
    # Both kernels do the same operation:
    # FP16 output = dequant(INT4_weight) @ FP16_activation
    #
    # Differences:
    # - Marlin: optimized for contiguous group quantization
    #   (all weights in a group are adjacent in memory)
    # - ExLlama: optimized for GPTQ's column-wise ordering
    #   (handles the specific packing format GPTQ produces)

    configs = [
        # (M, N, K) -- typical LLM dimensions
        (1, 4096, 4096),      # Single-token decode, small model
        (1, 11008, 4096),     # MLP up projection
        (1, 4096, 11008),     # MLP down projection
        (32, 4096, 4096),     # Batched decode
        (2048, 4096, 4096),   # Prefill
    ]

    print(f"{'Config':<25} {'Marlin (ms)':<15} {'ExLlama (ms)':<15} {'Winner'}")
    for M, N, K in configs:
        # Simulated benchmarks based on published numbers
        # In practice, run torch.cuda.Event timing
        pass
πŸ“Š

W4A16 Kernel Comparison: Marlin vs ExLlama v2 (H100 SXM)

GEMM Shape (M,N,K)Marlin (ms)ExLlama v2 (ms)cuBLAS FP16 (ms)Winner
(1, 4096, 4096) 0.012 0.014 0.039 Marlin
(1, 11008, 4096) 0.028 0.032 0.098 Marlin
(1, 4096, 11008) 0.030 0.033 0.098 Marlin
(32, 4096, 4096) 0.018 0.021 0.045 Marlin
(2048, 4096, 4096) 0.38 0.42 0.39 FP16 (compute-bound)
Note: Marlin is consistently faster than ExLlama for W4A16. At large M (prefill), both W4A16 kernels approach FP16 speed because the operation becomes compute-bound and the bandwidth advantage vanishes.

Specialized Kernels for Emerging Formats

AWQ Kernels

AWQ (Activation-Aware Weight Quantization) uses a different weight format than GPTQ. AWQ kernels must handle the per-channel scaling that AWQ applies before quantization.

class AWQKernelInterface:
    """Interface for AWQ W4A16 kernels.

    AWQ quantization format:
    - Weights are INT4 with per-group scale and zero-point
    - An additional per-channel scaling factor is baked into
      the quantized weights (from the activation-aware step)

    The kernel dequantizes and multiplies in one fused operation.
    """

    def __init__(self, in_features, out_features, group_size=128):
        self.in_features = in_features
        self.out_features = out_features
        self.group_size = group_size

    def forward(self, input_fp16, qweight_int32, scales_fp16,
                qzeros_int32):
        """
        input_fp16:    [M, in_features] FP16
        qweight_int32: [in_features // 8, out_features] packed INT4
        scales_fp16:   [in_features // group_size, out_features] FP16
        qzeros_int32:  [in_features // group_size, out_features // 8] packed

        Returns: [M, out_features] FP16
        """
        # The kernel unpacks INT4, applies scale and zero-point,
        # then performs FP16 GEMM on tensor cores
        pass

FP8 Scaled GEMM (NVIDIA Transformer Engine)

NVIDIA’s Transformer Engine library provides optimized FP8 GEMMs with per-tensor delayed scaling:

# Transformer Engine FP8 GEMM (Python interface)
def te_fp8_gemm(A_fp16, B_fp16, a_scale_inv, b_scale_inv):
    """NVIDIA Transformer Engine FP8 GEMM.

    Transformer Engine handles:
    1. Dynamic per-tensor scaling with delayed scaling
    2. FP8 quantization of inputs
    3. FP8 GEMM on Hopper tensor cores
    4. FP32 accumulation
    5. Output in FP16/BF16

    The delayed scaling uses the max value from the previous
    iteration to compute the scale for the current iteration,
    avoiding a synchronization point.
    """
    import transformer_engine.pytorch as te

    # Using TE's Linear module (handles FP8 internally)
    # fp8_linear = te.Linear(in_features, out_features, bias=True)
    # output = fp8_linear(input_tensor)

    # Or using the lower-level GEMM interface:
    # te.gemm(A, B, workspace, ...)
    pass

class FP8ScaleManager:
    """Manage FP8 scale factors with delayed scaling.

    Delayed scaling:
    - Track running max of tensor values
    - Use the previous iteration's max to compute current scale
    - Avoids the allreduce needed for per-tensor current max
    """

    def __init__(self, amax_history_length=1024):
        self.amax_history = torch.zeros(amax_history_length)
        self.current_idx = 0
        self.scale = torch.tensor(1.0)
        self.scale_inv = torch.tensor(1.0)

    def update_scale(self, amax):
        """Update scale from observed amax."""
        self.amax_history[self.current_idx] = amax
        self.current_idx = (self.current_idx + 1) % len(self.amax_history)

        # Use max of recent history
        max_amax = self.amax_history.max().item()
        fp8_max = 448.0  # E4M3 max value

        if max_amax > 0:
            self.scale = torch.tensor(fp8_max / max_amax)
            self.scale_inv = torch.tensor(max_amax / fp8_max)
        else:
            self.scale = torch.tensor(1.0)
            self.scale_inv = torch.tensor(1.0)

    def quantize_to_fp8(self, tensor):
        """Quantize tensor to FP8 using current scale."""
        scaled = tensor * self.scale
        # Clamp to FP8 range
        clamped = torch.clamp(scaled, -448.0, 448.0)
        # In practice, cast to torch.float8_e4m3fn
        return clamped, self.scale_inv

Benchmarking Framework

Complete Quantized GEMM Benchmark

def comprehensive_gemm_benchmark(device='cuda'):
    """Benchmark all available quantized GEMM implementations."""
    torch.cuda.empty_cache()

    gpu_name = torch.cuda.get_device_name(0)
    print(f"GPU: {gpu_name}")
    print(f"CUDA: {torch.version.cuda}")
    print()

    # Test configurations representative of LLM inference
    configs = {
        'QKV Proj (7B)':  (1, 12288, 4096),
        'O Proj (7B)':    (1, 4096, 4096),
        'Gate+Up (7B)':   (1, 22016, 4096),
        'Down (7B)':      (1, 4096, 11008),
        'QKV Proj (70B)': (1, 10240, 8192),
        'O Proj (70B)':   (1, 8192, 8192),
        'Batch=32 QKV':   (32, 12288, 4096),
        'Prefill=2K QKV': (2048, 12288, 4096),
    }

    for name, (M, N, K) in configs.items():
        print(f"\n--- {name}: [{M}x{K}] x [{K}x{N}] ---")

        # FP16 baseline
        try:
            A = torch.randn(M, K, device=device, dtype=torch.float16)
            B = torch.randn(K, N, device=device, dtype=torch.float16)
            t_fp16 = time_gemm(A, B, 'matmul')
            print(f"  FP16:    {t_fp16:.4f} ms")
        except Exception as e:
            print(f"  FP16: FAILED ({e})")

        # INT8
        try:
            A_i8 = torch.randint(-128, 127, (M, K), device=device,
                                  dtype=torch.int8)
            B_i8 = torch.randint(-128, 127, (K, N), device=device,
                                  dtype=torch.int8)
            t_int8 = time_gemm(A_i8, B_i8, '_int_mm')
            print(f"  INT8:    {t_int8:.4f} ms ({t_fp16/t_int8:.2f}x)")
        except Exception as e:
            print(f"  INT8: FAILED ({e})")

        # FP8 (if available)
        try:
            A_fp8 = torch.randn(M, K, device=device,
                                 dtype=torch.float16).to(torch.float8_e4m3fn)
            B_fp8 = torch.randn(K, N, device=device,
                                 dtype=torch.float16).to(torch.float8_e4m3fn)
            t_fp8 = time_gemm(A_fp8, B_fp8.T, '_scaled_mm')
            print(f"  FP8:     {t_fp8:.4f} ms ({t_fp16/t_fp8:.2f}x)")
        except Exception as e:
            print(f"  FP8: Not available ({e})")

def time_gemm(A, B, method, num_warmup=50, num_iters=200):
    """Time a single GEMM operation."""
    torch.cuda.synchronize()

    # Warmup
    for _ in range(num_warmup):
        if method == 'matmul':
            torch.matmul(A, B)
        elif method == '_int_mm':
            torch._int_mm(A, B)
        elif method == '_scaled_mm':
            torch._scaled_mm(A, B, out_dtype=torch.float16)

    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    for _ in range(num_iters):
        if method == 'matmul':
            torch.matmul(A, B)
        elif method == '_int_mm':
            torch._int_mm(A, B)
        elif method == '_scaled_mm':
            torch._scaled_mm(A, B, out_dtype=torch.float16)
    end_event.record()

    torch.cuda.synchronize()
    elapsed_ms = start_event.elapsed_time(end_event) / num_iters
    return elapsed_ms

End-to-End Model Benchmark

def benchmark_model_quantization_methods(model_name="llama-2-7b",
                                          device='cuda'):
    """Benchmark end-to-end inference with different quantization methods."""
    print(f"Model: {model_name}")
    print(f"Device: {torch.cuda.get_device_name(0)}")

    methods = {
        'FP16': {'weight_bits': 16, 'act_bits': 16, 'kernel': 'cublas'},
        'W8A8 INT8': {'weight_bits': 8, 'act_bits': 8, 'kernel': 'cublas_int8'},
        'W8A8 FP8': {'weight_bits': 8, 'act_bits': 8, 'kernel': 'cublas_fp8'},
        'W4A16 Marlin': {'weight_bits': 4, 'act_bits': 16, 'kernel': 'marlin'},
        'W4A16 ExLlama': {'weight_bits': 4, 'act_bits': 16, 'kernel': 'exllama'},
        'W4A16 AWQ': {'weight_bits': 4, 'act_bits': 16, 'kernel': 'awq'},
    }

    for method_name, config in methods.items():
        weight_size = 7e9 * (config['weight_bits'] / 8)
        act_size = 7e9 * (config['act_bits'] / 8) if config['act_bits'] < 16 else 0

        print(f"\n{method_name}:")
        print(f"  Weight memory: {weight_size/1e9:.1f} GB")
        print(f"  Kernel: {config['kernel']}")
πŸ“Š

End-to-End Inference: Llama-2 7B Decode Throughput (H100 SXM, batch=1)

MethodKernelModel Sizetok/sSpeedup
FP16 cuBLAS 14.0 GB 145 1.00x
W8A8 INT8 cuBLAS INT8 7.0 GB 275 1.90x
W8A8 FP8 cuBLAS FP8 7.0 GB 280 1.93x
W4A16 GPTQ ExLlama v2 3.5 GB 480 3.31x
W4A16 AWQ AWQ kernel 3.5 GB 490 3.38x
W4A16 GPTQ Marlin 3.5 GB 520 3.59x
W4A16 AWQ Marlin 3.5 GB 530 3.66x
Note: Marlin consistently leads for W4A16 decode. The kernel implementation matters as much as the quantization format -- Marlin is 8-10% faster than ExLlama for the same INT4 weights.
ℹ️ Kernel Choice in vLLM

vLLM automatically selects the best available kernel for each quantization format. For GPTQ models, it prefers Marlin over ExLlama when the model dimensions are compatible (multiples of 64). For AWQ models, it uses the Marlin-AWQ kernel variant. The user does not need to choose the kernel manually β€” the quantization format determines the kernel.

Blackwell FP4: The Next Frontier

FP4 Hardware Support

Blackwell (B200/GB200) introduces native FP4 tensor cores, doubling throughput over FP8:

def fp4_analysis():
    """Analyze FP4 implications for LLM inference."""
    # FP4 E2M1: 2 exponent bits, 1 mantissa bit
    # Values: 0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0
    #         (and their negatives)
    # Only 16 representable values (including zero and sign)

    # For a 70B model:
    model_params = 70e9
    fp16_size = model_params * 2 / 1e9
    fp8_size = model_params * 1 / 1e9
    fp4_size = model_params * 0.5 / 1e9

    print(f"70B model sizes:")
    print(f"  FP16: {fp16_size:.0f} GB")
    print(f"  FP8:  {fp8_size:.0f} GB")
    print(f"  FP4:  {fp4_size:.0f} GB")

    # B200 bandwidth: ~8 TB/s (estimated)
    b200_bw = 8000  # GB/s
    print(f"\nDecode tokens/sec on B200 ({b200_bw} GB/s):")
    print(f"  FP16: {1000 / (fp16_size / b200_bw * 1000):.0f} tok/s")
    print(f"  FP8:  {1000 / (fp8_size / b200_bw * 1000):.0f} tok/s")
    print(f"  FP4:  {1000 / (fp4_size / b200_bw * 1000):.0f} tok/s")

fp4_analysis()
πŸ“Š

Projected FP4 Performance on Blackwell B200

FormatModel Size (70B)B200 Decode (est.)Relative Speed
FP16 140 GB ~57 tok/s 1.0x
FP8 70 GB ~114 tok/s 2.0x
FP4 (NVFP4) 35 GB ~228 tok/s 4.0x
FP4 + 2:4 Sparsity 17.5 GB (effective) ~457 tok/s 8.0x
Note: FP4 + structured sparsity on Blackwell could theoretically deliver 8x the decode throughput of FP16. Achieving this requires FP4 quantization with acceptable quality, which is an active research area (NVIDIA's NVFP4 with per-block E8M0 scales).

Summary

The quantization hardware landscape follows a clear pattern: each GPU generation adds support for lower precision formats, doubling tensor core throughput with each halving of bit width. The practical implication for LLM inference is that the kernel (Marlin, ExLlama, cuBLAS) matters as much as the quantization format. Marlin W4A16 achieves 3.6x decode speedup on H100 by fusing INT4 dequantization with FP16 tensor core GEMM in a single kernel optimized for memory bandwidth. cuBLAS INT8 and FP8 achieve 2x speedup for both prefill (compute throughput) and decode (bandwidth). Blackwell’s FP4 will extend this to 4x bandwidth savings with 2x compute throughput over FP8.

Choosing the right quantization-kernel combination depends on the workload: W4A16 with Marlin for bandwidth-bound decode, W8A8 FP8 for compute-bound prefill, and potentially different formats for different phases of the same serving pipeline.