Part of Series Quantization Masterclass 14 of 30
1 Number Formats for AI: FP32, BF16, FP16, FP8 E4M3, FP8 E5M2, NVFP4, MXFP4, INT8, INT4 2 Weight Quantization: GPTQ, AWQ, and Round-To-Nearest β€” Algorithms and Implementation 3 Activation Quantization: SmoothQuant, Per-Tensor Scaling, and W8A8 Inference 4 FP8 for Training and Inference: E4M3, E5M2, Transformer Engine, and Delayed Scaling 5 FP4 and MXFP4: The Blackwell Frontier β€” Sub-Byte Quantization for Next-Gen Inference 6 KV Cache Quantization: FP8, INT8, INT4, Per-Token Scaling, and the Quality-Memory Tradeoff 7 Quantization-Aware Training: Fake Quantization, Straight-Through Estimator, and QAT vs PTQ 8 Mixed Precision Inference: Which Ops Use Which Precision and Why 9 Calibration for Post-Training Quantization: MinMax, Percentile, MSE-Optimal, and Cross-Layer 10 Quantization Hardware Support: Tensor Core Precision Matrix, cuBLAS INT8, and Marlin Kernels 11 Per-Channel vs Per-Group vs Per-Tensor Scaling: Granularity Tradeoffs in Weight Quantization 12 The Outlier Channel Problem: Why LLM Activations Break Simple Quantization 13 W4A16 Inference: 4-Bit Weights with FP16 Activations and the Marlin Kernel 14 W8A8 INT8 Inference: cuBLAS INT8 GEMM, Per-Tensor Scaling, and When INT8 Beats FP8 15 GGUF Quantization Types: Q4_K_M, Q5_K_M, Q8_0 β€” How llama.cpp Quantizes for CPU 16 AWQ Deep Dive: Activation-Aware Weight Quantization β€” The Algorithm Step by Step 17 GPTQ Deep Dive: Hessian-Based One-Shot Quantization β€” OBS, Column-Wise Updates, and Lazy Batch 18 SqueezeLLM and Non-Uniform Quantization: Lookup Tables, Sparse Outliers, and Mixed Strategies 19 Quantization for Training: FP8 GEMM, Loss Scaling, and Why BF16 Remains the Default 20 Quantization Production Guide: Choosing the Right Method for Your Model, Hardware, and Latency SLO 21 Combining Sparsity and Quantization: 2:4 Structured Sparsity with INT8 for Maximum Throughput 22 Dynamic vs Static Quantization: Online Calibration, Offline Calibration, and When Each Wins 23 AQLM and Extreme Compression: 2-Bit Quantization with Additive Codebooks 24 Quantized Draft Models for Speculative Decoding: INT4 Drafters with FP16 Verification 25 Quantization Benchmarking: How to Properly Measure Quality Loss, Throughput, and Cost Impact 26 INT4 Weight Packing: Bit Manipulation, Dequantization Kernels, and Memory Layout 27 Serving Quantized Models: vLLM, TRT-LLM, and llama.cpp Integration 28 Debugging Quantization: Layer Sensitivity, Outlier Detection, and Quality Recovery 29 Future of Quantization: Sub-4-Bit, Ternary, and Binary Neural Networks 30 End-to-End Quantization Pipeline: From FP16 Checkpoint to Production INT4 Deployment

INT8 tensor cores on Ampere deliver 2x the throughput of FP16 tensor coresβ€”but only if both matrix operands are quantized to INT8 simultaneously. That is the constraint that makes W8A8 (INT8 weights, INT8 activations) fundamentally harder than W4A16 (INT4 weights, FP16 activations). Weights are staticβ€”quantize once, serve forever. Activations are generated fresh for every token, and they contain outlier channels with magnitudes 100x larger than the median, making per-tensor INT8 activation quantization catastrophically lossy without SmoothQuant’s channel migration trick. Get the scaling wrong and you lose 3+ perplexity points. Get it right and you get true 2x tensor core speedup with near-lossless quality.

