Part of Series Inference Optimization Timeline 18 of 60
1 Transformer Fundamentals for Systems Engineers: The 10-Minute Bridge from Architecture to Inference 2 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 3 KV Cache: The Hidden Memory Giant in LLM Serving 4 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 5 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 6 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 7 Continuous Batching: The Complete Guide to LLM Inference Scheduling 8 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 9 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 10 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 11 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 12 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 13 Mamba and State Space Models: The O(n) Alternative to Attention 14 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 15 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 16 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 17 Model Loading and Cold Start: safetensors, mmap, and Startup Optimization 18 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 19 Kernel Autotuning: How TensorRT and torch.compile Find Optimal CUDA Kernels 20 Attention Kernel Comparison: FlashAttention vs FlashInfer vs xformers vs Triton 21 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 22 Dynamic Batching: Orca, Sarathi, and Iteration-Level Scheduling Algorithms 23 Memory Pool Management: Slab Allocators for GPU Inference 24 Prefill vs Decode Optimization: Different Bottlenecks, Different Solutions 25 Decode Optimization: CUDA Graphs, Persistent Batches, and Speculative Verification 26 Multi-Model Serving: GPU Sharing, Model Switching, and Adapter Pool Management 27 Structured Output Acceleration: Compressed FSMs, Speculative JSON, and Grammar Caching 28 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 29 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 30 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 31 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 32 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification 33 Disaggregated Serving v2: Mooncake KV-Centric Architecture and LoongServe Elastic SP 34 Request Preemption and Priority Scheduling in Production LLM Serving 35 Autoscaling LLM Inference: Signals, Lag, Warm Pools, and Cost-Optimal Scaling 36 The Inference Stack in 2026: From HTTP Request to GPU Kernel and Back 37 Video and Audio LLM Serving: Temporal Encoding, Chunked Streaming, and Latency Budgets 38 KV Cache Compression and Eviction: H2O, Attention Sinks, Sliding Window, and Quantized KV 39 Distributed Inference: Tensor Parallelism vs Pipeline Parallelism for Serving 40 Serving Benchmark Methodology: How to Properly Measure LLM Inference Performance 41 Compute-Communication Overlap: Hiding Distributed Training Latency 42 DeepSpeed ZeRO: Memory Optimization for Distributed Training at Scale 43 Pipeline Parallelism: From GPipe to DualPipe -- Eliminating the Bubble 44 Gradient Compression for Distributed Training: Promise, Reality, and Where It Still Wins 45 The Definitive Guide to Distributed Parallelism: Data, Tensor, Pipeline, Expert, and Sequence Parallelism for Large-Scale Training 46 Decoding Performance: Beam Search vs Sampling — Latency, Throughput, Memory, and the Full Design Space 47 LLM Prefill Phase Optimization: Why Prompt Processing Is Compute-Bound and How to Fix It 48 LLM Serving Engines: vLLM vs SGLang vs TensorRT-LLM — A Systems Comparison 49 Request Routing for LLM Inference: From Naive Load Balancing to KV Cache-Aware Scheduling 50 Why Adam Is Expensive and What To Do About It: 8-bit Adam, Adafactor, CAME, and the Memory Math of Optimizers 51 How Large Models Actually Get Loaded: Safetensors, mmap, Tensor Parallelism, and Progressive Loading 52 Mixed Precision Training: The Complete Precision Landscape from FP32 to FP4 53 Model Compression: Pruning, Distillation, and Why Quantization Won 54 From NAS to Scaling Laws: How We Design LLM Architectures Now 55 NVIDIA NCCL Performance Tuning for Multi-GPU Training 56 ONNX Runtime in Practice: Graph Optimization, Execution Providers, Quantization, and When ORT Is the Right Choice 57 Optimizing GEMM for Neural Networks: BLAS vs Custom Kernels (Nov 2019) 58 Long Context: From Sparse Attention to Ring Attention 59 TensorRT-LLM: Graph Optimization for Maximum Inference Performance 60 Long Context LLMs: From 2K to 1M Tokens

The same model can run 2x faster just by letting the compiler choose better kernel parameters. This seems impossible — we’re not changing the algorithm, not quantizing, not fusing operations. We’re running the exact same GEMM operations on the same hardware. But a single matrix multiplication can be executed by hundreds of distinct CUDA kernel implementations, and the performance gap between the best and worst can exceed 3x. The GEMM that processes your FFN gate projection might take 0.65ms with the default cuBLAS heuristic or 0.21ms after TensorRT profiles every candidate and picks the optimal tile size, pipeline depth, and memory access pattern. Multiply that across 80 layers and thousands of requests per second, and autotuning becomes the difference between running at 70% utilization and 95% utilization on the same hardware you already paid for.

This post covers the full autotuning landscape for LLM inference: why so many kernel variants exist, how TensorRT exhaustively profiles them during its engine build phase, how torch.compile’s Inductor backend generates Triton kernels and selects optimal configurations, how cuBLAS uses heuristic tables instead of profiling, and when each approach is appropriate.

Why Hundreds of Kernels Exist for One Operation

A matrix multiplication C=A×BC = A \times B where AA is [M,K][M, K] and BB is [K,N][K, N] is decomposed into tiles. Each thread block computes one tile of the output matrix CC. The choices involved in this decomposition create a combinatorial explosion of kernel variants.

