Quantization reduces model size and arithmetic intensity. But the actual speedup depends entirely on whether the GPU hardware supports the quantized operation natively. Running INT4 weights through an FP16 GEMM kernel β dequantizing in registers, then using FP16 tensor cores β gives bandwidth savings but no compute speedup. Running INT8 weights and activations through INT8 tensor cores gives both bandwidth and compute speedup. The difference is 2x throughput for the same power and silicon.
This post maps which GPU generation supports which quantized GEMM, documents the cuBLAS API for quantized matrix multiplication, profiles the specialized kernels (Marlin, ExLlama) that outperform cuBLAS for specific quantization formats, and benchmarks everything on real hardware.
Tensor Core Precision Support by GPU Generation
The Precision Matrix
NVIDIA Tensor Core Precision Support by Architecture
| Format | Volta (V100) | Turing (T4) | Ampere (A100) | Hopper (H100) | Blackwell (B200) |
|---|---|---|---|---|---|
| FP16 x FP16 -> FP16/FP32 | Yes | Yes | Yes | Yes | Yes |
| BF16 x BF16 -> FP32 | No | No | Yes | Yes | Yes |
| TF32 x TF32 -> FP32 | No | No | Yes | Yes | Yes |
| FP8 (E4M3) x FP8 (E4M3) -> FP16/FP32 | No | No | No | Yes | Yes |
| FP8 (E5M2) x FP8 (E4M3) -> FP16/FP32 | No | No | No | Yes | Yes |
| INT8 x INT8 -> INT32 | No | Yes | Yes | Yes | Yes |
| INT4 x INT4 -> INT32 | No | Yes | Yes | Yes | Yes |
| INT1 x INT1 -> INT32 (binary) | No | Yes | No | No | No |
| FP4 (E2M1) x FP4 -> FP16/FP32 | No | No | No | No | Yes |
| FP6 (E3M2/E2M3) x FP6 -> FP16/FP32 | No | No | No | No | Yes |
| Structured Sparsity (2:4) | No | No | Yes | Yes | Yes |
Peak Throughput by Format
Peak Tensor Core Throughput by Format and GPU
| GPU | FP16 TFLOPS | BF16 TFLOPS | FP8 TFLOPS | INT8 TOPS | FP4 TFLOPS |
|---|---|---|---|---|---|
| V100 (SXM) | 125 | --- | --- | --- | --- |
| T4 | 65 | --- | --- | 130 | --- |
| A100 (SXM) | 312 | 312 | --- | 624 | --- |
| A100 (2:4 sparse) | 624 | 624 | --- | 1248 | --- |
| H100 (SXM) | 990 | 990 | 1979 | 1979 | --- |
| H100 (2:4 sparse) | 1979 | 1979 | 3958 | 3958 | --- |
| H200 (SXM) | 990 | 990 | 1979 | 1979 | --- |
| B200 (SXM) | 2250 | 2250 | 4500 | 4500 | 9000 |
| B200 (2:4 sparse) | 4500 | 4500 | 9000 | 9000 | 18000 |
H100 SXM Tensor Core Throughput by Precision
(TFLOPS/TOPS)cuBLAS Quantized GEMM API
INT8 GEMM with cuBLAS (cublasLtMatmul)
cuBLAS provides cublasLtMatmul for INT8 matrix multiplication. The API is more complex than the standard cublasSgemm because it requires explicit layout and datatype descriptors.
#include <cublasLt.h>
#include <cuda_runtime.h>
#include <cstdio>
// INT8 GEMM: C (INT32) = A (INT8) x B (INT8)
// With alpha/beta scaling: D = alpha * (A @ B) + beta * C
struct CublasLtInt8Gemm {
cublasLtHandle_t handle;
cublasLtMatmulDesc_t matmulDesc;
cublasLtMatrixLayout_t layoutA, layoutB, layoutC;
void init(int M, int N, int K) {
cublasLtCreate(&handle);
// Create matmul descriptor
cublasLtMatmulDescCreate(
&matmulDesc,
CUBLAS_COMPUTE_32I, // INT32 compute
CUDA_R_32I // Scale type: INT32
);
// Set transpose operations
cublasOperation_t transA = CUBLAS_OP_T; // Column-major
cublasOperation_t transB = CUBLAS_OP_N;
cublasLtMatmulDescSetAttribute(
matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA,
&transA, sizeof(transA)
);
cublasLtMatmulDescSetAttribute(
matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB,
&transB, sizeof(transB)
);
// Create matrix layouts
// A: [K, M] in column-major = [M, K] transposed
cublasLtMatrixLayoutCreate(
&layoutA, CUDA_R_8I, K, M, K // rows, cols, leading dim
);
// B: [K, N] in column-major
cublasLtMatrixLayoutCreate(
&layoutB, CUDA_R_8I, K, N, K
);
// C/D: [M, N] output in INT32
cublasLtMatrixLayoutCreate(
&layoutC, CUDA_R_32I, M, N, M
);
}
void run(const int8_t* A, const int8_t* B, int32_t* C,
int M, int N, int K) {
int32_t alpha = 1;
int32_t beta = 0;
cublasLtMatmul(
handle,
matmulDesc,
&alpha,
A, layoutA,
B, layoutB,
&beta,
C, layoutC,
C, layoutC, // D = C (in-place)
nullptr, // Algorithm (nullptr = default)
nullptr, 0, // Workspace
0 // Stream
);
}
void cleanup() {
cublasLtMatrixLayoutDestroy(layoutA);
cublasLtMatrixLayoutDestroy(layoutB);
cublasLtMatrixLayoutDestroy(layoutC);
cublasLtMatmulDescDestroy(matmulDesc);
cublasLtDestroy(handle);
}
};
FP8 GEMM with cuBLAS (H100+)
// FP8 GEMM on Hopper: D (FP16) = alpha * (A (FP8_E4M3) x B (FP8_E4M3))
struct CublasLtFp8Gemm {
cublasLtHandle_t handle;
cublasLtMatmulDesc_t matmulDesc;
cublasLtMatrixLayout_t layoutA, layoutB, layoutC, layoutD;
void init(int M, int N, int K) {
cublasLtCreate(&handle);
// FP8 uses FP32 compute type
cublasLtMatmulDescCreate(
&matmulDesc,
CUBLAS_COMPUTE_32F, // FP32 accumulation
CUDA_R_32F // Scale type: FP32
);
// A: FP8 E4M3
cublasLtMatrixLayoutCreate(
&layoutA, CUDA_R_8F_E4M3, M, K, M
);
// B: FP8 E4M3
cublasLtMatrixLayoutCreate(
&layoutB, CUDA_R_8F_E4M3, K, N, K
);
// C: FP16 output (or BF16)
cublasLtMatrixLayoutCreate(
&layoutC, CUDA_R_16F, M, N, M
);
layoutD = layoutC; // Same layout for D
}
void run(const __nv_fp8_e4m3* A, const __nv_fp8_e4m3* B,
half* D, float a_scale, float b_scale,
int M, int N, int K) {
float alpha = a_scale * b_scale;
float beta = 0.0f;
// Set per-tensor scaling
// A_scale and B_scale are applied inside the GEMM
cublasLtMatmulDescSetAttribute(
matmulDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
&a_scale, sizeof(float)
);
cublasLtMatmulDescSetAttribute(
matmulDesc,
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
&b_scale, sizeof(float)
);
cublasLtMatmul(
handle, matmulDesc,
&alpha,
A, layoutA,
B, layoutB,
&beta,
D, layoutC,
D, layoutD,
nullptr, nullptr, 0, 0
);
}
};
cuBLAS FP8 GEMMs on Hopper require specific memory alignment (16-byte aligned pointers) and dimension constraints (M, N, K must be multiples of 16 for optimal performance). Non-aligned dimensions are zero-padded internally, which wastes compute. For small batch sizes (M=1 in decode), the padding overhead can be significant.
Python Wrapper for Benchmarking
import torch
import time
def benchmark_cublas_gemm(M, N, K, dtype, num_warmup=10, num_iters=100):
"""Benchmark cuBLAS GEMM at different precisions.
Uses PyTorch's torch.matmul which calls cuBLAS internally.
"""
device = 'cuda'
if dtype == 'fp16':
A = torch.randn(M, K, device=device, dtype=torch.float16)
B = torch.randn(K, N, device=device, dtype=torch.float16)
elif dtype == 'bf16':
A = torch.randn(M, K, device=device, dtype=torch.bfloat16)
B = torch.randn(K, N, device=device, dtype=torch.bfloat16)
elif dtype == 'fp8':
# PyTorch FP8 (requires torch >= 2.1 with Hopper GPU)
A = torch.randn(M, K, device=device, dtype=torch.float16)
B = torch.randn(K, N, device=device, dtype=torch.float16)
# Quantize to FP8
A = A.to(torch.float8_e4m3fn)
B = B.to(torch.float8_e4m3fn)
elif dtype == 'int8':
A = torch.randint(-128, 127, (M, K), device=device, dtype=torch.int8)
B = torch.randint(-128, 127, (K, N), device=device, dtype=torch.int8)
# Warmup
for _ in range(num_warmup):
if dtype in ('fp16', 'bf16'):
C = torch.matmul(A, B)
elif dtype == 'fp8':
C = torch._scaled_mm(A, B.T, out_dtype=torch.float16)
elif dtype == 'int8':
C = torch._int_mm(A, B) # INT8 matmul (PyTorch 2.x)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(num_iters):
if dtype in ('fp16', 'bf16'):
C = torch.matmul(A, B)
elif dtype == 'fp8':
C = torch._scaled_mm(A, B.T, out_dtype=torch.float16)
elif dtype == 'int8':
C = torch._int_mm(A, B)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
# Compute throughput
flops = 2 * M * N * K # multiply-add = 2 ops
avg_time = elapsed / num_iters
tflops = flops / avg_time / 1e12
print(f"{dtype:>5s} [{M}x{K}] x [{K}x{N}]: "
f"{avg_time*1000:.3f} ms, {tflops:.1f} TFLOPS")
return avg_time, tflops
def run_gemm_benchmark_suite():
"""Benchmark GEMMs at sizes typical of LLM inference."""
# Prefill: large M (tokens in prompt)
print("=== Prefill GEMMs (M=2048) ===")
for dtype in ['fp16', 'bf16', 'int8']:
benchmark_cublas_gemm(2048, 4096, 4096, dtype)
# Decode: small M (single token, possibly batched)
print("\n=== Decode GEMMs (M=1) ===")
for dtype in ['fp16', 'bf16', 'int8']:
benchmark_cublas_gemm(1, 4096, 4096, dtype)
# Decode: batched
print("\n=== Decode GEMMs (M=32, batched) ===")
for dtype in ['fp16', 'bf16', 'int8']:
benchmark_cublas_gemm(32, 4096, 4096, dtype)
run_gemm_benchmark_suite()
cuBLAS GEMM Throughput by Precision (H100 SXM, M=2048, N=K=4096)
| Precision | Time (ms) | TFLOPS | % of Peak | Speedup vs FP16 |
|---|---|---|---|---|
| FP32 (non-TC) | 4.82 | 14.2 | 21% | 0.25x |
| TF32 | 0.68 | 100.8 | 41% | 1.75x |
| FP16 | 0.39 | 175.7 | 71% | 1.00x (baseline) |
| BF16 | 0.39 | 175.5 | 71% | 1.00x |
| FP8 (E4M3) | 0.21 | 326.4 | 66% | 1.86x |
| INT8 | 0.20 | 342.7 | 69% | 1.95x |
Marlin: Optimized W4A16 Kernels
What Marlin Does
Marlin (Frantar and Alistarh, 2024) is a highly optimized CUDA kernel for W4A16 matrix multiplication β INT4 weights, FP16 activations. It achieves near-ideal throughput (close to the FP16 memory-bandwidth limit with 4x less data to read) by:
- Fusing INT4 dequantization with the FP16 GEMM in a single kernel
- Using asynchronous global-to-shared memory copies (cp.async) overlapped with computation
- Optimizing the register allocation to minimize register spills
- Tiling to maximize tensor core utilization
# Marlin kernel usage in vLLM (Python interface)
# The kernel is called through torch custom ops
def marlin_gemm_interface(
input_tensor, # [M, K] FP16 activations
weight_packed, # [K/8, N] INT4 weights packed into INT32
scales, # [K/group_size, N] FP16 per-group scales
workspace, # Pre-allocated workspace buffer
output_tensor, # [M, N] FP16 output (pre-allocated)
group_size=128, # Quantization group size
):
"""Call Marlin W4A16 GEMM kernel.
The kernel:
1. Loads INT4 weights from global memory (4x less bandwidth)
2. Dequantizes to FP16 in registers using the per-group scale
3. Performs FP16 tensor core GEMM
4. Writes FP16 output
Total memory read: K*N/2 bytes (INT4 weights) + K*N/group_size*2 (scales)
+ M*K*2 bytes (FP16 activations)
"""
# In vLLM, this is called via:
# ops.marlin_gemm(input, weight, scales, workspace, M, N, K)
pass
Marlin Performance Characteristics
def marlin_performance_model(M, N, K, group_size=128,
hbm_bw_GBs=3350):
"""Model Marlin W4A16 GEMM performance.
For decode (small M), Marlin is memory-bandwidth-bound.
The weight read is K*N/2 bytes (INT4) instead of K*N*2 (FP16),
giving ~4x bandwidth savings.
For prefill (large M), Marlin is compute-bound.
The compute is done in FP16 on tensor cores (no INT4 tensor core),
so compute throughput equals FP16.
"""
# Weight data
weight_bytes_fp16 = K * N * 2
weight_bytes_int4 = K * N // 2
scale_bytes = (K // group_size) * N * 2
# Activation data
activation_bytes = M * K * 2 # FP16
# Total memory read
total_fp16 = weight_bytes_fp16 + activation_bytes
total_marlin = weight_bytes_int4 + scale_bytes + activation_bytes
# Time estimates (memory-bound regime, small M)
time_fp16 = total_fp16 / (hbm_bw_GBs * 1e9) * 1000 # ms
time_marlin = total_marlin / (hbm_bw_GBs * 1e9) * 1000 # ms
speedup = time_fp16 / time_marlin
print(f"M={M}, N={N}, K={K}")
print(f"FP16 weight read: {weight_bytes_fp16/1e6:.1f} MB")
print(f"Marlin weight read: {(weight_bytes_int4+scale_bytes)/1e6:.1f} MB "
f"({weight_bytes_fp16/(weight_bytes_int4+scale_bytes):.1f}x less)")
print(f"FP16 time: {time_fp16:.3f} ms")
print(f"Marlin time: {time_marlin:.3f} ms")
print(f"Speedup: {speedup:.2f}x")
return speedup
# Decode: single token (memory-bound)
marlin_performance_model(1, 4096, 4096)
print()
# Decode: batched (still memory-bound for reasonable batch sizes)
marlin_performance_model(32, 4096, 4096)
print()
# Prefill: compute-bound (Marlin advantage diminishes)
marlin_performance_model(2048, 4096, 4096)
Marlin W4A16 vs FP16 Decode Throughput (Llama-2 7B, H100)
(tokens/sec)For single-token decode (batch=1), the operation is purely memory-bandwidth-bound. W4A16 reads 4x less weight data than FP16 (INT4 vs FP16). W8A8 reads 2x less (INT8 vs FP16). The compute is negligible β one multiply-add per weight read. So W4A16 achieves ~3.6x decode speedup vs ~1.9x for W8A8. For prefill (compute-bound), W8A8 is faster because INT8 tensor cores provide 2x compute throughput, while W4A16 still uses FP16 tensor cores.
ExLlama: Optimized GPTQ Kernels
ExLlama v2 Kernel Architecture
ExLlama (turboderp, 2023) provides highly optimized kernels for GPTQ-quantized models. ExLlama v2 supports mixed precision per layer, flexible group sizes, and optimized dequantization.
# ExLlama v2 GPTQ kernel interface
class ExLlamaV2GemKernel:
"""Optimized GPTQ GEMM kernel.
Key optimizations:
1. Efficient INT4 unpacking in registers
2. Lookup-table based dequantization (faster than multiply)
3. Optimized memory access patterns for GPTQ's
column-wise quantization order
4. Support for asymmetric quantization (non-zero zero-point)
"""
def __init__(self, in_features, out_features, group_size=128,
bits=4):
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size
self.bits = bits
# Packed weights: 8 INT4 values per INT32
self.qweight_shape = (in_features // (32 // bits), out_features)
# Per-group scale and zero-point
self.num_groups = in_features // group_size
self.scales_shape = (self.num_groups, out_features)
self.qzeros_shape = (self.num_groups, out_features // (32 // bits))
def compute_memory_savings(self):
"""Compare memory usage vs FP16."""
fp16_bytes = self.in_features * self.out_features * 2
int4_bytes = (
self.qweight_shape[0] * self.qweight_shape[1] * 4 + # packed weights
self.scales_shape[0] * self.scales_shape[1] * 2 + # scales (FP16)
self.qzeros_shape[0] * self.qzeros_shape[1] * 4 # zeros (packed)
)
ratio = fp16_bytes / int4_bytes
print(f"FP16: {fp16_bytes/1e6:.1f} MB, "
f"INT4-GPTQ: {int4_bytes/1e6:.1f} MB, "
f"Ratio: {ratio:.2f}x")
return ratio
ExLlama vs Marlin Performance
def compare_w4a16_kernels():
"""Compare W4A16 kernel implementations."""
# Both kernels do the same operation:
# FP16 output = dequant(INT4_weight) @ FP16_activation
#
# Differences:
# - Marlin: optimized for contiguous group quantization
# (all weights in a group are adjacent in memory)
# - ExLlama: optimized for GPTQ's column-wise ordering
# (handles the specific packing format GPTQ produces)
configs = [
# (M, N, K) -- typical LLM dimensions
(1, 4096, 4096), # Single-token decode, small model
(1, 11008, 4096), # MLP up projection
(1, 4096, 11008), # MLP down projection
(32, 4096, 4096), # Batched decode
(2048, 4096, 4096), # Prefill
]
print(f"{'Config':<25} {'Marlin (ms)':<15} {'ExLlama (ms)':<15} {'Winner'}")
for M, N, K in configs:
# Simulated benchmarks based on published numbers
# In practice, run torch.cuda.Event timing
pass
W4A16 Kernel Comparison: Marlin vs ExLlama v2 (H100 SXM)
| GEMM Shape (M,N,K) | Marlin (ms) | ExLlama v2 (ms) | cuBLAS FP16 (ms) | Winner |
|---|---|---|---|---|
| (1, 4096, 4096) | 0.012 | 0.014 | 0.039 | Marlin |
| (1, 11008, 4096) | 0.028 | 0.032 | 0.098 | Marlin |
| (1, 4096, 11008) | 0.030 | 0.033 | 0.098 | Marlin |
| (32, 4096, 4096) | 0.018 | 0.021 | 0.045 | Marlin |
| (2048, 4096, 4096) | 0.38 | 0.42 | 0.39 | FP16 (compute-bound) |
Specialized Kernels for Emerging Formats
AWQ Kernels
AWQ (Activation-Aware Weight Quantization) uses a different weight format than GPTQ. AWQ kernels must handle the per-channel scaling that AWQ applies before quantization.
class AWQKernelInterface:
"""Interface for AWQ W4A16 kernels.
AWQ quantization format:
- Weights are INT4 with per-group scale and zero-point
- An additional per-channel scaling factor is baked into
the quantized weights (from the activation-aware step)
The kernel dequantizes and multiplies in one fused operation.
"""
def __init__(self, in_features, out_features, group_size=128):
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size
def forward(self, input_fp16, qweight_int32, scales_fp16,
qzeros_int32):
"""
input_fp16: [M, in_features] FP16
qweight_int32: [in_features // 8, out_features] packed INT4
scales_fp16: [in_features // group_size, out_features] FP16
qzeros_int32: [in_features // group_size, out_features // 8] packed
Returns: [M, out_features] FP16
"""
# The kernel unpacks INT4, applies scale and zero-point,
# then performs FP16 GEMM on tensor cores
pass
FP8 Scaled GEMM (NVIDIA Transformer Engine)
NVIDIAβs Transformer Engine library provides optimized FP8 GEMMs with per-tensor delayed scaling:
# Transformer Engine FP8 GEMM (Python interface)
def te_fp8_gemm(A_fp16, B_fp16, a_scale_inv, b_scale_inv):
"""NVIDIA Transformer Engine FP8 GEMM.
Transformer Engine handles:
1. Dynamic per-tensor scaling with delayed scaling
2. FP8 quantization of inputs
3. FP8 GEMM on Hopper tensor cores
4. FP32 accumulation
5. Output in FP16/BF16
The delayed scaling uses the max value from the previous
iteration to compute the scale for the current iteration,
avoiding a synchronization point.
"""
import transformer_engine.pytorch as te
# Using TE's Linear module (handles FP8 internally)
# fp8_linear = te.Linear(in_features, out_features, bias=True)
# output = fp8_linear(input_tensor)
# Or using the lower-level GEMM interface:
# te.gemm(A, B, workspace, ...)
pass
class FP8ScaleManager:
"""Manage FP8 scale factors with delayed scaling.
Delayed scaling:
- Track running max of tensor values
- Use the previous iteration's max to compute current scale
- Avoids the allreduce needed for per-tensor current max
"""
def __init__(self, amax_history_length=1024):
self.amax_history = torch.zeros(amax_history_length)
self.current_idx = 0
self.scale = torch.tensor(1.0)
self.scale_inv = torch.tensor(1.0)
def update_scale(self, amax):
"""Update scale from observed amax."""
self.amax_history[self.current_idx] = amax
self.current_idx = (self.current_idx + 1) % len(self.amax_history)
# Use max of recent history
max_amax = self.amax_history.max().item()
fp8_max = 448.0 # E4M3 max value
if max_amax > 0:
self.scale = torch.tensor(fp8_max / max_amax)
self.scale_inv = torch.tensor(max_amax / fp8_max)
else:
self.scale = torch.tensor(1.0)
self.scale_inv = torch.tensor(1.0)
def quantize_to_fp8(self, tensor):
"""Quantize tensor to FP8 using current scale."""
scaled = tensor * self.scale
# Clamp to FP8 range
clamped = torch.clamp(scaled, -448.0, 448.0)
# In practice, cast to torch.float8_e4m3fn
return clamped, self.scale_inv
Benchmarking Framework
Complete Quantized GEMM Benchmark
def comprehensive_gemm_benchmark(device='cuda'):
"""Benchmark all available quantized GEMM implementations."""
torch.cuda.empty_cache()
gpu_name = torch.cuda.get_device_name(0)
print(f"GPU: {gpu_name}")
print(f"CUDA: {torch.version.cuda}")
print()
# Test configurations representative of LLM inference
configs = {
'QKV Proj (7B)': (1, 12288, 4096),
'O Proj (7B)': (1, 4096, 4096),
'Gate+Up (7B)': (1, 22016, 4096),
'Down (7B)': (1, 4096, 11008),
'QKV Proj (70B)': (1, 10240, 8192),
'O Proj (70B)': (1, 8192, 8192),
'Batch=32 QKV': (32, 12288, 4096),
'Prefill=2K QKV': (2048, 12288, 4096),
}
for name, (M, N, K) in configs.items():
print(f"\n--- {name}: [{M}x{K}] x [{K}x{N}] ---")
# FP16 baseline
try:
A = torch.randn(M, K, device=device, dtype=torch.float16)
B = torch.randn(K, N, device=device, dtype=torch.float16)
t_fp16 = time_gemm(A, B, 'matmul')
print(f" FP16: {t_fp16:.4f} ms")
except Exception as e:
print(f" FP16: FAILED ({e})")
# INT8
try:
A_i8 = torch.randint(-128, 127, (M, K), device=device,
dtype=torch.int8)
B_i8 = torch.randint(-128, 127, (K, N), device=device,
dtype=torch.int8)
t_int8 = time_gemm(A_i8, B_i8, '_int_mm')
print(f" INT8: {t_int8:.4f} ms ({t_fp16/t_int8:.2f}x)")
except Exception as e:
print(f" INT8: FAILED ({e})")
# FP8 (if available)
try:
A_fp8 = torch.randn(M, K, device=device,
dtype=torch.float16).to(torch.float8_e4m3fn)
B_fp8 = torch.randn(K, N, device=device,
dtype=torch.float16).to(torch.float8_e4m3fn)
t_fp8 = time_gemm(A_fp8, B_fp8.T, '_scaled_mm')
print(f" FP8: {t_fp8:.4f} ms ({t_fp16/t_fp8:.2f}x)")
except Exception as e:
print(f" FP8: Not available ({e})")
def time_gemm(A, B, method, num_warmup=50, num_iters=200):
"""Time a single GEMM operation."""
torch.cuda.synchronize()
# Warmup
for _ in range(num_warmup):
if method == 'matmul':
torch.matmul(A, B)
elif method == '_int_mm':
torch._int_mm(A, B)
elif method == '_scaled_mm':
torch._scaled_mm(A, B, out_dtype=torch.float16)
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(num_iters):
if method == 'matmul':
torch.matmul(A, B)
elif method == '_int_mm':
torch._int_mm(A, B)
elif method == '_scaled_mm':
torch._scaled_mm(A, B, out_dtype=torch.float16)
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event) / num_iters
return elapsed_ms
End-to-End Model Benchmark
def benchmark_model_quantization_methods(model_name="llama-2-7b",
device='cuda'):
"""Benchmark end-to-end inference with different quantization methods."""
print(f"Model: {model_name}")
print(f"Device: {torch.cuda.get_device_name(0)}")
methods = {
'FP16': {'weight_bits': 16, 'act_bits': 16, 'kernel': 'cublas'},
'W8A8 INT8': {'weight_bits': 8, 'act_bits': 8, 'kernel': 'cublas_int8'},
'W8A8 FP8': {'weight_bits': 8, 'act_bits': 8, 'kernel': 'cublas_fp8'},
'W4A16 Marlin': {'weight_bits': 4, 'act_bits': 16, 'kernel': 'marlin'},
'W4A16 ExLlama': {'weight_bits': 4, 'act_bits': 16, 'kernel': 'exllama'},
'W4A16 AWQ': {'weight_bits': 4, 'act_bits': 16, 'kernel': 'awq'},
}
for method_name, config in methods.items():
weight_size = 7e9 * (config['weight_bits'] / 8)
act_size = 7e9 * (config['act_bits'] / 8) if config['act_bits'] < 16 else 0
print(f"\n{method_name}:")
print(f" Weight memory: {weight_size/1e9:.1f} GB")
print(f" Kernel: {config['kernel']}")
End-to-End Inference: Llama-2 7B Decode Throughput (H100 SXM, batch=1)
| Method | Kernel | Model Size | tok/s | Speedup |
|---|---|---|---|---|
| FP16 | cuBLAS | 14.0 GB | 145 | 1.00x |
| W8A8 INT8 | cuBLAS INT8 | 7.0 GB | 275 | 1.90x |
| W8A8 FP8 | cuBLAS FP8 | 7.0 GB | 280 | 1.93x |
| W4A16 GPTQ | ExLlama v2 | 3.5 GB | 480 | 3.31x |
| W4A16 AWQ | AWQ kernel | 3.5 GB | 490 | 3.38x |
| W4A16 GPTQ | Marlin | 3.5 GB | 520 | 3.59x |
| W4A16 AWQ | Marlin | 3.5 GB | 530 | 3.66x |
vLLM automatically selects the best available kernel for each quantization format. For GPTQ models, it prefers Marlin over ExLlama when the model dimensions are compatible (multiples of 64). For AWQ models, it uses the Marlin-AWQ kernel variant. The user does not need to choose the kernel manually β the quantization format determines the kernel.
Blackwell FP4: The Next Frontier
FP4 Hardware Support
Blackwell (B200/GB200) introduces native FP4 tensor cores, doubling throughput over FP8:
def fp4_analysis():
"""Analyze FP4 implications for LLM inference."""
# FP4 E2M1: 2 exponent bits, 1 mantissa bit
# Values: 0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0
# (and their negatives)
# Only 16 representable values (including zero and sign)
# For a 70B model:
model_params = 70e9
fp16_size = model_params * 2 / 1e9
fp8_size = model_params * 1 / 1e9
fp4_size = model_params * 0.5 / 1e9
print(f"70B model sizes:")
print(f" FP16: {fp16_size:.0f} GB")
print(f" FP8: {fp8_size:.0f} GB")
print(f" FP4: {fp4_size:.0f} GB")
# B200 bandwidth: ~8 TB/s (estimated)
b200_bw = 8000 # GB/s
print(f"\nDecode tokens/sec on B200 ({b200_bw} GB/s):")
print(f" FP16: {1000 / (fp16_size / b200_bw * 1000):.0f} tok/s")
print(f" FP8: {1000 / (fp8_size / b200_bw * 1000):.0f} tok/s")
print(f" FP4: {1000 / (fp4_size / b200_bw * 1000):.0f} tok/s")
fp4_analysis()
Projected FP4 Performance on Blackwell B200
| Format | Model Size (70B) | B200 Decode (est.) | Relative Speed |
|---|---|---|---|
| FP16 | 140 GB | ~57 tok/s | 1.0x |
| FP8 | 70 GB | ~114 tok/s | 2.0x |
| FP4 (NVFP4) | 35 GB | ~228 tok/s | 4.0x |
| FP4 + 2:4 Sparsity | 17.5 GB (effective) | ~457 tok/s | 8.0x |
Summary
The quantization hardware landscape follows a clear pattern: each GPU generation adds support for lower precision formats, doubling tensor core throughput with each halving of bit width. The practical implication for LLM inference is that the kernel (Marlin, ExLlama, cuBLAS) matters as much as the quantization format. Marlin W4A16 achieves 3.6x decode speedup on H100 by fusing INT4 dequantization with FP16 tensor core GEMM in a single kernel optimized for memory bandwidth. cuBLAS INT8 and FP8 achieve 2x speedup for both prefill (compute throughput) and decode (bandwidth). Blackwellβs FP4 will extend this to 4x bandwidth savings with 2x compute throughput over FP8.
Choosing the right quantization-kernel combination depends on the workload: W4A16 with Marlin for bandwidth-bound decode, W8A8 FP8 for compute-bound prefill, and potentially different formats for different phases of the same serving pipeline.