This post covers the complete W8A8 pipeline: offline weight quantization, online activation quantization with per-token scaling, the cuBLAS cublasLtMatmul INT8 API, the INT32 accumulation and FP16 dequantization math, SmoothQuant integration, and benchmarks comparing INT8 to FP8 on H100.

The INT8 GEMM Pipeline

The W8A8 GEMM computes Y=XWTY = XW^T where XX (activations) and WW (weights) are both quantized to INT8:

Yfp=(sxβ‹…Xint8)β‹…(swβ‹…Wint8)T=sxβ‹…swTβ‹…(Xint8β‹…Wint8T)Y_{\text{fp}} = (s_x \cdot X_{\text{int8}}) \cdot (s_w \cdot W_{\text{int8}})^T = s_x \cdot s_w^T \cdot (X_{\text{int8}} \cdot W_{\text{int8}}^T)

The INT8 matmul produces INT32 accumulations, which are then dequantized to FP16/BF16 using the scale factors:

Input:  X_int8 (M x K), W_int8 (N x K)
Step 1: Y_int32 = X_int8 @ W_int8^T          [INT8 tensor cores, INT32 accumulation]
Step 2: Y_fp16 = Y_int32 * (s_x * s_w^T)     [Dequantize with scale factors]
Output: Y_fp16 (M x N)

The critical insight is that the scale factors are factored out of the matmul. The tensor cores operate on pure INT8 values with INT32 accumulation. Dequantization happens after the matmul.

import torch
import numpy as np

def w8a8_gemm_reference(X_fp, W_fp, x_scale_type='per_token', w_scale_type='per_channel'):
    """Reference W8A8 INT8 GEMM implementation.

    Args:
        X_fp: FP32 activations, shape (M, K)
        W_fp: FP32 weights, shape (N, K)
        x_scale_type: 'per_tensor' or 'per_token'
        w_scale_type: 'per_tensor' or 'per_channel'

    Returns:
        Y_fp: dequantized output, shape (M, N)
    """
    M, K = X_fp.shape
    N = W_fp.shape[0]
    qmax = 127

    # Step 1: Quantize activations
    if x_scale_type == 'per_tensor':
        x_scale = X_fp.abs().max() / qmax
        x_scale = max(x_scale, 1e-10)
        X_q = torch.clamp(torch.round(X_fp / x_scale), -128, 127).to(torch.int8)
        # x_scale shape: scalar
    elif x_scale_type == 'per_token':
        x_scale = X_fp.abs().amax(dim=1, keepdim=True) / qmax
        x_scale = x_scale.clamp(min=1e-10)
        X_q = torch.clamp(torch.round(X_fp / x_scale), -128, 127).to(torch.int8)
        # x_scale shape: (M, 1)

    # Step 2: Quantize weights (offline)
    if w_scale_type == 'per_tensor':
        w_scale = W_fp.abs().max() / qmax
        w_scale = max(w_scale, 1e-10)
        W_q = torch.clamp(torch.round(W_fp / w_scale), -128, 127).to(torch.int8)
        # w_scale shape: scalar
    elif w_scale_type == 'per_channel':
        w_scale = W_fp.abs().amax(dim=1, keepdim=True) / qmax
        w_scale = w_scale.clamp(min=1e-10)
        W_q = torch.clamp(torch.round(W_fp / w_scale), -128, 127).to(torch.int8)
        # w_scale shape: (N, 1)

    # Step 3: INT8 GEMM with INT32 accumulation
    Y_int32 = X_q.int() @ W_q.int().T  # (M, N)

    # Step 4: Dequantize
    # Y_fp = Y_int32 * (x_scale * w_scale^T)
    if x_scale_type == 'per_tensor' and w_scale_type == 'per_tensor':
        Y_fp = Y_int32.float() * (x_scale * w_scale)
    elif x_scale_type == 'per_token' and w_scale_type == 'per_channel':
        # x_scale: (M, 1), w_scale: (N, 1) -> outer product: (M, N)
        Y_fp = Y_int32.float() * (x_scale * w_scale.T)
    elif x_scale_type == 'per_token' and w_scale_type == 'per_tensor':
        Y_fp = Y_int32.float() * (x_scale * w_scale)
    elif x_scale_type == 'per_tensor' and w_scale_type == 'per_channel':
        Y_fp = Y_int32.float() * (x_scale * w_scale.T)

    return Y_fp