Tile Size Selection

The output matrix CC is [M,N][M, N]. We partition it into tiles of size [TM,TN][T_M, T_N]. Each thread block computes one tile. The K dimension is iterated in chunks of size TKT_K.

Common tile sizes on Hopper (sm_90):

Tile (T_M x T_N x T_K):
  256x128x64   — large tiles, high register pressure, fewer thread blocks
  128x256x64   — same total work per tile, different aspect ratio
  128x128x64   — balanced, good occupancy
  64x256x64    — tall-and-skinny output tiles
  256x64x64    — wide-and-short output tiles
  64x128x64    — small tiles, low register pressure, many thread blocks
  64x64x64     — smallest typical tile

Each tile size implies a different number of thread blocks. For M=4096,N=11008M=4096, N=11008:

  • Tile 256x128: 4096/256×11008/128=16×86=1376\lceil 4096/256 \rceil \times \lceil 11008/128 \rceil = 16 \times 86 = 1376 blocks
  • Tile 128x256: 4096/128×11008/256=32×43=1376\lceil 4096/128 \rceil \times \lceil 11008/256 \rceil = 32 \times 43 = 1376 blocks
  • Tile 64x64: 4096/64×11008/64=64×172=11008\lceil 4096/64 \rceil \times \lceil 11008/64 \rceil = 64 \times 172 = 11008 blocks

The H100 has 132 SMs. With 1376 blocks, each SM processes ~10 blocks (good occupancy, moderate wave quantization). With 11008 blocks, each SM processes ~83 blocks (very high parallelism but smaller tiles mean lower compute intensity per block).

Thread Block Shape and Warp Layout

Within a thread block, warps (groups of 32 threads) are arranged in a 2D grid. For a 256x128 tile with 256 threads (8 warps):

Warp layout options for 8 warps computing [256, 128]:
  Layout A: 4 warps in M x 2 warps in N  — each warp handles [64, 64]
  Layout B: 2 warps in M x 4 warps in N  — each warp handles [128, 32]
  Layout C: 8 warps in M x 1 warp  in N  — each warp handles [32, 128]
  Layout D: 1 warp  in M x 8 warps in N  — each warp handles [256, 16]

Each layout has different implications for shared memory bank conflicts, register usage, and instruction-level parallelism. Layout A is balanced; Layout D minimizes shared memory reads for AA but increases them for BB.

Pipeline Stages

Modern GEMM kernels overlap data loading with computation using software pipelining. The number of pipeline stages determines how much data is “in flight” at any time:

Stages=2: load tile k+1 while computing tile k
           Minimum latency hiding, minimum shared memory usage
           Shared memory: 2 * (T_M * T_K + T_K * T_N) * dtype_size

Stages=3: load tile k+2 while computing tile k, tile k+1 in buffer
           Better latency hiding, 50% more shared memory
           Shared memory: 3 * (T_M * T_K + T_K * T_N) * dtype_size

Stages=4: load tile k+3 while computing tile k
           Best latency hiding for high-latency memory
           Shared memory: 4 * (T_M * T_K + T_K * T_N) * dtype_size

For tile 128x128x64 in FP16 (2 bytes):

smem per stage=(128×64+64×128)×2=32768 bytes=32 KB\text{smem per stage} = (128 \times 64 + 64 \times 128) \times 2 = 32768 \text{ bytes} = 32 \text{ KB}

  • 2 stages: 64 KB shared memory
  • 3 stages: 96 KB shared memory
  • 4 stages: 128 KB shared memory

The H100 has 228 KB of shared memory per SM. At 128 KB (4 stages), only 1 thread block can run per SM. At 64 KB (2 stages), 3 thread blocks can run per SM. More concurrent blocks means better latency hiding through warp scheduling, but fewer stages means worse pipelining within each block.

ℹ️ The Combinatorial Explosion

For a single GEMM shape, the tuning space includes: 7+ tile sizes, 4+ warp layouts per tile, 3+ stage counts, 2+ swizzle patterns, 2+ epilogue variants (with/without bias fusion). That gives 7×4×3×2×2=3367 \times 4 \times 3 \times 2 \times 2 = 336 candidate kernels. In practice, CUTLASS enumerates 200-500 valid configurations per GEMM shape, and TensorRT’s internal library contains thousands.

Memory Access Patterns

How thread blocks traverse the output matrix also affects L2 cache behavior:

Linear order:     block(0,0), block(0,1), block(0,2), ..., block(1,0), ...
  — Accesses column tiles of B sequentially, poor B reuse in L2

Swizzled order:   block(0,0), block(1,0), block(0,1), block(1,1), ...
  — 2x2 block clusters, better L2 reuse for both A and B

Grouped order:    block(0,0), block(0,1), ..., block(0,G), block(1,0), ...
  — Group G columns, maximizes A-tile reuse within group

The L2 cache on H100 is 50 MB. For a Llama 70B FFN weight matrix of shape [8192,57344][8192, 57344] in FP16, the total weight size is 896 MB — far exceeding L2. But during a single “wave” of thread blocks (132 blocks for 132 SMs), each block reads a different row-strip of AA and a different column-strip of BB. Swizzled scheduling ensures that adjacent SMs read adjacent strips, improving L2 hit rates.

TensorRT: Exhaustive Profiling During Build

