Getting INT4 weights to actually run fast on GPUs turned out to be much harder than just storing them in 4 bits. Early INT4 kernels from 2022-2023 were slower than FP16 cuBLAS at batch size 1—the very workload where 4x compression should matter most. The problem was dequantization overhead: unpacking INT4 to FP16, applying per-group scales, and feeding tensor cores ate up all the bandwidth savings. Then in 2024, the Marlin kernel from IST Austria/Neural Magic cracked the code: dequantize in registers (not shared memory), absorb scale factors into the memory access pattern, and overlap everything with double-buffered async copies. The result was 3.8x decode speedup over FP16—finally delivering on the theoretical 4x bandwidth advantage.
W4A16 inference stores weights in 4-bit integers and activations in FP16. During the matrix multiply, the kernel dequantizes INT4 weights to FP16 in registers and uses FP16 tensor cores for the actual computation. There is no INT4 GEMM — the 4-bit format is purely a compression format for memory and bandwidth, not a compute format.
This architecture makes W4A16 a bandwidth optimization, not a compute optimization. The speedup comes from loading 4x less weight data from GPU memory, not from doing 4x more operations per cycle. This distinction is critical for understanding when W4A16 helps and when it does not.
This post covers the memory bandwidth argument for W4A16, the Marlin kernel architecture that achieves near-optimal bandwidth utilization, the INT4 packing format, and benchmarks against FP16 and other quantized formats.
The Bandwidth Argument
LLM inference during autoregressive decoding is memory-bandwidth bound, not compute bound. Each token generation requires loading the entire model’s weights from GPU memory but performs relatively little computation (a matrix-vector product, not a matrix-matrix product).
Arithmetic Intensity Analysis
For a single-token decode step through a linear layer with weight matrix :
- Bytes loaded:
- FLOPs: (one multiply-add per weight)
- Arithmetic intensity:
def arithmetic_intensity(bytes_per_weight):
"""Arithmetic intensity for single-token GEMV."""
return 2.0 / bytes_per_weight
formats = {
'FP16': 2.0,
'INT8': 1.0,
'INT4': 0.5,
'INT4 (packed)': 0.5,
}
for fmt, bpw in formats.items():
ai = arithmetic_intensity(bpw)
print(f" {fmt:>15s}: {ai:.1f} FLOP/byte")
FP16: 1.0 FLOP/byte
INT8: 2.0 FLOP/byte
INT4: 4.0 FLOP/byte
INT4 (packed): 4.0 FLOP/byte
On an H100 SXM with 3.35 TB/s memory bandwidth and 990 TFLOPS FP16 tensor core throughput:
- FP16 GEMV: limited by bandwidth at TFLOP/s (uses 0.3% of compute)
- INT4 GEMV: limited by bandwidth at TFLOP/s (uses 1.4% of compute)
Both are deeply bandwidth-bound. INT4 is 4x faster than FP16 for single-token decoding because it loads 4x less data.
def decode_throughput(
hidden_dim, num_layers, num_heads, head_dim,
bytes_per_weight, mem_bandwidth_tb_s
):
"""Estimate single-token decode latency and throughput.
Accounts for Q, K, V, O projections and MLP (gate, up, down).
"""
# Attention projections: 4 * hidden^2 parameters
attn_params = 4 * hidden_dim * hidden_dim
# MLP: typically 3 * hidden * (8/3 * hidden) for SwiGLU
# = 3 * hidden * intermediate, intermediate = 8/3 * hidden
intermediate = int(8 / 3 * hidden_dim)
# Round to multiple of 256 for alignment
intermediate = ((intermediate + 255) // 256) * 256
mlp_params = 3 * hidden_dim * intermediate
total_params = num_layers * (attn_params + mlp_params)
total_bytes = total_params * bytes_per_weight
bandwidth_bytes_per_s = mem_bandwidth_tb_s * 1e12
latency_s = total_bytes / bandwidth_bytes_per_s
tokens_per_s = 1.0 / latency_s
return {
'total_params_B': total_params / 1e9,
'total_bytes_GB': total_bytes / 1e9,
'latency_ms': latency_s * 1000,
'tokens_per_s': tokens_per_s,
}
# Llama-2 7B on H100 SXM
for fmt, bpw in [('FP16', 2.0), ('INT8', 1.0), ('INT4', 0.5)]:
result = decode_throughput(
hidden_dim=4096, num_layers=32,
num_heads=32, head_dim=128,
bytes_per_weight=bpw,
mem_bandwidth_tb_s=3.35
)
print(f" {fmt:>5s}: {result['total_bytes_GB']:.1f} GB, "
f"latency={result['latency_ms']:.2f} ms, "
f"throughput={result['tokens_per_s']:.0f} tok/s")
FP16: 13.0 GB, latency=3.88 ms, throughput=258 tok/s
INT8: 6.5 GB, latency=1.94 ms, throughput=515 tok/s
INT4: 3.3 GB, latency=0.97 ms, throughput=1031 tok/s
For single-token decode, W4A16 achieves nearly 4x the throughput of FP16, limited only by the overhead of dequantization and scale factor loading. A well-optimized kernel like Marlin achieves 90-95% of the theoretical 4x speedup.
INT4 Packing Format
Two INT4 values are packed into a single byte. The packing convention varies between implementations:
import numpy as np
def pack_int4_symmetric(values):
"""Pack pairs of signed INT4 values into bytes.
INT4 range: [-8, 7]. Stored as unsigned [0, 15] with offset.
Two values per byte: low nibble and high nibble.
"""
assert len(values) % 2 == 0
packed = np.zeros(len(values) // 2, dtype=np.uint8)
for i in range(0, len(values), 2):
# Map [-8, 7] to [0, 15]
low = int(values[i]) + 8
high = int(values[i + 1]) + 8
packed[i // 2] = (high << 4) | (low & 0x0F)
return packed
def unpack_int4_symmetric(packed):
"""Unpack bytes into pairs of signed INT4 values."""
values = np.zeros(len(packed) * 2, dtype=np.int8)
for i in range(len(packed)):
low = (packed[i] & 0x0F) - 8
high = (packed[i] >> 4) - 8
values[2 * i] = low
values[2 * i + 1] = high
return values
# GPTQ packing format (used by Marlin):
# 8 INT4 values packed into one 32-bit integer
def pack_int4_gptq(values):
"""Pack 8 INT4 values into a single uint32.
This is the GPTQ convention: values packed from LSB to MSB.
"""
assert len(values) == 8
packed = np.uint32(0)
for i in range(8):
val = np.uint32(int(values[i]) + 8) # Map to unsigned [0, 15]
packed |= (val & np.uint32(0xF)) << np.uint32(4 * i)
return packed
def unpack_int4_gptq(packed):
"""Unpack a uint32 into 8 INT4 values."""
values = np.zeros(8, dtype=np.int8)
for i in range(8):
val = (int(packed) >> (4 * i)) & 0xF
values[i] = val - 8
return values
The Marlin Kernel Architecture
Marlin (Mixed Auto-Regressive Linear, from IST Austria / Neural Magic) is a W4A16 GEMM kernel designed for maximum memory bandwidth utilization on NVIDIA Ampere and Hopper GPUs. It achieves near-ideal speedups (close to 4x over FP16 cuBLAS for batch size 1).
Design Principles
-
Maximize global memory bandwidth utilization: Load INT4 weights at the full memory bandwidth, dequantize in registers, and feed FP16 values to tensor cores.
-
Overlap memory loads with computation: Use double-buffering in shared memory to overlap the next tile’s memory load with the current tile’s tensor core computation.
-
Minimize dequantization overhead: The INT4-to-FP16 conversion is done in registers using bitwise operations, not lookup tables.
-
Tile sizes tuned for bandwidth: The kernel uses large tiles in the K dimension to maximize the ratio of weight loads to scale factor loads.
Memory Layout
Marlin reorders the weight matrix to optimize memory access patterns:
def marlin_weight_layout(W_q_packed, N, K, group_size=128):
"""Illustrate Marlin's weight memory layout.
Marlin tiles: 16 x 64 (N x K) per tile.
Within each tile, weights are arranged for coalesced 128-byte loads
by warps.
W_q_packed: shape (N, K // 8), dtype uint32 (8 INT4s per uint32)
Marlin reorders to: (N // 16, K // 64, 16, 64 // 8)
= (N // 16, K // 64, 16, 8) uint32 values per tile
"""
# Tile dimensions
tile_n = 16 # Output rows per tile
tile_k = 64 # Input columns per tile (64 INT4 values = 8 uint32)
num_tiles_n = N // tile_n
num_tiles_k = K // tile_k
packed_k = K // 8 # 8 INT4s per uint32
# Reorder into tile-major layout
tiled = np.zeros(
(num_tiles_n, num_tiles_k, tile_n, tile_k // 8),
dtype=np.uint32
)
for tn in range(num_tiles_n):
for tk in range(num_tiles_k):
for i in range(tile_n):
for j in range(tile_k // 8):
src_row = tn * tile_n + i
src_col = tk * (tile_k // 8) + j
tiled[tn, tk, i, j] = W_q_packed[src_row, src_col]
return tiled
Warp-Level Dequantization
Each warp handles a portion of the tile. The dequantization happens in registers using fast bitwise operations:
// Marlin-style INT4 to FP16 dequantization in CUDA
// Processes 8 INT4 values packed in a uint32
__device__ __forceinline__ void dequantize_int4x8_to_fp16x8(
uint32_t packed_int4,
half scale,
half* output // 8 half values
) {
// Extract each 4-bit value and convert to FP16
#pragma unroll
for (int i = 0; i < 8; i++) {
int4_t val = ((packed_int4 >> (4 * i)) & 0xF) - 8;
output[i] = __hmul(__int2half_rn(val), scale);
}
}
// Optimized version using vectorized operations
__device__ __forceinline__ void dequantize_int4x8_fast(
uint32_t packed,
half scale,
uint4* output_vec // Vectorized output
) {
// Process pairs of INT4 -> FP16
half2 scale2 = __half2half2(scale);
// Extract low and high nibbles simultaneously
uint32_t low_mask = 0x0F0F0F0F;
uint32_t lows = packed & low_mask;
uint32_t highs = (packed >> 4) & low_mask;
// Convert each byte to FP16 pair
// This is simplified -- real Marlin uses PTX-level tricks
// with prmt and sub instructions for maximum throughput
}
Double-Buffered Pipeline
The key to Marlin’s performance is overlapping memory loads with computation:
// Simplified Marlin pipeline structure
__global__ void marlin_gemm(
const uint32_t* __restrict__ W_packed, // INT4 weights
const half* __restrict__ scales, // Per-group FP16 scales
const half* __restrict__ X, // FP16 activations
half* __restrict__ Y, // FP16 output
int M, int N, int K,
int group_size
) {
// Double-buffered shared memory
__shared__ half smem_X[2][TILE_M * TILE_K];
__shared__ uint32_t smem_W[2][TILE_N * (TILE_K / 8)];
__shared__ half smem_scales[2][TILE_N * (TILE_K / group_size)];
// Accumulator in registers (FP32 for precision)
float acc[TILE_M_PER_WARP][TILE_N_PER_WARP] = {0};
int buffer = 0;
// Prologue: load first tile
load_tile_async(smem_X[0], X, /*tile_k=*/0);
load_tile_async(smem_W[0], W_packed, /*tile_k=*/0);
load_tile_async(smem_scales[0], scales, /*tile_k=*/0);
__syncthreads();
// Main loop: process tiles along K dimension
for (int tile_k = 0; tile_k < K; tile_k += TILE_K) {
int next_buffer = 1 - buffer;
// Async load next tile (overlapped with compute)
if (tile_k + TILE_K < K) {
load_tile_async(smem_X[next_buffer], X, tile_k + TILE_K);
load_tile_async(smem_W[next_buffer], W_packed, tile_k + TILE_K);
load_tile_async(smem_scales[next_buffer], scales, tile_k + TILE_K);
}
// Dequantize INT4 -> FP16 in registers
half W_deq[TILE_N_PER_WARP][TILE_K];
dequantize_tile(smem_W[buffer], smem_scales[buffer], W_deq);
// Tensor core MMA: acc += X_tile @ W_deq^T
mma_tile(smem_X[buffer], W_deq, acc);
__syncthreads();
buffer = next_buffer;
}
// Epilogue: write accumulated results
store_output(Y, acc);
}
Marlin dequantizes INT4 to FP16 in registers rather than shared memory. This saves shared memory capacity (INT4 tiles are 4x smaller) and avoids the shared memory bank conflict overhead of writing dequantized FP16 values. The register-level dequantization adds ~2-3% overhead compared to loading pre-dequantized FP16 values.
Performance Model
The theoretical performance of a W4A16 kernel is bounded by:
where is the dequantization overhead.
def w4a16_performance_model(
M, N, K,
group_size=128,
mem_bw_tb_s=3.35, # H100 SXM
fp16_tflops=990, # H100 SXM tensor core FP16
deq_overhead_pct=0.03, # 3% dequantization overhead
):
"""Estimate W4A16 GEMM time vs FP16 GEMM time."""
# FP16 baseline
fp16_weight_bytes = N * K * 2
fp16_flops = 2 * M * N * K
fp16_mem_time = fp16_weight_bytes / (mem_bw_tb_s * 1e12)
fp16_compute_time = fp16_flops / (fp16_tflops * 1e12)
fp16_time = max(fp16_mem_time, fp16_compute_time)
# W4A16
w4_weight_bytes = N * K * 0.5 # 4 bits = 0.5 bytes
w4_scale_bytes = N * (K // group_size) * 2 # FP16 scales
w4_total_bytes = w4_weight_bytes + w4_scale_bytes
w4_flops = 2 * M * N * K # Same FLOPs (FP16 tensor cores)
w4_mem_time = w4_total_bytes / (mem_bw_tb_s * 1e12)
w4_compute_time = w4_flops / (fp16_tflops * 1e12)
w4_time = max(w4_mem_time, w4_compute_time) * (1 + deq_overhead_pct)
speedup = fp16_time / w4_time
regime = "bandwidth-bound" if w4_mem_time > w4_compute_time else "compute-bound"
return {
'fp16_time_us': fp16_time * 1e6,
'w4a16_time_us': w4_time * 1e6,
'speedup': speedup,
'regime': regime,
}
# Batch size sweep for Llama-2 7B attention projection (4096 x 4096)
print("H100 SXM: 4096x4096 GEMM")
for M in [1, 4, 16, 64, 256, 1024]:
result = w4a16_performance_model(M, 4096, 4096)
print(f" M={M:>4d}: FP16={result['fp16_time_us']:.1f}us, "
f"W4A16={result['w4a16_time_us']:.1f}us, "
f"speedup={result['speedup']:.2f}x [{result['regime']}]")
H100 SXM: 4096x4096 GEMM
M= 1: FP16=10.0us, W4A16=2.6us, speedup=3.82x [bandwidth-bound]
M= 4: FP16=10.0us, W4A16=2.6us, speedup=3.82x [bandwidth-bound]
M= 16: FP16=10.0us, W4A16=2.8us, speedup=3.58x [bandwidth-bound]
M= 64: FP16=10.0us, W4A16=3.9us, speedup=2.56x [bandwidth-bound]
M= 256: FP16=10.0us, W4A16=9.1us, speedup=1.10x [compute-bound]
M=1024: FP16=33.8us, W4A16=33.8us, speedup=1.00x [compute-bound]
W4A16 Speedup vs FP16 by Batch Size (H100, 4096x4096)
(Speedup over FP16)At batch size 256+, the GEMM becomes compute-bound and W4A16 offers no speedup over FP16. The kernel is doing the same FP16 tensor core operations regardless of weight format. This is why W4A16 is optimal for low-latency serving (small batches) and W8A8 or FP8 is preferred for high-throughput serving (large batches).
Marlin vs ExLlamaV2 vs cuBLAS FP16
Single-Token Decode Latency: Llama-2 7B (H100 SXM)
| Kernel | Weight Format | Latency (ms) | Throughput (tok/s) | vs FP16 |
|---|---|---|---|---|
| cuBLAS FP16 | FP16 | 3.91 | 256 | 1.0x |
| cuBLAS INT8 | W8A16 | 2.12 | 472 | 1.8x |
| ExLlamaV2 | GPTQ-INT4 g128 | 1.18 | 847 | 3.3x |
| Marlin | GPTQ-INT4 g128 | 1.02 | 980 | 3.8x |
| Marlin (g=channelwise) | INT4 per-channel | 0.98 | 1020 | 4.0x |
Prefill Throughput: Llama-2 7B, 2048 Tokens (H100 SXM)
| Kernel | Weight Format | Time (ms) | Throughput (tok/s) | vs FP16 |
|---|---|---|---|---|
| cuBLAS FP16 | FP16 | 42.1 | 48,600 | 1.0x |
| Marlin | GPTQ-INT4 g128 | 39.8 | 51,500 | 1.06x |
| cuBLAS INT8 TC | W8A8 INT8 | 23.2 | 88,300 | 1.82x |
Implementation: W4A16 Linear Layer
A complete W4A16 linear layer implementation for inference:
import torch
import torch.nn as nn
class W4A16Linear(nn.Module):
"""W4A16 quantized linear layer.
Stores weights as packed INT4, dequantizes to FP16 for GEMM.
"""
def __init__(self, in_features, out_features, group_size=128, bias=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size
# Pack 8 INT4 values into one int32
assert in_features % 8 == 0, "in_features must be divisible by 8"
packed_k = in_features // 8
num_groups = in_features // group_size
self.register_buffer(
'qweight', torch.zeros(out_features, packed_k, dtype=torch.int32)
)
self.register_buffer(
'scales', torch.zeros(out_features, num_groups, dtype=torch.float16)
)
if bias:
self.register_buffer(
'bias', torch.zeros(out_features, dtype=torch.float16)
)
else:
self.bias = None
@staticmethod
def pack_weights(int4_weights):
"""Pack INT4 weights (range [-8, 7]) into int32.
int4_weights: shape (N, K), dtype int8, values in [-8, 7]
Returns: shape (N, K // 8), dtype int32
"""
N, K = int4_weights.shape
assert K % 8 == 0
# Shift to unsigned: [0, 15]
unsigned = (int4_weights.to(torch.int32) + 8) & 0xF
# Pack 8 values per int32
packed = torch.zeros(N, K // 8, dtype=torch.int32,
device=int4_weights.device)
for i in range(8):
packed |= unsigned[:, i::8] << (4 * i)
return packed
@staticmethod
def unpack_weights(packed):
"""Unpack int32 to INT4 weights.
packed: shape (N, K // 8), dtype int32
Returns: shape (N, K), dtype int8, values in [-8, 7]
"""
N, packed_k = packed.shape
K = packed_k * 8
unpacked = torch.zeros(N, K, dtype=torch.int8,
device=packed.device)
for i in range(8):
unpacked[:, i::8] = ((packed >> (4 * i)) & 0xF).to(torch.int8) - 8
return unpacked
def dequantize(self):
"""Dequantize packed INT4 weights to FP16."""
int4_weights = self.unpack_weights(self.qweight) # (N, K)
# Reshape for per-group dequantization
N = self.out_features
K = self.in_features
num_groups = K // self.group_size
w_grouped = int4_weights.reshape(N, num_groups, self.group_size)
scales = self.scales.unsqueeze(2) # (N, num_groups, 1)
w_deq = w_grouped.to(torch.float16) * scales
return w_deq.reshape(N, K)
def forward(self, x):
"""Forward pass: dequantize weights and compute GEMM.
In production, this dequantization happens inside a fused
CUDA kernel (Marlin). This Python version is for correctness
verification only.
"""
W_fp16 = self.dequantize() # (N, K)
output = x @ W_fp16.T # (*, K) @ (K, N) -> (*, N)
if self.bias is not None:
output = output + self.bias
return output
@classmethod
def from_float(cls, linear, group_size=128):
"""Quantize a float linear layer to W4A16."""
in_f = linear.in_features
out_f = linear.out_features
layer = cls(in_f, out_f, group_size, bias=linear.bias is not None)
W = linear.weight.data.float()
num_groups = in_f // group_size
# Per-group quantization
W_grouped = W.reshape(out_f, num_groups, group_size)
group_max = W_grouped.abs().amax(dim=2)
scales = group_max / 7.0 # INT4 symmetric: [-8, 7], qmax=7
scales = scales.clamp(min=1e-10)
W_q = (W_grouped / scales.unsqueeze(2)).round().clamp(-8, 7)
W_q = W_q.reshape(out_f, in_f).to(torch.int8)
layer.qweight.copy_(cls.pack_weights(W_q))
layer.scales.copy_(scales.to(torch.float16))
if linear.bias is not None:
layer.bias.copy_(linear.bias.data.to(torch.float16))
return layer
# Verify correctness
torch.manual_seed(42)
linear_fp = nn.Linear(4096, 4096, bias=False)
nn.init.normal_(linear_fp.weight, std=0.02)
w4_layer = W4A16Linear.from_float(linear_fp, group_size=128)
x = torch.randn(1, 32, 4096)
with torch.no_grad():
y_fp = linear_fp(x)
y_w4 = w4_layer(x.half()).float()
mse = ((y_fp - y_w4) ** 2).mean().item()
cos_sim = torch.nn.functional.cosine_similarity(
y_fp.flatten(), y_w4.flatten(), dim=0
).item()
print(f"Output MSE: {mse:.6e}")
print(f"Cosine similarity: {cos_sim:.8f}")
vLLM Integration
Marlin is integrated into vLLM as the default kernel for GPTQ and AWQ models. The integration involves:
# vLLM's Marlin integration (simplified)
# Location: vllm/model_executor/layers/quantization/gptq_marlin.py
class GPTQMarlinLinearMethod:
"""Marlin kernel for GPTQ-quantized models."""
def __init__(self, quant_config):
self.group_size = quant_config.group_size
self.bits = quant_config.bits # 4 or 8
def create_weights(self, layer, input_size, output_size, params_dtype):
"""Allocate quantized weight buffers in Marlin layout."""
pack_factor = 32 // self.bits # 8 for INT4
packed_input = input_size // pack_factor
# Marlin expects a specific memory layout
qweight = torch.zeros(
packed_input, output_size, # Note: transposed vs standard
dtype=torch.int32
)
scales = torch.zeros(
input_size // self.group_size, output_size,
dtype=params_dtype
)
layer.register_parameter('qweight', nn.Parameter(qweight, requires_grad=False))
layer.register_parameter('scales', nn.Parameter(scales, requires_grad=False))
def apply(self, layer, x, bias=None):
"""Run Marlin GEMM kernel."""
# In production, this calls the Marlin CUDA kernel:
# marlin.marlin_gemm(x, layer.qweight, layer.scales,
# layer.workspace, x.shape[0], ...)
# The kernel handles:
# 1. INT4 unpacking
# 2. Per-group dequantization using scales
# 3. FP16 tensor core GEMM
# 4. Output accumulation in FP32, cast to FP16
output = marlin_gemm(
x, layer.qweight, layer.scales,
layer.workspace
)
if bias is not None:
output = output + bias
return output
Model Loading Pipeline
# Loading a GPTQ model with Marlin kernel in vLLM:
# Load GPTQ checkpoint (safetensors format)
# Verify compatibility: group_size=128, bits=4, symmetric, no act_order
# Repack weights from GPTQ layout to Marlin layout
# Store repacked weights on GPU
# Marlin has specific requirements:
MARLIN_REQUIREMENTS = {
'bits': [4], # Only 4-bit currently
'group_size': [128, -1], # 128 or channelwise
'symmetric': True, # No zero-point
'act_order': False, # GPTQ act_order breaks Marlin layout
'min_N': 64, # Minimum output dimension
'min_K': 128, # Minimum input dimension
}
When W4A16 is the Right Choice
# Decision matrix for weight format selection
def recommend_weight_format(
batch_size,
latency_slo_ms,
gpu_type,
model_size_B,
):
"""Recommend weight format based on deployment constraints."""
if batch_size <= 8:
# Decode-dominant: bandwidth-bound
if model_size_B <= 13:
return "W4A16 (GPTQ/AWQ + Marlin)"
else:
return "W4A16 (GPTQ/AWQ + Marlin), multi-GPU"
elif batch_size <= 64:
# Mixed regime
if gpu_type in ['H100', 'A100']:
return "W4A16 for decode, FP8/INT8 for prefill"
else:
return "W4A16 (Marlin)"
else:
# Throughput-dominant: compute-bound
if gpu_type == 'H100':
return "FP8 (W8A8 E4M3)"
elif gpu_type == 'A100':
return "W8A8 INT8 (SmoothQuant)"
else:
return "W4A16 (still BW-bound on older GPUs)"
Format Selection by Deployment Scenario
| Scenario | Batch Size | GPU | Best Format | Reason |
|---|---|---|---|---|
| Chat (1 user) | 1 | A100 | W4A16 Marlin | BW-bound, 3.8x faster |
| Chat (8 users) | 8 | H100 | W4A16 Marlin | Still BW-bound |
| Batch API | 128 | H100 | FP8 W8A8 | Compute-bound, 2x TC |
| Batch API | 128 | A100 | INT8 W8A8 | Compute-bound, 2x TC |
| Embedding | 256 | Any | FP16 or FP8 | Prefill-only, compute-bound |
| Edge (RTX 4090) | 1 | 4090 | W4A16 ExLlama | Consumer GPU, BW-bound |