# Test
torch.manual_seed(42)
M, N, K = 32, 4096, 4096
X = torch.randn(M, K) * 0.5
W = torch.randn(N, K) * 0.02

Y_ref = X @ W.T
Y_int8 = w8a8_gemm_reference(X, W, 'per_token', 'per_channel')

mse = ((Y_ref - Y_int8) ** 2).mean().item()
cos_sim = torch.nn.functional.cosine_similarity(
    Y_ref.flatten(), Y_int8.flatten(), dim=0
).item()
print(f"MSE: {mse:.6e}, Cosine sim: {cos_sim:.8f}")

Scale Factor Compatibility with INT8 GEMM

The scale factor strategy must be compatible with INT8 tensor core execution. The constraint is that the scale factor must be factorable out of the inner product:

Yij=βˆ‘kXikWjk=βˆ‘k(sx(i)qx(ik))(sw(j)qw(jk))=sx(i)sw(j)βˆ‘kqx(ik)qw(jk)Y_{ij} = \sum_k X_{ik} W_{jk} = \sum_k (s_x^{(i)} q_x^{(ik)}) (s_w^{(j)} q_w^{(jk)}) = s_x^{(i)} s_w^{(j)} \sum_k q_x^{(ik)} q_w^{(jk)}

This factoring works when:

  • Per-tensor scaling: sxs_x and sws_w are scalars, trivially factors out
  • Per-token x Per-channel: sx(i)s_x^{(i)} depends only on row ii, sw(j)s_w^{(j)} depends only on row jj β€” factors out as an outer product
  • Per-channel x Per-channel: sx(k)s_x^{(k)} and sw(k)s_w^{(k)} both depend on the inner dimension kk β€” DOES NOT factor out
def verify_scale_compatibility(x_scale_shape, w_scale_shape, M, N, K):
    """Check if scale factors are compatible with INT8 GEMM.

    Compatible means the scales can be factored out of the inner sum.
    """
    x_dims = set()
    w_dims = set()

    if x_scale_shape == 'scalar':
        pass  # No dimension dependency
    elif x_scale_shape == 'per_token':
        x_dims.add('M')  # Depends on row index
    elif x_scale_shape == 'per_channel':
        x_dims.add('K')  # Depends on inner dimension

    if w_scale_shape == 'scalar':
        pass
    elif w_scale_shape == 'per_channel':
        w_dims.add('N')  # Depends on output channel
    elif w_scale_shape == 'per_input_channel':
        w_dims.add('K')  # Depends on inner dimension

    # Compatible if neither depends on K, or both depend on K with same scale
    k_dep = 'K' in x_dims or 'K' in w_dims
    compatible = not k_dep

    return compatible

# Check all combinations
combinations = [
    ('scalar', 'scalar', True),
    ('scalar', 'per_channel', True),
    ('per_token', 'scalar', True),
    ('per_token', 'per_channel', True),    # Standard W8A8
    ('per_channel', 'per_channel', False),  # Incompatible!
    ('per_channel', 'per_input_channel', False),
]

for x_s, w_s, expected in combinations:
    result = verify_scale_compatibility(x_s, w_s, 32, 4096, 4096)
    status = "OK" if result == expected else "MISMATCH"
    print(f"  X={x_s:>15s}, W={w_s:>20s}: "
          f"compatible={result} [{status}]")
ℹ️ The Per-Token x Per-Channel Convention

The standard W8A8 configuration uses per-token activation scaling (one scale per token/row of X) and per-channel weight scaling (one scale per output channel/row of W). This gives both operands good quantization quality while remaining compatible with INT8 tensor core GEMM. The dequantization is an outer product of the two scale vectors, applied element-wise to the INT32 output.

cuBLAS INT8 GEMM API

NVIDIA’s cuBLAS library provides INT8 GEMM through the cublasLtMatmul API. The setup is more involved than FP16 GEMM because of the mixed-type accumulation and scaling:

#include <cublasLt.h>

// cuBLAS INT8 GEMM setup
void setup_int8_gemm(
    cublasLtHandle_t handle,
    int M, int N, int K,
    const int8_t* A,      // Activations (M x K)
    const int8_t* B,      // Weights (N x K, stored as K x N column-major)
    float* C,             // Output (M x N)
    float alpha,          // Global scale factor
    float beta            // Output accumulation factor
) {
    cublasLtMatmulDesc_t matmulDesc;
    cublasLtMatmulDescCreate(&matmulDesc, CUBLAS_COMPUTE_32I, CUDA_R_32F);

    // Set transpose operations
    cublasOperation_t transA = CUBLAS_OP_N;
    cublasOperation_t transB = CUBLAS_OP_T;
    cublasLtMatmulDescSetAttribute(
        matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA,
        &transA, sizeof(transA)
    );
    cublasLtMatmulDescSetAttribute(
        matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB,
        &transB, sizeof(transB)
    );

    // Create matrix layouts
    cublasLtMatrixLayout_t layoutA, layoutB, layoutC;

    // A: INT8, M x K
    cublasLtMatrixLayoutCreate(&layoutA, CUDA_R_8I, M, K, M);

    // B: INT8, K x N (column-major for N x K row-major)
    cublasLtMatrixLayoutCreate(&layoutB, CUDA_R_8I, K, N, K);

    // C: FP32, M x N
    cublasLtMatrixLayoutCreate(&layoutC, CUDA_R_32F, M, N, M);

    // Execute INT8 GEMM
    // Y = alpha * (A_int8 @ B_int8^T) + beta * C
    // Internally: INT8 x INT8 -> INT32 accumulation -> FP32 output
    cublasLtMatmul(
        handle, matmulDesc,
        &alpha, A, layoutA,
        B, layoutB,
        &beta, C, layoutC,
        C, layoutC,
        NULL, NULL, 0, 0  // Workspace, preferences, stream
    );

    // Cleanup
    cublasLtMatmulDescDestroy(matmulDesc);
    cublasLtMatrixLayoutDestroy(layoutA);
    cublasLtMatrixLayoutDestroy(layoutB);
    cublasLtMatrixLayoutDestroy(layoutC);
}

Per-Token x Per-Channel Dequantization After cuBLAS

cuBLAS INT8 GEMM produces a single output with a global alpha scale. For per-token x per-channel scaling, we need a post-GEMM dequantization kernel:

// Post-GEMM dequantization kernel
// Y_fp = Y_int32 * x_scales[row] * w_scales[col]
__global__ void dequantize_int32_per_token_per_channel(
    const int32_t* __restrict__ Y_int32,  // (M, N)
    const float* __restrict__ x_scales,   // (M,)
    const float* __restrict__ w_scales,   // (N,)
    half* __restrict__ Y_fp16,            // (M, N)
    int M, int N
) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;

    if (row < M && col < N) {
        float val = (float)Y_int32[row * N + col];
        val *= x_scales[row] * w_scales[col];
        Y_fp16[row * N + col] = __float2half(val);
    }
}

In practice, this dequantization is fused into the GEMM epilogue using cuBLAS epilogue functions or custom kernels.

Dynamic Activation Quantization

Weights are quantized offline, but activations must be quantized at runtime because their distribution depends on the input. The quantization kernel runs before each GEMM:

def dynamic_quantize_per_token(X_fp):
    """Quantize activations to INT8 per-token at runtime.

    This runs on every forward pass. Must be fast.

    Args:
        X_fp: FP16/BF16 activations, shape (M, K)

    Returns:
        X_int8: quantized activations, shape (M, K)
        scales: per-token scale factors, shape (M, 1)
    """
    # Find per-token maximum
    abs_max = X_fp.abs().amax(dim=-1, keepdim=True)  # (M, 1)

    # Compute scale
    scales = abs_max / 127.0
    scales = scales.clamp(min=1e-10)

    # Quantize
    X_int8 = (X_fp / scales).round().clamp(-128, 127).to(torch.int8)

    return X_int8, scales