TensorRT’s approach to kernel selection is straightforward: try every candidate kernel, time each one, pick the fastest.

The Build Phase

When you build a TensorRT engine, the builder:

  1. Parses the model graph (ONNX or TensorRT network definition)
  2. Applies graph optimizations (layer fusion, constant folding)
  3. For each resulting operation, enumerates all candidate kernels
  4. Profiles each candidate kernel on the target GPU
  5. Selects the fastest kernel for each operation
  6. Serializes the engine with the selected kernels
import tensorrt as trt

logger = trt.Logger(trt.Logger.VERBOSE)
builder = trt.Builder(logger)
config = builder.create_builder_config()

# Control autotuning behavior
config.set_flag(trt.BuilderFlag.FP16)

# Timing iterations: more iterations = more accurate selection
# but longer build time
config.avg_timing_iterations = 8  # default is 8

# Builder optimization level: 0-5
# Level 3 (default): standard autotuning
# Level 5: exhaustive search including less common kernels
config.builder_optimization_level = 5

network = builder.create_network(
    1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)

# Parse ONNX model
parser = trt.OnnxParser(network, logger)
with open("model.onnx", "rb") as f:
    parser.parse(f.read())

# Build engine — this is where autotuning happens
# For a 7B model: 5-30 minutes
# For a 70B model: 30-120 minutes
engine = builder.build_serialized_network(network, config)

# Save for deployment — no need to rebuild on same GPU
with open("model.engine", "wb") as f:
    f.write(engine)

What Gets Profiled

During the build, TensorRT’s verbose log reveals the profiling process. For a single linear layer:

[TensorRT] VERBOSE: Tactic: 0x0000000000000001
  Inputs: {FP16[1,4096]} Outputs: {FP16[1,11008]}
  Kernel: sm90_xmma_gemm_f16_f16_f16_f32_256x128x64_3stage_nn
  Time: 0.0234ms

[TensorRT] VERBOSE: Tactic: 0x0000000000000002
  Inputs: {FP16[1,4096]} Outputs: {FP16[1,11008]}
  Kernel: sm90_xmma_gemm_f16_f16_f16_f32_128x256x64_3stage_nn
  Time: 0.0241ms

[TensorRT] VERBOSE: Tactic: 0x0000000000000003
  Inputs: {FP16[1,4096]} Outputs: {FP16[1,11008]}
  Kernel: sm90_xmma_gemm_f16_f16_f16_f32_128x128x64_4stage_nn
  Time: 0.0219ms     <-- fastest

... (200+ more tactics)

[TensorRT] VERBOSE: Selected tactic 0x0000000000000003

Each “tactic” is a complete kernel configuration. TensorRT runs each one multiple times (controlled by avg_timing_iterations), discards outliers, and selects the one with the lowest median time.

TensorRT Timing Cache

Because profiling takes minutes to hours, TensorRT supports a timing cache that stores profiling results across builds:

# Save timing cache after build
timing_cache = config.get_timing_cache()
with open("timing_cache.bin", "wb") as f:
    f.write(timing_cache.serialize())

# Load timing cache for next build (same GPU only)
with open("timing_cache.bin", "rb") as f:
    config.set_timing_cache(
        config.create_timing_cache(f.read()),
        ignore_mismatch=False
    )

The timing cache is GPU-specific. A cache built on an H100 SXM is invalid on an H100 PCIe (different memory bandwidth, different clock speeds). Even two H100 SXMs can have slightly different optimal kernels due to silicon variation, though in practice the differences are negligible.

Dynamic Shapes and Autotuning

LLM inference involves dynamic shapes: the batch size and sequence length change every iteration. TensorRT handles this with optimization profiles:

profile = builder.create_optimization_profile()
profile.set_shape(
    "input_ids",
    min=(1, 1),      # minimum shape
    opt=(32, 512),    # optimal shape (autotuned for this)
    max=(64, 2048)    # maximum shape
)
config.add_optimization_profile(profile)

TensorRT autotunes primarily for the opt shape. For shapes far from opt, it falls back to a heuristic that interpolates between profiled results. This is why TensorRT users often create multiple engines for different batch size ranges.

📊

TensorRT Build Time vs Optimization Level

Optimization LevelBuild Time (Llama 7B)Build Time (Llama 70B)Inference Speedup vs Level 3
Level 0 (no autotuning) 15 seconds 2 minutes 0.82x (slower)
Level 3 (default) 8 minutes 45 minutes 1.00x (baseline)
Level 4 (extended) 18 minutes 90 minutes 1.03x
Level 5 (exhaustive) 35 minutes 150 minutes 1.05x
Note: Measured on H100 SXM with FP16. Build time is a one-time cost. Level 5 explores uncommon kernel configurations that occasionally yield 3-5% improvement on specific layers.

The diminishing returns beyond level 3 explain why most production deployments use the default. The additional 2-5% from levels 4-5 is real but rarely justifies 3-5x longer build times, especially during iterative development.

torch.compile: Inductor and Triton Kernel Generation

torch.compile takes a fundamentally different approach. Instead of choosing from a pre-existing library of handwritten CUDA kernels, the Inductor backend generates Triton kernels on the fly and benchmarks them.

The Compilation Pipeline

import torch

model = load_llm("llama-7b")

# Default mode: moderate autotuning
compiled_model = torch.compile(model, mode="default")

