The same model can run 2x faster just by letting the compiler choose better kernel parameters. This seems impossible — we’re not changing the algorithm, not quantizing, not fusing operations. We’re running the exact same GEMM operations on the same hardware. But a single matrix multiplication can be executed by hundreds of distinct CUDA kernel implementations, and the performance gap between the best and worst can exceed 3x. The GEMM that processes your FFN gate projection might take 0.65ms with the default cuBLAS heuristic or 0.21ms after TensorRT profiles every candidate and picks the optimal tile size, pipeline depth, and memory access pattern. Multiply that across 80 layers and thousands of requests per second, and autotuning becomes the difference between running at 70% utilization and 95% utilization on the same hardware you already paid for.
This post covers the full autotuning landscape for LLM inference: why so many kernel variants exist, how TensorRT exhaustively profiles them during its engine build phase, how torch.compile’s Inductor backend generates Triton kernels and selects optimal configurations, how cuBLAS uses heuristic tables instead of profiling, and when each approach is appropriate.
Why Hundreds of Kernels Exist for One Operation
A matrix multiplication where is and is is decomposed into tiles. Each thread block computes one tile of the output matrix . The choices involved in this decomposition create a combinatorial explosion of kernel variants.
Tile Size Selection
The output matrix is . We partition it into tiles of size . Each thread block computes one tile. The K dimension is iterated in chunks of size .
Common tile sizes on Hopper (sm_90):
Tile (T_M x T_N x T_K):
256x128x64 — large tiles, high register pressure, fewer thread blocks
128x256x64 — same total work per tile, different aspect ratio
128x128x64 — balanced, good occupancy
64x256x64 — tall-and-skinny output tiles
256x64x64 — wide-and-short output tiles
64x128x64 — small tiles, low register pressure, many thread blocks
64x64x64 — smallest typical tile
Each tile size implies a different number of thread blocks. For :
- Tile 256x128: blocks
- Tile 128x256: blocks
- Tile 64x64: blocks
The H100 has 132 SMs. With 1376 blocks, each SM processes ~10 blocks (good occupancy, moderate wave quantization). With 11008 blocks, each SM processes ~83 blocks (very high parallelism but smaller tiles mean lower compute intensity per block).
Thread Block Shape and Warp Layout
Within a thread block, warps (groups of 32 threads) are arranged in a 2D grid. For a 256x128 tile with 256 threads (8 warps):
Warp layout options for 8 warps computing [256, 128]:
Layout A: 4 warps in M x 2 warps in N — each warp handles [64, 64]
Layout B: 2 warps in M x 4 warps in N — each warp handles [128, 32]
Layout C: 8 warps in M x 1 warp in N — each warp handles [32, 128]
Layout D: 1 warp in M x 8 warps in N — each warp handles [256, 16]
Each layout has different implications for shared memory bank conflicts, register usage, and instruction-level parallelism. Layout A is balanced; Layout D minimizes shared memory reads for but increases them for .
Pipeline Stages
Modern GEMM kernels overlap data loading with computation using software pipelining. The number of pipeline stages determines how much data is “in flight” at any time:
Stages=2: load tile k+1 while computing tile k
Minimum latency hiding, minimum shared memory usage
Shared memory: 2 * (T_M * T_K + T_K * T_N) * dtype_size
Stages=3: load tile k+2 while computing tile k, tile k+1 in buffer
Better latency hiding, 50% more shared memory
Shared memory: 3 * (T_M * T_K + T_K * T_N) * dtype_size
Stages=4: load tile k+3 while computing tile k
Best latency hiding for high-latency memory
Shared memory: 4 * (T_M * T_K + T_K * T_N) * dtype_size
For tile 128x128x64 in FP16 (2 bytes):
- 2 stages: 64 KB shared memory
- 3 stages: 96 KB shared memory
- 4 stages: 128 KB shared memory
The H100 has 228 KB of shared memory per SM. At 128 KB (4 stages), only 1 thread block can run per SM. At 64 KB (2 stages), 3 thread blocks can run per SM. More concurrent blocks means better latency hiding through warp scheduling, but fewer stages means worse pipelining within each block.
For a single GEMM shape, the tuning space includes: 7+ tile sizes, 4+ warp layouts per tile, 3+ stage counts, 2+ swizzle patterns, 2+ epilogue variants (with/without bias fusion). That gives candidate kernels. In practice, CUTLASS enumerates 200-500 valid configurations per GEMM shape, and TensorRT’s internal library contains thousands.
Memory Access Patterns
How thread blocks traverse the output matrix also affects L2 cache behavior:
Linear order: block(0,0), block(0,1), block(0,2), ..., block(1,0), ...
— Accesses column tiles of B sequentially, poor B reuse in L2
Swizzled order: block(0,0), block(1,0), block(0,1), block(1,1), ...
— 2x2 block clusters, better L2 reuse for both A and B
Grouped order: block(0,0), block(0,1), ..., block(0,G), block(1,0), ...
— Group G columns, maximizes A-tile reuse within group
The L2 cache on H100 is 50 MB. For a Llama 70B FFN weight matrix of shape in FP16, the total weight size is 896 MB — far exceeding L2. But during a single “wave” of thread blocks (132 blocks for 132 SMs), each block reads a different row-strip of and a different column-strip of . Swizzled scheduling ensures that adjacent SMs read adjacent strips, improving L2 hit rates.
TensorRT: Exhaustive Profiling During Build
TensorRT’s approach to kernel selection is straightforward: try every candidate kernel, time each one, pick the fastest.
The Build Phase
When you build a TensorRT engine, the builder:
- Parses the model graph (ONNX or TensorRT network definition)
- Applies graph optimizations (layer fusion, constant folding)
- For each resulting operation, enumerates all candidate kernels
- Profiles each candidate kernel on the target GPU
- Selects the fastest kernel for each operation
- Serializes the engine with the selected kernels
import tensorrt as trt
logger = trt.Logger(trt.Logger.VERBOSE)
builder = trt.Builder(logger)
config = builder.create_builder_config()
# Control autotuning behavior
config.set_flag(trt.BuilderFlag.FP16)
# Timing iterations: more iterations = more accurate selection
# but longer build time
config.avg_timing_iterations = 8 # default is 8
# Builder optimization level: 0-5
# Level 3 (default): standard autotuning
# Level 5: exhaustive search including less common kernels
config.builder_optimization_level = 5
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
# Parse ONNX model
parser = trt.OnnxParser(network, logger)
with open("model.onnx", "rb") as f:
parser.parse(f.read())
# Build engine — this is where autotuning happens
# For a 7B model: 5-30 minutes
# For a 70B model: 30-120 minutes
engine = builder.build_serialized_network(network, config)
# Save for deployment — no need to rebuild on same GPU
with open("model.engine", "wb") as f:
f.write(engine)
What Gets Profiled
During the build, TensorRT’s verbose log reveals the profiling process. For a single linear layer:
[TensorRT] VERBOSE: Tactic: 0x0000000000000001
Inputs: {FP16[1,4096]} Outputs: {FP16[1,11008]}
Kernel: sm90_xmma_gemm_f16_f16_f16_f32_256x128x64_3stage_nn
Time: 0.0234ms
[TensorRT] VERBOSE: Tactic: 0x0000000000000002
Inputs: {FP16[1,4096]} Outputs: {FP16[1,11008]}
Kernel: sm90_xmma_gemm_f16_f16_f16_f32_128x256x64_3stage_nn
Time: 0.0241ms
[TensorRT] VERBOSE: Tactic: 0x0000000000000003
Inputs: {FP16[1,4096]} Outputs: {FP16[1,11008]}
Kernel: sm90_xmma_gemm_f16_f16_f16_f32_128x128x64_4stage_nn
Time: 0.0219ms <-- fastest
... (200+ more tactics)
[TensorRT] VERBOSE: Selected tactic 0x0000000000000003
Each “tactic” is a complete kernel configuration. TensorRT runs each one multiple times (controlled by avg_timing_iterations), discards outliers, and selects the one with the lowest median time.
TensorRT Timing Cache
Because profiling takes minutes to hours, TensorRT supports a timing cache that stores profiling results across builds:
# Save timing cache after build
timing_cache = config.get_timing_cache()
with open("timing_cache.bin", "wb") as f:
f.write(timing_cache.serialize())
# Load timing cache for next build (same GPU only)
with open("timing_cache.bin", "rb") as f:
config.set_timing_cache(
config.create_timing_cache(f.read()),
ignore_mismatch=False
)
The timing cache is GPU-specific. A cache built on an H100 SXM is invalid on an H100 PCIe (different memory bandwidth, different clock speeds). Even two H100 SXMs can have slightly different optimal kernels due to silicon variation, though in practice the differences are negligible.
Dynamic Shapes and Autotuning
LLM inference involves dynamic shapes: the batch size and sequence length change every iteration. TensorRT handles this with optimization profiles:
profile = builder.create_optimization_profile()
profile.set_shape(
"input_ids",
min=(1, 1), # minimum shape
opt=(32, 512), # optimal shape (autotuned for this)
max=(64, 2048) # maximum shape
)
config.add_optimization_profile(profile)
TensorRT autotunes primarily for the opt shape. For shapes far from opt, it falls back to a heuristic that interpolates between profiled results. This is why TensorRT users often create multiple engines for different batch size ranges.
TensorRT Build Time vs Optimization Level
| Optimization Level | Build Time (Llama 7B) | Build Time (Llama 70B) | Inference Speedup vs Level 3 |
|---|---|---|---|
| Level 0 (no autotuning) | 15 seconds | 2 minutes | 0.82x (slower) |
| Level 3 (default) | 8 minutes | 45 minutes | 1.00x (baseline) |
| Level 4 (extended) | 18 minutes | 90 minutes | 1.03x |
| Level 5 (exhaustive) | 35 minutes | 150 minutes | 1.05x |
The diminishing returns beyond level 3 explain why most production deployments use the default. The additional 2-5% from levels 4-5 is real but rarely justifies 3-5x longer build times, especially during iterative development.
torch.compile: Inductor and Triton Kernel Generation
torch.compile takes a fundamentally different approach. Instead of choosing from a pre-existing library of handwritten CUDA kernels, the Inductor backend generates Triton kernels on the fly and benchmarks them.
The Compilation Pipeline
import torch
model = load_llm("llama-7b")
# Default mode: moderate autotuning
compiled_model = torch.compile(model, mode="default")
# Reduce-overhead mode: CUDA graph integration
compiled_model = torch.compile(model, mode="reduce-overhead")
# Max-autotune mode: exhaustive kernel search
compiled_model = torch.compile(model, mode="max-autotune")
When you call torch.compile, the following stages execute on the first forward pass:
Stage 1: Dynamo (Python → FX graph)
— Traces Python code into an intermediate representation
— Time: 2-10 seconds for a 7B model
Stage 2: AOTAutograd (FX graph → Aten IR)
— Decomposes high-level ops into primitive operations
— Generates backward graph if needed (not for inference)
— Time: 1-5 seconds
Stage 3: Inductor (Aten IR → Triton/C++ kernels)
— Generates Triton kernel code for each fused subgraph
— Applies operator fusion (element-wise chains, reductions)
— Time: 5-30 seconds
Stage 4: Autotuning (Triton kernels → optimal configs)
— Benchmarks multiple configurations for each generated kernel
— Time: 20-300 seconds (the dominant cost with max-autotune)
Stage 5: Code caching
— Serializes generated code and selected configs to disk
— Subsequent runs skip stages 1-4
Inductor Kernel Generation
Inductor does not call cuBLAS for every GEMM. For standalone GEMMs, it delegates to cuBLAS or a Triton GEMM template. For fused operations (GEMM + bias + activation), it generates custom Triton kernels:
# What the user writes:
def fused_ffn(x, w_gate, w_up, w_down):
gate = torch.mm(x, w_gate)
up = torch.mm(x, w_up)
hidden = torch.nn.functional.silu(gate) * up
return torch.mm(hidden, w_down)
# What Inductor generates (simplified Triton for the silu*up fusion):
"""
@triton.jit
def fused_silu_mul_kernel(
gate_ptr, up_ptr, out_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
gate = tl.load(gate_ptr + offsets, mask=mask)
up = tl.load(up_ptr + offsets, mask=mask)
# SiLU = x * sigmoid(x)
silu_gate = gate * tl.sigmoid(gate)
result = silu_gate * up
tl.store(out_ptr + offsets, result, mask=mask)
"""
The GEMMs themselves go through a different path. Inductor’s GEMM handling:
# Inductor's decision tree for GEMM (simplified):
def select_gemm_backend(M, N, K, dtype):
if dtype in (torch.float16, torch.bfloat16):
if M * N * K > 1_000_000: # large enough for cuBLAS
return "cublas"
else:
return "triton_gemm_template"
elif dtype == torch.float8_e4m3fn:
return "cublas_lt" # cuBLAS Light for FP8
else:
return "triton_gemm_template"
For large GEMMs (the FFN projections), Inductor typically delegates to cuBLAS. For smaller GEMMs or fused GEMM+activation patterns, it generates Triton kernels.
Triton Autotuning
When Inductor generates a Triton kernel, it creates multiple configurations and benchmarks them:
import triton
import triton.language as tl
# Inductor generates this autotune decorator
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "num_stages": 3, "num_warps": 8},
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "num_stages": 4, "num_warps": 4},
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "num_stages": 4, "num_warps": 4},
),
triton.Config(
{"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "num_stages": 4, "num_warps": 4},
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "num_stages": 4, "num_warps": 4},
),
triton.Config(
{"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "num_stages": 5, "num_warps": 2},
),
],
key=["M", "N", "K"], # re-autotune when these change
)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=offs_k[None, :] + k < K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] + k < K, other=0.0)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = accumulator.to(tl.float16)
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, c, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
The @triton.autotune decorator benchmarks each config by running the kernel multiple times and measuring wall-clock time. The winning config is cached.
max-autotune Mode
With mode="max-autotune", Inductor expands the search space:
# Default mode: ~6 Triton configs per kernel, cuBLAS for large GEMMs
# max-autotune mode: ~20 Triton configs per kernel + cuBLAS comparison
# max-autotune also enables:
# CUDA graph capture (eliminates kernel launch overhead)
# cuBLAS vs Triton comparison for every GEMM
# Epilogue fusion search (try fusing bias/activation into GEMM)
The max-autotune compilation for a 7B model takes 2-5 minutes. For a 70B model (which typically uses tensor parallelism, so each rank compiles its shard), 3-8 minutes per rank.
torch.compile Compilation Time by Mode
| Mode | Compile Time (7B) | Compile Time (70B/rank) | Speedup vs Eager |
|---|---|---|---|
| eager (no compile) | 0 seconds | 0 seconds | 1.00x |
| default | 30-60 seconds | 45-90 seconds | 1.05-1.15x |
| reduce-overhead | 40-80 seconds | 60-120 seconds | 1.10-1.25x |
| max-autotune | 120-300 seconds | 180-480 seconds | 1.15-1.30x |
You can combine modes: torch.compile(model, mode="max-autotune-no-cudagraphs") followed by manual CUDA graph capture, or just mode="max-autotune" which includes CUDA graphs. The combination delivers both optimal kernels and eliminated launch overhead. For LLM inference where the same shapes repeat every decode step, this combination is highly effective.
cuBLAS Heuristics: Fast Selection Without Profiling
cuBLAS takes a third approach: no profiling at all. It uses precomputed heuristic tables that map (M, N, K, dtype, GPU architecture) to a kernel selection.
How cuBLAS Selects Kernels
# Pseudocode for cuBLAS kernel selection:
def cublas_select_gemm(M, N, K, dtype, gpu_arch):
# Step 1: Filter by compatibility
candidates = get_kernels_for_arch(gpu_arch, dtype)
# Step 2: Heuristic scoring based on shape
for kernel in candidates:
tile_m, tile_n, tile_k = kernel.tile_size
# Wave quantization: how many "waves" of thread blocks?
blocks_m = ceil(M / tile_m)
blocks_n = ceil(N / tile_n)
total_blocks = blocks_m * blocks_n
num_sms = get_sm_count(gpu_arch) # 132 for H100
num_waves = ceil(total_blocks / num_sms)
# Tail effect: last wave may have poor SM utilization
tail_utilization = (total_blocks % num_sms) / num_sms
if total_blocks % num_sms == 0:
tail_utilization = 1.0
# Score based on tile efficiency and wave utilization
kernel.score = compute_heuristic_score(
tile_efficiency=tile_m * tile_n * tile_k,
wave_utilization=tail_utilization,
num_waves=num_waves,
shared_memory=kernel.smem_usage,
)
# Step 3: Return highest-scoring kernel
return max(candidates, key=lambda k: k.score)
The heuristic is tuned by NVIDIA engineers using profiling data from representative shapes. It works well for common shapes (powers of 2, standard transformer dimensions) but can be suboptimal for unusual shapes.
Where Heuristics Fail
cuBLAS heuristics have known weaknesses:
1. Non-power-of-2 dimensions. When is not a multiple of common tile sizes, wave quantization causes SM underutilization. For example, with (Llama 7B FFN) and tile size 256:
With (single token decode): total blocks on 132 SMs. Only 43/132 = 33% SM utilization. A tile size of 128 gives 86 blocks (65% utilization). Tile size 64 gives 172 blocks (full wave + partial second wave). The heuristic must weigh these tradeoffs without actually measuring.
2. Small M (decode case). When , the GEMM is really a matrix-vector multiply. cuBLAS may select a GEMM kernel instead of a specialized GEMV kernel. The specialized kernel is typically 10-20% faster because it avoids tile setup overhead.
3. Grouped/batched GEMM. For Mixture-of-Experts models, each expert processes a different number of tokens. cuBLAS grouped GEMM heuristics are less mature than single-GEMM heuristics.
import torch
# Demonstrate cuBLAS suboptimality on non-standard shapes
# Standard shape: cuBLAS heuristic is well-tuned
M, N, K = 1, 4096, 4096
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
B = torch.randn(K, N, dtype=torch.float16, device="cuda")
# Warm up
for _ in range(10):
torch.mm(A, B)
torch.cuda.synchronize()
import time
start = time.perf_counter()
for _ in range(1000):
torch.mm(A, B)
torch.cuda.synchronize()
standard_time = (time.perf_counter() - start) / 1000
# Non-standard shape: cuBLAS heuristic may be suboptimal
M, N, K = 1, 11008, 4096
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
B = torch.randn(K, N, dtype=torch.float16, device="cuda")
for _ in range(10):
torch.mm(A, B)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(1000):
torch.mm(A, B)
torch.cuda.synchronize()
nonstandard_time = (time.perf_counter() - start) / 1000
# The non-standard shape may show lower TFLOPS utilization
print(f"Standard shape TFLOPS: {2*1*4096*4096/standard_time/1e12:.2f}")
print(f"Non-standard shape TFLOPS: {2*1*11008*4096/nonstandard_time/1e12:.2f}")
cuBLAS vs Autotuned Kernel: TFLOPS by Shape (H100 SXM, FP16)
(TFLOPS)The gap between cuBLAS heuristic and autotuned selection is largest for small with non-power-of-2 . At large batch sizes where the GEMM is compute-bound, both approaches approach peak TFLOPS and the gap narrows.
When Autotuning Matters: Quantifying the Improvement
Autotuning improvement varies by workload. Here is a systematic breakdown.
Shape-Dependent Improvement
The improvement from autotuning correlates with how well the GEMM shape fits common tile sizes:
Autotuning Improvement Over cuBLAS Heuristic by Shape Category
| Shape Category | Example | cuBLAS TFLOPS | Autotuned TFLOPS | Improvement |
|---|---|---|---|---|
| Power-of-2, large M | [512, 4096] x [4096, 4096] | 850 | 870 | 2.4% |
| Power-of-2, small M | [1, 4096] x [4096, 4096] | 42 | 45 | 7.1% |
| Non-PoT N, large M | [512, 4096] x [4096, 11008] | 820 | 860 | 4.9% |
| Non-PoT N, small M | [1, 4096] x [4096, 11008] | 38 | 48 | 26.3% |
| GQA shapes | [1, 8192] x [8192, 1024] | 18 | 24 | 33.3% |
| MoE expert | [17, 4096] x [4096, 11008] | 195 | 240 | 23.1% |
The pattern is clear: autotuning matters most when the GEMM shape is “weird” from the perspective of standard tile sizes. Single-token decode (), GQA key-value projections (small ), and MoE expert routing (irregular ) all benefit substantially.
End-to-End Throughput Impact
The per-kernel improvement translates to end-to-end throughput improvement, attenuated by Amdahl’s law:
End-to-end speedup = 1 / (1 - fraction_GEMM + fraction_GEMM / kernel_speedup)
For Llama 70B decode (GEMM is ~85% of time):
5% kernel speedup -> 1 / (0.15 + 0.85/1.05) = 1.041 -> 4.1% end-to-end
15% kernel speedup -> 1 / (0.15 + 0.85/1.15) = 1.116 -> 11.6% end-to-end
30% kernel speedup -> 1 / (0.15 + 0.85/1.30) = 1.207 -> 20.7% end-to-end
In practice, autotuning delivers 5-15% end-to-end throughput improvement for standard transformer models. For MoE models with irregular expert routing, the improvement can reach 20-30%.
Implementation: torch.compile with max-autotune
Here is a complete implementation of autotuned LLM inference using torch.compile:
import torch
import os
import time
# Enable Inductor's GEMM autotuning
# This tells Inductor to benchmark cuBLAS against Triton for each GEMM
os.environ["TORCHINDUCTOR_MAX_AUTOTUNE"] = "1"
os.environ["TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS"] = "TRITON,ATen"
# Cache directory for compiled kernels
# First compile takes 2-5 minutes; subsequent runs use cache
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/inductor_cache"
# Optional: increase Triton autotune configs
os.environ["TRITON_MAX_AUTOTUNE_CONFIGS"] = "20"
class AutotunedInferenceEngine:
"""LLM inference engine with torch.compile autotuning."""
def __init__(self, model_path, max_batch_size=64, max_seq_len=4096):
self.device = torch.device("cuda")
# Load model
self.model = self._load_model(model_path)
self.model.eval()
self.model.to(self.device)
# Compile with max-autotune
self.compiled_model = None
self.compile_time = None
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
def _load_model(self, path):
"""Load model weights. Implementation depends on model format."""
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto",
)
return model
def compile(self):
"""Compile model with max-autotune. Call once before inference."""
start = time.perf_counter()
self.compiled_model = torch.compile(
self.model,
mode="max-autotune",
fullgraph=False,
dynamic=True, # support variable batch sizes
)
# Warm up: trigger compilation with representative shapes
# Inductor compiles lazily on first call
warmup_shapes = [
(1, 1), # single token decode
(1, 128), # short prefill
(1, 512), # medium prefill
(32, 1), # batched decode
]
for batch_size, seq_len in warmup_shapes:
dummy_input = torch.randint(
0, 32000,
(batch_size, seq_len),
device=self.device,
)
with torch.no_grad():
self.compiled_model(dummy_input)
torch.cuda.synchronize()
self.compile_time = time.perf_counter() - start
print(f"Compilation complete in {self.compile_time:.1f}s")
@torch.inference_mode()
def generate(self, input_ids, max_new_tokens=128):
"""Generate tokens using compiled model."""
if self.compiled_model is None:
raise RuntimeError("Call compile() before generate()")
batch_size = input_ids.shape[0]
generated = input_ids.clone()
# KV cache is managed internally by HuggingFace
past_key_values = None
for step in range(max_new_tokens):
if past_key_values is None:
# Prefill: process all input tokens
outputs = self.compiled_model(
input_ids=generated,
use_cache=True,
)
else:
# Decode: process only the last token
outputs = self.compiled_model(
input_ids=generated[:, -1:],
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = outputs.past_key_values
logits = outputs.logits[:, -1, :]
# Greedy sampling
next_token = logits.argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=-1)
# Check for EOS
if (next_token == self.model.config.eos_token_id).all():
break
return generated
def benchmark_autotuning(model_path):
"""Compare eager vs compiled inference throughput."""
engine = AutotunedInferenceEngine(model_path)
# Benchmark eager mode
input_ids = torch.randint(
0, 32000, (1, 128), device="cuda"
)
# Eager baseline
engine.model.eval()
with torch.inference_mode():
# Warm up
for _ in range(3):
engine.model(input_ids)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(50):
engine.model(input_ids)
torch.cuda.synchronize()
eager_time = (time.perf_counter() - start) / 50
# Compiled with max-autotune
engine.compile()
with torch.inference_mode():
# Warm up (compilation already done)
for _ in range(3):
engine.compiled_model(input_ids)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(50):
engine.compiled_model(input_ids)
torch.cuda.synchronize()
compiled_time = (time.perf_counter() - start) / 50
speedup = eager_time / compiled_time
print(f"Eager: {eager_time*1000:.2f} ms/forward")
print(f"Compiled: {compiled_time*1000:.2f} ms/forward")
print(f"Speedup: {speedup:.2f}x")
return eager_time, compiled_time
Inspecting Selected Kernels
After compilation, you can inspect which kernels Inductor selected:
# Enable Inductor debug logging
import torch._inductor.config as inductor_config
inductor_config.debug = True
inductor_config.trace.enabled = True
inductor_config.trace.log_dir = "/tmp/inductor_traces"
# After compilation, check the trace directory:
# /tmp/inductor_traces/
# model__0_forward/
# output_code.py <-- generated Triton/C++ code
# fx_graph_readable.py <-- the FX graph before codegen
# ir_post_fusion.txt <-- IR after operator fusion
The output_code.py file contains the actual generated Triton kernels with the selected configurations:
# Example excerpt from output_code.py (auto-generated by Inductor):
# Selected config for fused_silu_mul:
# BLOCK_SIZE=1024, num_warps=4, num_stages=3
# Autotuning tried 6 configs, selected in 0.34s
# Selected config for mm (4096x11008):
# Backend: cuBLAS (faster than Triton by 8%)
# Algorithm: CUBLAS_GEMM_DEFAULT_TENSOR_OP
# Selected config for mm (4096x4096):
# Backend: Triton
# BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, num_stages=4, num_warps=4
# Autotuning tried 12 configs, selected in 2.1s
Handling Dynamic Shapes
LLM inference has dynamic shapes (varying batch sizes, sequence lengths). torch.compile handles this through dynamic shape support:
# Option 1: Dynamic shapes (recompiles when shape changes dramatically)
compiled = torch.compile(model, dynamic=True)
# Option 2: Explicit dynamic dimensions
from torch._dynamo import mark_dynamic
input_ids = torch.randint(0, 32000, (1, 128), device="cuda")
mark_dynamic(input_ids, 0) # batch dimension is dynamic
mark_dynamic(input_ids, 1) # sequence dimension is dynamic
compiled(input_ids)
# Option 3: Multiple compilations for different shape ranges
# (manual, but avoids recompilation overhead)
compiled_decode = torch.compile(model, mode="max-autotune")
compiled_prefill = torch.compile(model, mode="max-autotune")
# Warm up each for its target shape range
with torch.no_grad():
compiled_decode(torch.randint(0, 32000, (32, 1), device="cuda"))
compiled_prefill(torch.randint(0, 32000, (1, 512), device="cuda"))
When dynamic=True, torch.compile generates shape-generic code that works for any shape. This is slightly slower than shape-specific code because the kernel cannot specialize on exact dimensions. When dynamic=False (default), a new shape triggers recompilation. For LLM serving with variable batch sizes, dynamic=True is usually the better choice — the 2-5% loss from shape-generic kernels is cheaper than repeated recompilation.
Autotuning for Quantized Kernels
Quantized inference (INT4, INT8, FP8) introduces additional autotuning dimensions because the dequantization can be fused with the GEMM or done separately.
W4A16 (INT4 weights, FP16 activations)
# INT4 GEMM autotuning dimensions:
# Dequant location: in registers vs shared memory
# Group size: per-channel vs group-128 vs group-32
# Tile size: same as FP16 but with different optimal points
# Inner loop: dequant-then-accumulate vs packed-accumulate
# Example: Marlin kernel (optimized INT4 GEMM) configuration space
marlin_configs = [
# (thread_m, thread_n, thread_k, stages, group_size)
(16, 256, 64, 4, 128), # wide output tiles
(16, 128, 64, 4, 128), # balanced
(16, 64, 128, 3, 128), # deep K tiles for better dequant amortization
(16, 256, 64, 4, -1), # per-channel quantization (no groups)
]
For W4A16 GEMMs, the autotuning improvement over heuristic selection is typically 10-25%, larger than for FP16 GEMMs. The reason: the dequantization overhead makes tile size selection more shape-sensitive. A tile that is optimal for FP16 may be suboptimal for INT4 because the dequantization cost changes the compute-to-memory ratio.
FP8 Autotuning
FP8 GEMMs have the additional dimension of scaling granularity:
# FP8 scaling options that affect kernel selection:
# Per-tensor scaling: one scale factor per matrix
# — Simplest, fastest kernel, lowest accuracy
# Per-channel scaling: one scale factor per output channel
# — Requires modified epilogue to apply channel-wise scales
# Per-block scaling: one scale factor per tile
# — Most accurate, most complex kernel, NVIDIA's "deep learning" format
# TensorRT selects among FP8 tactics:
# sm90_xmma_gemm_e4m3_e4m3_f32_f32_128x128x64_3stage
# sm90_xmma_gemm_e4m3_e4m3_f32_f32_256x128x64_3stage
# ... plus per-channel and per-block variants
Production Deployment Patterns
TensorRT-LLM Autotuning
TensorRT-LLM combines TensorRT’s engine building with LLM-specific optimizations:
# TensorRT-LLM build command with autotuning
# trtllm-build --model_dir ./llama-7b-hf \
# --dtype float16 \
# --max_batch_size 64 \
# --max_input_len 2048 \
# --max_seq_len 4096 \
# --gemm_plugin float16 \
# --builder_opt_level 4
# The --gemm_plugin flag enables TensorRT-LLM's custom
# GEMM plugin which has its own autotuning:
# Profiles cuBLAS tactics
# Profiles CUTLASS tactics
# Profiles custom fused GEMM+dequant tactics (for quantized models)
# Selects best per shape
Persistent Autotuning Results
Both TensorRT and torch.compile support caching autotuning results:
# TensorRT: engine file IS the cached result
# Just deploy the .engine file, no re-autotuning needed
# torch.compile: cache directory
import torch._inductor.config as config
config.cache_dir = "/persistent/storage/inductor_cache"
# First run: compiles and autotunes (slow)
# Subsequent runs: loads from cache (fast)
# Triton: cache directory
import os
os.environ["TRITON_CACHE_DIR"] = "/persistent/storage/triton_cache"
Autotuning Strategy Comparison
| Strategy | Setup Time | Per-Shape Overhead | Kernel Quality | Dynamic Shape Support |
|---|---|---|---|---|
| TensorRT (build) | 5-150 min | 0 (offline) | Best | Limited (opt profiles) |
| torch.compile max-autotune | 2-8 min | 0.1-2s (lazy) | Very Good | Good (dynamic=True) |
| torch.compile default | 30-90s | 0 (lazy) | Good | Good |
| cuBLAS heuristic | 0 | 0 | Good for standard shapes | Excellent |
| Triton manual autotune | Variable | 1-10s per kernel | Depends on configs | Manual |
Use TensorRT when you have fixed shapes (batch size ranges known in advance), maximum throughput matters, and build time is acceptable. Use torch.compile max-autotune when you need Python-level flexibility, dynamic shapes, or rapid iteration. Use cuBLAS heuristics (no autotuning) when development speed matters and 5-15% throughput loss is acceptable. For MoE models with irregular expert routing, always autotune — the heuristic penalty is too large.
Measuring Autotuning Effectiveness
To validate that autotuning is working and quantify its benefit, profile the compiled model and compare kernel performance:
import torch
from torch.profiler import profile, ProfilerActivity, schedule
def profile_autotuned_model(model, compiled_model, input_ids):
"""Profile eager vs compiled to measure autotuning benefit."""
# Profile eager
with profile(
activities=[ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True,
) as prof_eager:
with torch.inference_mode():
for _ in range(10):
model(input_ids)
# Profile compiled
with profile(
activities=[ProfilerActivity.CUDA],
record_shapes=True,
with_stack=True,
) as prof_compiled:
with torch.inference_mode():
for _ in range(10):
compiled_model(input_ids)
# Print kernel-level comparison
print("=== Eager (cuBLAS heuristic) ===")
print(prof_eager.key_averages().table(
sort_by="cuda_time_total", row_limit=10
))
print("=== Compiled (autotuned) ===")
print(prof_compiled.key_averages().table(
sort_by="cuda_time_total", row_limit=10
))
# Extract GEMM kernel times
eager_gemm_time = sum(
e.cuda_time_total
for e in prof_eager.key_averages()
if "gemm" in e.key.lower() or "mm" in e.key.lower()
)
compiled_gemm_time = sum(
e.cuda_time_total
for e in prof_compiled.key_averages()
if "gemm" in e.key.lower()
or "mm" in e.key.lower()
or "triton" in e.key.lower()
)
print(f"Eager GEMM time: {eager_gemm_time/1000:.2f} ms")
print(f"Compiled GEMM time: {compiled_gemm_time/1000:.2f} ms")
print(f"GEMM speedup: {eager_gemm_time/compiled_gemm_time:.2f}x")