The runtime overhead of dynamic quantization is the cost of computing per-token max (a reduction) plus the division and round. On GPU, this takes approximately 10-15 microseconds for a typical (32, 4096) activation tensor β€” negligible compared to the GEMM itself.

def measure_quantization_overhead(M, K, num_iterations=1000):
    """Measure dynamic quantization kernel time."""
    X = torch.randn(M, K, dtype=torch.float16, device='cuda')

    # Warmup
    for _ in range(10):
        dynamic_quantize_per_token(X)
    torch.cuda.synchronize()

    import time
    start = time.perf_counter()
    for _ in range(num_iterations):
        dynamic_quantize_per_token(X)
    torch.cuda.synchronize()
    elapsed = (time.perf_counter() - start) / num_iterations

    return elapsed * 1e6  # microseconds

# Expected: ~12us for (32, 4096), ~18us for (128, 4096)

INT8 vs FP8: When INT8 Wins

H100 provides both INT8 and FP8 tensor core support, both at 1979 TFLOPS (dense). The question is when to prefer INT8 over FP8.

Precision Comparison

def compare_int8_fp8_precision():
    """Compare INT8 and FP8 E4M3 representable values."""

    # INT8: 256 uniform levels in [-128, 127]
    int8_levels = list(range(-128, 128))

    # FP8 E4M3: non-uniform levels in [-448, 448]
    # (simplified -- actual E4M3 has specific spacing)
    fp8_levels = []
    for sign in [1, -1]:
        for exp in range(16):  # 4-bit exponent
            for mant in range(8):  # 3-bit mantissa
                if exp == 0:
                    val = sign * (mant / 8) * (2 ** (-6))
                elif exp == 15 and mant == 7:
                    continue  # NaN
                elif exp == 15:
                    val = sign * (1 + mant / 8) * (2 ** 8)
                else:
                    val = sign * (1 + mant / 8) * (2 ** (exp - 7))
                fp8_levels.append(val)

    return {
        'int8_num_levels': len(int8_levels),
        'int8_range': (min(int8_levels), max(int8_levels)),
        'fp8_num_levels': len(set(fp8_levels)),
        'fp8_range': (min(fp8_levels), max(fp8_levels)),
    }
πŸ“Š

INT8 vs FP8 E4M3 Precision Characteristics

PropertyINT8FP8 E4M3
Total levels 256 ~240 (excluding NaN)
Range [-128, 127] [-448, 448]
Dynamic range 256:1 ~3500:1
Precision near 1.0 1.0 (uniform step) 0.125 (mantissa)
Precision near 0.01 1.0 (same step) 0.0005 (finer near 0)
Uniform spacing Yes No (logarithmic)
H100 throughput 1979 TOPS 1979 TFLOPS
Note: INT8 has uniform spacing (equal precision everywhere), FP8 has non-uniform spacing (finer near 0, coarser at extremes). INT8 is better for distributions concentrated in a narrow range after scaling. FP8 is better for wide dynamic range.

When INT8 Beats FP8

  1. After SmoothQuant: Once outlier channels are smoothed, activation distributions are concentrated in a narrow range. INT8’s uniform spacing uses all 256 levels efficiently.

  2. On Ampere (A100): A100 has INT8 tensor cores but no FP8 support. INT8 is the only sub-FP16 compute option.

  3. Larger models with less sensitivity: Larger models (70B+) tend to be less sensitive to quantization, and INT8 provides sufficient precision.

  4. When per-token scaling is sufficient: If the per-token activation range (after SmoothQuant) is narrow enough that 256 uniform levels suffice, INT8 avoids the complexity of FP8 calibration.

def should_use_int8(
    gpu_generation,
    model_size_B,
    has_smoothquant,
    activation_outlier_ratio,
):
    """Decide between INT8 and FP8 for W8A8 inference."""

    if gpu_generation == 'ampere':
        return True  # No FP8 tensor cores on A100

    if not has_smoothquant and activation_outlier_ratio > 20:
        return False  # FP8's wider range handles outliers better

    if has_smoothquant:
        # After SmoothQuant, INT8 and FP8 give similar quality
        # INT8 has simpler calibration (no E4M3/E5M2 format selection)
        return True

    # Default: FP8 on Hopper for simplicity
    return False