# Reduce-overhead mode: CUDA graph integration
compiled_model = torch.compile(model, mode="reduce-overhead")

# Max-autotune mode: exhaustive kernel search
compiled_model = torch.compile(model, mode="max-autotune")

When you call torch.compile, the following stages execute on the first forward pass:

Stage 1: Dynamo (Python → FX graph)
  — Traces Python code into an intermediate representation
  — Time: 2-10 seconds for a 7B model

Stage 2: AOTAutograd (FX graph → Aten IR)
  — Decomposes high-level ops into primitive operations
  — Generates backward graph if needed (not for inference)
  — Time: 1-5 seconds

Stage 3: Inductor (Aten IR → Triton/C++ kernels)
  — Generates Triton kernel code for each fused subgraph
  — Applies operator fusion (element-wise chains, reductions)
  — Time: 5-30 seconds

Stage 4: Autotuning (Triton kernels → optimal configs)
  — Benchmarks multiple configurations for each generated kernel
  — Time: 20-300 seconds (the dominant cost with max-autotune)

Stage 5: Code caching
  — Serializes generated code and selected configs to disk
  — Subsequent runs skip stages 1-4

Inductor Kernel Generation

Inductor does not call cuBLAS for every GEMM. For standalone GEMMs, it delegates to cuBLAS or a Triton GEMM template. For fused operations (GEMM + bias + activation), it generates custom Triton kernels:

# What the user writes:
def fused_ffn(x, w_gate, w_up, w_down):
    gate = torch.mm(x, w_gate)
    up = torch.mm(x, w_up)
    hidden = torch.nn.functional.silu(gate) * up
    return torch.mm(hidden, w_down)

# What Inductor generates (simplified Triton for the silu*up fusion):
"""
@triton.jit
def fused_silu_mul_kernel(
    gate_ptr, up_ptr, out_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    gate = tl.load(gate_ptr + offsets, mask=mask)
    up = tl.load(up_ptr + offsets, mask=mask)

    # SiLU = x * sigmoid(x)
    silu_gate = gate * tl.sigmoid(gate)
    result = silu_gate * up

    tl.store(out_ptr + offsets, result, mask=mask)
"""

The GEMMs themselves go through a different path. Inductor’s GEMM handling:

# Inductor's decision tree for GEMM (simplified):
def select_gemm_backend(M, N, K, dtype):
    if dtype in (torch.float16, torch.bfloat16):
        if M * N * K > 1_000_000:  # large enough for cuBLAS
            return "cublas"
        else:
            return "triton_gemm_template"
    elif dtype == torch.float8_e4m3fn:
        return "cublas_lt"  # cuBLAS Light for FP8
    else:
        return "triton_gemm_template"

For large GEMMs (the FFN projections), Inductor typically delegates to cuBLAS. For smaller GEMMs or fused GEMM+activation patterns, it generates Triton kernels.

Triton Autotuning

When Inductor generates a Triton kernel, it creates multiple configurations and benchmarks them:

import triton
import triton.language as tl

# Inductor generates this autotune decorator
@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "num_stages": 3, "num_warps": 8},
        ),
        triton.Config(
            {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "num_stages": 4, "num_warps": 4},
        ),
        triton.Config(
            {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "num_stages": 4, "num_warps": 4},
        ),
        triton.Config(
            {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "num_stages": 4, "num_warps": 4},
        ),
        triton.Config(
            {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "num_stages": 4, "num_warps": 4},
        ),
        triton.Config(
            {"BLOCK_M": 64, "BLOCK_N": 64, "BLOCK_K": 32, "num_stages": 5, "num_warps": 2},
        ),
    ],
    key=["M", "N", "K"],  # re-autotune when these change
)
@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn

    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in range(0, K, BLOCK_K):
        a = tl.load(a_ptrs, mask=offs_k[None, :] + k < K, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] + k < K, other=0.0)
        accumulator += tl.dot(a, b)
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk

    c = accumulator.to(tl.float16)
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    tl.store(c_ptrs, c, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))

The @triton.autotune decorator benchmarks each config by running the kernel multiple times and measuring wall-clock time. The winning config is cached.

max-autotune Mode

With mode="max-autotune", Inductor expands the search space:

# Default mode: ~6 Triton configs per kernel, cuBLAS for large GEMMs
# max-autotune mode: ~20 Triton configs per kernel + cuBLAS comparison

# max-autotune also enables:
# CUDA graph capture (eliminates kernel launch overhead)
# cuBLAS vs Triton comparison for every GEMM
# Epilogue fusion search (try fusing bias/activation into GEMM)

The max-autotune compilation for a 7B model takes 2-5 minutes. For a 70B model (which typically uses tensor parallelism, so each rank compiles its shard), 3-8 minutes per rank.

📊

torch.compile Compilation Time by Mode

ModeCompile Time (7B)Compile Time (70B/rank)Speedup vs Eager
eager (no compile) 0 seconds 0 seconds 1.00x
default 30-60 seconds 45-90 seconds 1.05-1.15x
reduce-overhead 40-80 seconds 60-120 seconds 1.10-1.25x
max-autotune 120-300 seconds 180-480 seconds 1.15-1.30x
Note: Measured on H100 SXM. Compile time is dominated by Triton kernel compilation and autotuning. Speedup ranges depend on model architecture and batch size. The reduce-overhead mode gains come primarily from CUDA graph capture, not kernel selection.
max-autotune + reduce-overhead Stack

You can combine modes: torch.compile(model, mode="max-autotune-no-cudagraphs") followed by manual CUDA graph capture, or just mode="max-autotune" which includes CUDA graphs. The combination delivers both optimal kernels and eliminated launch overhead. For LLM inference where the same shapes repeat every decode step, this combination is highly effective.

cuBLAS Heuristics: Fast Selection Without Profiling

cuBLAS takes a third approach: no profiling at all. It uses precomputed heuristic tables that map (M, N, K, dtype, GPU architecture) to a kernel selection.

How cuBLAS Selects Kernels

# Pseudocode for cuBLAS kernel selection:
def cublas_select_gemm(M, N, K, dtype, gpu_arch):
    # Step 1: Filter by compatibility
    candidates = get_kernels_for_arch(gpu_arch, dtype)

    # Step 2: Heuristic scoring based on shape
    for kernel in candidates:
        tile_m, tile_n, tile_k = kernel.tile_size

        # Wave quantization: how many "waves" of thread blocks?
        blocks_m = ceil(M / tile_m)
        blocks_n = ceil(N / tile_n)
        total_blocks = blocks_m * blocks_n
        num_sms = get_sm_count(gpu_arch)  # 132 for H100
        num_waves = ceil(total_blocks / num_sms)

        # Tail effect: last wave may have poor SM utilization
        tail_utilization = (total_blocks % num_sms) / num_sms
        if total_blocks % num_sms == 0:
            tail_utilization = 1.0

        # Score based on tile efficiency and wave utilization
        kernel.score = compute_heuristic_score(
            tile_efficiency=tile_m * tile_n * tile_k,
            wave_utilization=tail_utilization,
            num_waves=num_waves,
            shared_memory=kernel.smem_usage,
        )

    # Step 3: Return highest-scoring kernel
    return max(candidates, key=lambda k: k.score)

The heuristic is tuned by NVIDIA engineers using profiling data from representative shapes. It works well for common shapes (powers of 2, standard transformer dimensions) but can be suboptimal for unusual shapes.

Where Heuristics Fail

cuBLAS heuristics have known weaknesses:

1. Non-power-of-2 dimensions. When NN is not a multiple of common tile sizes, wave quantization causes SM underutilization. For example, with N=11008N=11008 (Llama 7B FFN) and tile size 256:

11008/256=43 column blocks\lceil 11008 / 256 \rceil = 43 \text{ column blocks}

With M=1M=1 (single token decode): 1×43=431 \times 43 = 43 total blocks on 132 SMs. Only 43/132 = 33% SM utilization. A tile size of 128 gives 86 blocks (65% utilization). Tile size 64 gives 172 blocks (full wave + partial second wave). The heuristic must weigh these tradeoffs without actually measuring.

2. Small M (decode case). When M=1M=1, the GEMM is really a matrix-vector multiply. cuBLAS may select a GEMM kernel instead of a specialized GEMV kernel. The specialized kernel is typically 10-20% faster because it avoids tile setup overhead.

3. Grouped/batched GEMM. For Mixture-of-Experts models, each expert processes a different number of tokens. cuBLAS grouped GEMM heuristics are less mature than single-GEMM heuristics.

import torch

# Demonstrate cuBLAS suboptimality on non-standard shapes
# Standard shape: cuBLAS heuristic is well-tuned
M, N, K = 1, 4096, 4096
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
B = torch.randn(K, N, dtype=torch.float16, device="cuda")

# Warm up
for _ in range(10):
    torch.mm(A, B)
torch.cuda.synchronize()

import time
start = time.perf_counter()
for _ in range(1000):
    torch.mm(A, B)
torch.cuda.synchronize()
standard_time = (time.perf_counter() - start) / 1000

# Non-standard shape: cuBLAS heuristic may be suboptimal
M, N, K = 1, 11008, 4096
A = torch.randn(M, K, dtype=torch.float16, device="cuda")
B = torch.randn(K, N, dtype=torch.float16, device="cuda")

for _ in range(10):
    torch.mm(A, B)
torch.cuda.synchronize()

start = time.perf_counter()
for _ in range(1000):
    torch.mm(A, B)
torch.cuda.synchronize()
nonstandard_time = (time.perf_counter() - start) / 1000

# The non-standard shape may show lower TFLOPS utilization
print(f"Standard shape TFLOPS: {2*1*4096*4096/standard_time/1e12:.2f}")
print(f"Non-standard shape TFLOPS: {2*1*11008*4096/nonstandard_time/1e12:.2f}")

cuBLAS vs Autotuned Kernel: TFLOPS by Shape (H100 SXM, FP16)

(TFLOPS)
cuBLAS [1,4096]x[4096,4096]
42 TFLOPS
Autotuned [1,4096]x[4096,4096]
45 TFLOPS
cuBLAS [1,4096]x[4096,11008]
38 TFLOPS
Autotuned [1,4096]x[4096,11008]
48 TFLOPS
cuBLAS [32,4096]x[4096,11008]
285 TFLOPS
Autotuned [32,4096]x[4096,11008]
310 TFLOPS
cuBLAS [256,4096]x[4096,11008]
780 TFLOPS
Autotuned [256,4096]x[4096,11008]
810 TFLOPS