W8A8 Perplexity: INT8 vs FP8 on Llama-2 7B

(WikiText-2 Perplexity)
FP16 baseline
5.47 WikiText-2 Perplexity
FP8 per-tensor
5.56 WikiText-2 Perplexity
INT8 per-tensor Outlier damage
6.81 WikiText-2 Perplexity
INT8 per-token/ch
5.73 WikiText-2 Perplexity
FP8 + SmoothQuant
5.49 WikiText-2 Perplexity
INT8 + SmoothQuant Nearly equal
5.52 WikiText-2 Perplexity
πŸ’‘ After SmoothQuant, INT8 and FP8 Give Nearly Identical Results

Without SmoothQuant, FP8 is significantly better than INT8 for activations because its wider dynamic range accommodates outlier channels (5.56 vs 6.81 ppl). After SmoothQuant, the outliers are eliminated and INT8 matches FP8 quality (5.52 vs 5.49 ppl). If you use SmoothQuant, the choice between INT8 and FP8 comes down to hardware support and tooling, not quality.

Complete W8A8 INT8 Linear Layer

import torch
import torch.nn as nn

class W8A8Int8Linear(nn.Module):
    """W8A8 INT8 quantized linear layer for inference.

    Weights: INT8 per-channel quantized (offline)
    Activations: INT8 per-token quantized (dynamic)
    GEMM: INT8 tensor cores with INT32 accumulation
    Output: dequantized to FP16
    """

    def __init__(self, in_features, out_features, bias=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.register_buffer(
            'weight_int8',
            torch.zeros(out_features, in_features, dtype=torch.int8)
        )
        self.register_buffer(
            'weight_scale',
            torch.zeros(out_features, 1, dtype=torch.float32)
        )

        if bias:
            self.register_buffer(
                'bias', torch.zeros(out_features, dtype=torch.float16)
            )
        else:
            self.bias = None

    @classmethod
    def from_float(cls, linear, smooth_scales=None):
        """Quantize a float linear layer to W8A8 INT8.

        Args:
            linear: FP16/FP32 nn.Linear
            smooth_scales: optional SmoothQuant scales, shape (in_features,)
        """
        in_f = linear.in_features
        out_f = linear.out_features

        layer = cls(in_f, out_f, bias=linear.bias is not None)

        W = linear.weight.data.float()  # (out_f, in_f)

        # Apply SmoothQuant scaling to weights (if provided)
        if smooth_scales is not None:
            W = W * smooth_scales.unsqueeze(0)

        # Per-channel quantization
        w_max = W.abs().amax(dim=1, keepdim=True)  # (out_f, 1)
        w_scale = w_max / 127.0
        w_scale = w_scale.clamp(min=1e-10)

        W_q = (W / w_scale).round().clamp(-128, 127).to(torch.int8)

        layer.weight_int8.copy_(W_q)
        layer.weight_scale.copy_(w_scale)

        if linear.bias is not None:
            layer.bias.copy_(linear.bias.data.to(torch.float16))

        return layer

    def forward(self, x, smooth_scales=None):
        """Forward pass with dynamic activation quantization.

        Args:
            x: FP16 activations, shape (*, in_features)
            smooth_scales: SmoothQuant scales for activation (fused into LN)
        """
        original_shape = x.shape
        x = x.reshape(-1, self.in_features).float()

        # Dynamic per-token quantization of activations
        x_max = x.abs().amax(dim=1, keepdim=True)
        x_scale = x_max / 127.0
        x_scale = x_scale.clamp(min=1e-10)
        x_int8 = (x / x_scale).round().clamp(-128, 127).to(torch.int8)

        # INT8 GEMM (simulated -- real impl uses cuBLAS INT8)
        # Y_int32 = X_int8 @ W_int8^T
        y_int32 = x_int8.int() @ self.weight_int8.int().T

        # Dequantize: Y_fp = Y_int32 * (x_scale * w_scale^T)
        y_fp = y_int32.float() * (x_scale * self.weight_scale.T)

        if self.bias is not None:
            y_fp = y_fp + self.bias.float()

        y_fp = y_fp.half()
        return y_fp.reshape(*original_shape[:-1], self.out_features)

# Verify
torch.manual_seed(42)
linear = nn.Linear(4096, 4096, bias=False)
nn.init.normal_(linear.weight, std=0.02)

int8_layer = W8A8Int8Linear.from_float(linear)

x = torch.randn(1, 32, 4096)
with torch.no_grad():
    y_ref = linear(x)
    y_int8 = int8_layer(x).float()

mse = ((y_ref - y_int8) ** 2).mean().item()
cos_sim = torch.nn.functional.cosine_similarity(
    y_ref.flatten(), y_int8.flatten(), dim=0
).item()

print(f"W8A8 INT8 MSE: {mse:.6e}")
print(f"Cosine similarity: {cos_sim:.8f}")

INT32 Accumulation Overflow Analysis

INT8 tensor cores accumulate into INT32. The maximum possible accumulation value depends on the matrix dimensions:

max⁑(∣Yint32∣)=KΓ—1282=16384Γ—K\max(|Y_{\text{int32}}|) = K \times 128^2 = 16384 \times K

For K=4096K = 4096:

max⁑(∣Yint32∣)=4096Γ—16384=67,108,864\max(|Y_{\text{int32}}|) = 4096 \times 16384 = 67{,}108{,}864

INT32 range is [βˆ’2,147,483,648,β€…β€Š2,147,483,647][-2{,}147{,}483{,}648, \; 2{,}147{,}483{,}647]. The worst-case accumulation uses only 67M/2.1B=3.1%67M / 2.1B = 3.1\% of the INT32 range. Overflow is not a practical concern for typical LLM dimensions.

def check_int32_overflow(K, abs_max_per_element=128):
    """Check if INT32 accumulation can overflow for given K."""
    max_accumulation = K * abs_max_per_element * abs_max_per_element
    int32_max = 2 ** 31 - 1
    overflow_risk = max_accumulation > int32_max
    utilization = max_accumulation / int32_max * 100

    return {
        'max_accumulation': max_accumulation,
        'int32_max': int32_max,
        'overflow_risk': overflow_risk,
        'utilization_pct': utilization,
    }

for K in [4096, 8192, 16384, 32768, 131072]:
    result = check_int32_overflow(K)
    risk = "OVERFLOW!" if result['overflow_risk'] else "safe"
    print(f"  K={K:>6d}: max_accum={result['max_accumulation']:>15,}, "
          f"utilization={result['utilization_pct']:.1f}% [{risk}]")
  K=  4096: max_accum=     67,108,864, utilization=3.1% [safe]
  K=  8192: max_accum=    134,217,728, utilization=6.2% [safe]
  K= 16384: max_accum=    268,435,456, utilization=12.5% [safe]
  K= 32768: max_accum=    536,870,912, utilization=25.0% [safe]
  K=131072: max_accum=  2,147,483,648, utilization=100.0% [OVERFLOW!]
⚠️ INT32 Overflow at K=131072

For K dimensions above 131,072, worst-case INT32 accumulation can overflow. This is not a concern for typical LLM hidden dimensions (4096-16384) but matters for very long sequence attention computations where K = seq_len. In such cases, the GEMM must be tiled along K with intermediate FP32 accumulation.

SmoothQuant + INT8 End-to-End

The full W8A8 INT8 pipeline with SmoothQuant:

class SmoothQuantInt8Model:
    """Apply SmoothQuant and quantize a model to W8A8 INT8."""

    def __init__(self, model, alpha=0.5):
        self.model = model
        self.alpha = alpha

    def calibrate_and_quantize(self, calibration_dataloader, num_samples=128):
        """Full pipeline: calibrate, smooth, quantize."""

        # Step 1: Collect activation statistics
        act_maxes = {}

        def make_hook(name):
            def hook(module, input, output):
                x = input[0].detach().float()
                x_flat = x.reshape(-1, x.shape[-1])
                batch_max = x_flat.abs().amax(dim=0)
                if name not in act_maxes:
                    act_maxes[name] = batch_max
                else:
                    act_maxes[name] = torch.max(act_maxes[name], batch_max)
            return hook

        hooks = []
        for name, mod in self.model.named_modules():
            if isinstance(mod, nn.Linear):
                hooks.append(mod.register_forward_hook(make_hook(name)))

        count = 0
        self.model.eval()
        with torch.no_grad():
            for batch in calibration_dataloader:
                if count >= num_samples:
                    break
                self.model(batch['input_ids'].cuda())
                count += batch['input_ids'].shape[0]

        for h in hooks:
            h.remove()

        # Step 2: Compute SmoothQuant scales and apply
        for name, mod in self.model.named_modules():
            if isinstance(mod, nn.Linear) and name in act_maxes:
                act_max = act_maxes[name].to(mod.weight.device)
                weight_max = mod.weight.data.abs().amax(dim=0)

                smooth_scale = (
                    act_max.pow(self.alpha) /
                    weight_max.clamp(min=1e-5).pow(1 - self.alpha)
                ).clamp(min=1e-5)

                # Step 3: Quantize weights with smooth scaling applied
                int8_mod = W8A8Int8Linear.from_float(
                    mod, smooth_scales=smooth_scale
                )

                # Store smooth scales for runtime activation division
                int8_mod.register_buffer('smooth_scales', smooth_scale)

                # Replace module in model
                parent_name = '.'.join(name.split('.')[:-1])
                child_name = name.split('.')[-1]
                parent = dict(self.model.named_modules())[parent_name]
                setattr(parent, child_name, int8_mod)

        return self.model

Throughput Benchmarks

πŸ“Š

GEMM Throughput: cuBLAS FP16 vs INT8 (H100 SXM)

Matrix (M x N x K)FP16 TFLOPSINT8 TOPSINT8/FP16 Speedup
1 x 4096 x 4096 0.034 0.059 1.7x
32 x 4096 x 4096 0.89 1.71 1.9x
128 x 4096 x 4096 3.21 6.18 1.9x
512 x 4096 x 4096 11.2 21.4 1.9x
2048 x 4096 x 4096 38.7 72.1 1.9x
4096 x 4096 x 4096 68.2 128.6 1.9x
Note: INT8 tensor cores on H100 provide ~1.9x speedup over FP16 for compute-bound GEMMs. For bandwidth-bound cases (M=1), the speedup is lower because the bottleneck is memory, not compute.

End-to-End Decode: Llama-2 7B Tokens/sec (H100 SXM)

(Tokens per Second)
FP16
258 Tokens per Second
W4A16 Marlin BW-bound win
980 Tokens per Second
W8A8 INT8 Balanced
485 Tokens per Second
W8A8 FP8
502 Tokens per Second
W8A8 INT8 (batch=64) Compute win
18,200 Tokens per Second

Key insight: For single-token decode, W4A16 (980 tok/s) beats W8A8 (485 tok/s) because decode is bandwidth-bound and W4A16 loads half as much data. For large-batch inference, W8A8 INT8 wins because the GEMM is compute-bound and INT8 tensor cores provide 2x throughput.

Implementation in vLLM and TensorRT-LLM

# vLLM W8A8 INT8 integration
# Uses cutlass INT8 GEMM kernels with SmoothQuant

# Configuration for SmoothQuant in vLLM:
quantization_config = {
    'method': 'smoothquant',
    'weight_bits': 8,
    'activation_bits': 8,
    'alpha': 0.5,
    'per_token_activation': True,
    'per_channel_weight': True,
    'calibration_dataset': 'c4',
    'calibration_samples': 512,
}

# TensorRT-LLM W8A8 INT8:
# Uses cuBLAS INT8 GEMM with fused dequantization epilogue
# Supports per-tensor and per-channel weight scaling
# Activation quantization fused into preceding LayerNorm kernel

# Key difference: TensorRT-LLM fuses the activation quantization
# into the LayerNorm output, avoiding a separate kernel launch
# for dynamic quantization. This saves ~10us per layer.