The gap between cuBLAS heuristic and autotuned selection is largest for small MM with non-power-of-2 NN. At large batch sizes where the GEMM is compute-bound, both approaches approach peak TFLOPS and the gap narrows.

When Autotuning Matters: Quantifying the Improvement

Autotuning improvement varies by workload. Here is a systematic breakdown.

Shape-Dependent Improvement

The improvement from autotuning correlates with how well the GEMM shape fits common tile sizes:

📊

Autotuning Improvement Over cuBLAS Heuristic by Shape Category

Shape CategoryExamplecuBLAS TFLOPSAutotuned TFLOPSImprovement
Power-of-2, large M [512, 4096] x [4096, 4096] 850 870 2.4%
Power-of-2, small M [1, 4096] x [4096, 4096] 42 45 7.1%
Non-PoT N, large M [512, 4096] x [4096, 11008] 820 860 4.9%
Non-PoT N, small M [1, 4096] x [4096, 11008] 38 48 26.3%
GQA shapes [1, 8192] x [8192, 1024] 18 24 33.3%
MoE expert [17, 4096] x [4096, 11008] 195 240 23.1%
Note: Measured on H100 SXM with FP16. Autotuned = best of TensorRT tactic selection. Non-standard shapes (non-PoT, small M, GQA, MoE) benefit most from autotuning.

The pattern is clear: autotuning matters most when the GEMM shape is “weird” from the perspective of standard tile sizes. Single-token decode (M=1M=1), GQA key-value projections (small N=1024N=1024), and MoE expert routing (irregular MM) all benefit substantially.

End-to-End Throughput Impact

The per-kernel improvement translates to end-to-end throughput improvement, attenuated by Amdahl’s law:

End-to-end speedup = 1 / (1 - fraction_GEMM + fraction_GEMM / kernel_speedup)

For Llama 70B decode (GEMM is ~85% of time):
  5% kernel speedup  -> 1 / (0.15 + 0.85/1.05) = 1.041 -> 4.1% end-to-end
  15% kernel speedup -> 1 / (0.15 + 0.85/1.15) = 1.116 -> 11.6% end-to-end
  30% kernel speedup -> 1 / (0.15 + 0.85/1.30) = 1.207 -> 20.7% end-to-end

In practice, autotuning delivers 5-15% end-to-end throughput improvement for standard transformer models. For MoE models with irregular expert routing, the improvement can reach 20-30%.

Implementation: torch.compile with max-autotune

Here is a complete implementation of autotuned LLM inference using torch.compile:

import torch
import os
import time

# Enable Inductor's GEMM autotuning
# This tells Inductor to benchmark cuBLAS against Triton for each GEMM
os.environ["TORCHINDUCTOR_MAX_AUTOTUNE"] = "1"
os.environ["TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS"] = "TRITON,ATen"

# Cache directory for compiled kernels
# First compile takes 2-5 minutes; subsequent runs use cache
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp/inductor_cache"

# Optional: increase Triton autotune configs
os.environ["TRITON_MAX_AUTOTUNE_CONFIGS"] = "20"

class AutotunedInferenceEngine:
    """LLM inference engine with torch.compile autotuning."""

    def __init__(self, model_path, max_batch_size=64, max_seq_len=4096):
        self.device = torch.device("cuda")

        # Load model
        self.model = self._load_model(model_path)
        self.model.eval()
        self.model.to(self.device)

        # Compile with max-autotune
        self.compiled_model = None
        self.compile_time = None
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len

    def _load_model(self, path):
        """Load model weights. Implementation depends on model format."""
        from transformers import AutoModelForCausalLM
        model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.float16,
            device_map="auto",
        )
        return model

    def compile(self):
        """Compile model with max-autotune. Call once before inference."""
        start = time.perf_counter()

        self.compiled_model = torch.compile(
            self.model,
            mode="max-autotune",
            fullgraph=False,
            dynamic=True,  # support variable batch sizes
        )

        # Warm up: trigger compilation with representative shapes
        # Inductor compiles lazily on first call
        warmup_shapes = [
            (1, 1),     # single token decode
            (1, 128),   # short prefill
            (1, 512),   # medium prefill
            (32, 1),    # batched decode
        ]

        for batch_size, seq_len in warmup_shapes:
            dummy_input = torch.randint(
                0, 32000,
                (batch_size, seq_len),
                device=self.device,
            )
            with torch.no_grad():
                self.compiled_model(dummy_input)
            torch.cuda.synchronize()

        self.compile_time = time.perf_counter() - start
        print(f"Compilation complete in {self.compile_time:.1f}s")

    @torch.inference_mode()
    def generate(self, input_ids, max_new_tokens=128):
        """Generate tokens using compiled model."""
        if self.compiled_model is None:
            raise RuntimeError("Call compile() before generate()")

        batch_size = input_ids.shape[0]
        generated = input_ids.clone()

        # KV cache is managed internally by HuggingFace
        past_key_values = None

        for step in range(max_new_tokens):
            if past_key_values is None:
                # Prefill: process all input tokens
                outputs = self.compiled_model(
                    input_ids=generated,
                    use_cache=True,
                )
            else:
                # Decode: process only the last token
                outputs = self.compiled_model(
                    input_ids=generated[:, -1:],
                    past_key_values=past_key_values,
                    use_cache=True,
                )

            past_key_values = outputs.past_key_values
            logits = outputs.logits[:, -1, :]

            # Greedy sampling
            next_token = logits.argmax(dim=-1, keepdim=True)
            generated = torch.cat([generated, next_token], dim=-1)

            # Check for EOS
            if (next_token == self.model.config.eos_token_id).all():
                break

        return generated

def benchmark_autotuning(model_path):
    """Compare eager vs compiled inference throughput."""
    engine = AutotunedInferenceEngine(model_path)

    # Benchmark eager mode
    input_ids = torch.randint(
        0, 32000, (1, 128), device="cuda"
    )

    # Eager baseline
    engine.model.eval()
    with torch.inference_mode():
        # Warm up
        for _ in range(3):
            engine.model(input_ids)
        torch.cuda.synchronize()

        start = time.perf_counter()
        for _ in range(50):
            engine.model(input_ids)
        torch.cuda.synchronize()
        eager_time = (time.perf_counter() - start) / 50

    # Compiled with max-autotune
    engine.compile()

    with torch.inference_mode():
        # Warm up (compilation already done)
        for _ in range(3):
            engine.compiled_model(input_ids)
        torch.cuda.synchronize()

        start = time.perf_counter()
        for _ in range(50):
            engine.compiled_model(input_ids)
        torch.cuda.synchronize()
        compiled_time = (time.perf_counter() - start) / 50

    speedup = eager_time / compiled_time
    print(f"Eager:    {eager_time*1000:.2f} ms/forward")
    print(f"Compiled: {compiled_time*1000:.2f} ms/forward")
    print(f"Speedup:  {speedup:.2f}x")

    return eager_time, compiled_time

Inspecting Selected Kernels

After compilation, you can inspect which kernels Inductor selected:

# Enable Inductor debug logging
import torch._inductor.config as inductor_config
inductor_config.debug = True
inductor_config.trace.enabled = True
inductor_config.trace.log_dir = "/tmp/inductor_traces"

# After compilation, check the trace directory:
# /tmp/inductor_traces/
#   model__0_forward/
#     output_code.py        <-- generated Triton/C++ code
#     fx_graph_readable.py  <-- the FX graph before codegen
#     ir_post_fusion.txt    <-- IR after operator fusion

The output_code.py file contains the actual generated Triton kernels with the selected configurations:

# Example excerpt from output_code.py (auto-generated by Inductor):

# Selected config for fused_silu_mul:
#   BLOCK_SIZE=1024, num_warps=4, num_stages=3
# Autotuning tried 6 configs, selected in 0.34s

# Selected config for mm (4096x11008):
#   Backend: cuBLAS (faster than Triton by 8%)
#   Algorithm: CUBLAS_GEMM_DEFAULT_TENSOR_OP

# Selected config for mm (4096x4096):
#   Backend: Triton
#   BLOCK_M=128, BLOCK_N=128, BLOCK_K=32, num_stages=4, num_warps=4
#   Autotuning tried 12 configs, selected in 2.1s

Handling Dynamic Shapes

LLM inference has dynamic shapes (varying batch sizes, sequence lengths). torch.compile handles this through dynamic shape support:

# Option 1: Dynamic shapes (recompiles when shape changes dramatically)
compiled = torch.compile(model, dynamic=True)

# Option 2: Explicit dynamic dimensions
from torch._dynamo import mark_dynamic
input_ids = torch.randint(0, 32000, (1, 128), device="cuda")
mark_dynamic(input_ids, 0)  # batch dimension is dynamic
mark_dynamic(input_ids, 1)  # sequence dimension is dynamic
compiled(input_ids)

# Option 3: Multiple compilations for different shape ranges
# (manual, but avoids recompilation overhead)
compiled_decode = torch.compile(model, mode="max-autotune")
compiled_prefill = torch.compile(model, mode="max-autotune")

# Warm up each for its target shape range
with torch.no_grad():
    compiled_decode(torch.randint(0, 32000, (32, 1), device="cuda"))
    compiled_prefill(torch.randint(0, 32000, (1, 512), device="cuda"))
⚠️ Dynamic Shape Recompilation

When dynamic=True, torch.compile generates shape-generic code that works for any shape. This is slightly slower than shape-specific code because the kernel cannot specialize on exact dimensions. When dynamic=False (default), a new shape triggers recompilation. For LLM serving with variable batch sizes, dynamic=True is usually the better choice — the 2-5% loss from shape-generic kernels is cheaper than repeated recompilation.

Autotuning for Quantized Kernels

Quantized inference (INT4, INT8, FP8) introduces additional autotuning dimensions because the dequantization can be fused with the GEMM or done separately.

W4A16 (INT4 weights, FP16 activations)

# INT4 GEMM autotuning dimensions:
# Dequant location: in registers vs shared memory
# Group size: per-channel vs group-128 vs group-32
# Tile size: same as FP16 but with different optimal points
# Inner loop: dequant-then-accumulate vs packed-accumulate

# Example: Marlin kernel (optimized INT4 GEMM) configuration space
marlin_configs = [
    # (thread_m, thread_n, thread_k, stages, group_size)
    (16, 256, 64, 4, 128),   # wide output tiles
    (16, 128, 64, 4, 128),   # balanced
    (16, 64, 128, 3, 128),   # deep K tiles for better dequant amortization
    (16, 256, 64, 4, -1),    # per-channel quantization (no groups)
]

For W4A16 GEMMs, the autotuning improvement over heuristic selection is typically 10-25%, larger than for FP16 GEMMs. The reason: the dequantization overhead makes tile size selection more shape-sensitive. A tile that is optimal for FP16 may be suboptimal for INT4 because the dequantization cost changes the compute-to-memory ratio.

FP8 Autotuning

FP8 GEMMs have the additional dimension of scaling granularity:

# FP8 scaling options that affect kernel selection:
# Per-tensor scaling: one scale factor per matrix
#    — Simplest, fastest kernel, lowest accuracy
# Per-channel scaling: one scale factor per output channel
#    — Requires modified epilogue to apply channel-wise scales
# Per-block scaling: one scale factor per tile
#    — Most accurate, most complex kernel, NVIDIA's "deep learning" format

# TensorRT selects among FP8 tactics:
# sm90_xmma_gemm_e4m3_e4m3_f32_f32_128x128x64_3stage
# sm90_xmma_gemm_e4m3_e4m3_f32_f32_256x128x64_3stage
# ... plus per-channel and per-block variants

Production Deployment Patterns

TensorRT-LLM Autotuning

TensorRT-LLM combines TensorRT’s engine building with LLM-specific optimizations:

# TensorRT-LLM build command with autotuning
# trtllm-build --model_dir ./llama-7b-hf \
#   --dtype float16 \
#   --max_batch_size 64 \
#   --max_input_len 2048 \
#   --max_seq_len 4096 \
#   --gemm_plugin float16 \
#   --builder_opt_level 4

# The --gemm_plugin flag enables TensorRT-LLM's custom
# GEMM plugin which has its own autotuning:
# Profiles cuBLAS tactics
# Profiles CUTLASS tactics
# Profiles custom fused GEMM+dequant tactics (for quantized models)
# Selects best per shape

Persistent Autotuning Results

Both TensorRT and torch.compile support caching autotuning results:

# TensorRT: engine file IS the cached result
# Just deploy the .engine file, no re-autotuning needed

# torch.compile: cache directory
import torch._inductor.config as config
config.cache_dir = "/persistent/storage/inductor_cache"
# First run: compiles and autotunes (slow)
# Subsequent runs: loads from cache (fast)

# Triton: cache directory
import os
os.environ["TRITON_CACHE_DIR"] = "/persistent/storage/triton_cache"
📊

Autotuning Strategy Comparison

StrategySetup TimePer-Shape OverheadKernel QualityDynamic Shape Support
TensorRT (build) 5-150 min 0 (offline) Best Limited (opt profiles)
torch.compile max-autotune 2-8 min 0.1-2s (lazy) Very Good Good (dynamic=True)
torch.compile default 30-90s 0 (lazy) Good Good
cuBLAS heuristic 0 0 Good for standard shapes Excellent
Triton manual autotune Variable 1-10s per kernel Depends on configs Manual
Note: Setup time is one-time cost (amortized with caching). Kernel quality is relative to the best achievable on each shape.
💡 Decision Framework for Autotuning Strategy

Use TensorRT when you have fixed shapes (batch size ranges known in advance), maximum throughput matters, and build time is acceptable. Use torch.compile max-autotune when you need Python-level flexibility, dynamic shapes, or rapid iteration. Use cuBLAS heuristics (no autotuning) when development speed matters and 5-15% throughput loss is acceptable. For MoE models with irregular expert routing, always autotune — the heuristic penalty is too large.

Measuring Autotuning Effectiveness

To validate that autotuning is working and quantify its benefit, profile the compiled model and compare kernel performance:

import torch
from torch.profiler import profile, ProfilerActivity, schedule

def profile_autotuned_model(model, compiled_model, input_ids):
    """Profile eager vs compiled to measure autotuning benefit."""

    # Profile eager
    with profile(
        activities=[ProfilerActivity.CUDA],
        record_shapes=True,
        with_stack=True,
    ) as prof_eager:
        with torch.inference_mode():
            for _ in range(10):
                model(input_ids)

    # Profile compiled
    with profile(
        activities=[ProfilerActivity.CUDA],
        record_shapes=True,
        with_stack=True,
    ) as prof_compiled:
        with torch.inference_mode():
            for _ in range(10):
                compiled_model(input_ids)

    # Print kernel-level comparison
    print("=== Eager (cuBLAS heuristic) ===")
    print(prof_eager.key_averages().table(
        sort_by="cuda_time_total", row_limit=10
    ))

    print("=== Compiled (autotuned) ===")
    print(prof_compiled.key_averages().table(
        sort_by="cuda_time_total", row_limit=10
    ))

    # Extract GEMM kernel times
    eager_gemm_time = sum(
        e.cuda_time_total
        for e in prof_eager.key_averages()
        if "gemm" in e.key.lower() or "mm" in e.key.lower()
    )
    compiled_gemm_time = sum(
        e.cuda_time_total
        for e in prof_compiled.key_averages()
        if "gemm" in e.key.lower()
        or "mm" in e.key.lower()
        or "triton" in e.key.lower()
    )

    print(f"Eager GEMM time:    {eager_gemm_time/1000:.2f} ms")
    print(f"Compiled GEMM time: {compiled_gemm_time/1000:.2f} ms")
    print(f"GEMM speedup: {eager_gemm_time/compiled_gemm_time:.2f}x")