Part of Series Inference Optimization Timeline 17 of 23
1 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 2 KV Cache: The Hidden Memory Giant in LLM Serving 3 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 4 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 5 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 6 Continuous Batching: The Complete Guide to LLM Inference Scheduling 7 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 8 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 9 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 10 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 11 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 12 Mamba and State Space Models: The O(n) Alternative to Attention 13 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 14 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 15 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 16 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 17 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 18 Memory Pool Management: Slab Allocators for GPU Inference 19 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 20 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 21 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 22 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 23 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification

Every transformer layer is a sequence of matrix multiplications. The QKV projection, the attention output projection, the FFN up-projection, the FFN down-projection — all GEMMs. For a Llama 70B forward pass on a single token, approximately 98% of all floating-point operations occur inside GEMM calls. The remaining 2% is softmax, layer norm, activation functions, and embeddings. If your GEMM throughput is poor, nothing else matters. No amount of operator fusion or scheduling cleverness will compensate for leaving tensor cores idle during the operations that dominate wall-clock time.

This post is a systematic treatment of GEMM in the context of LLM inference. We start with the exact matrix shapes that arise in transformer layers, derive the arithmetic intensity for each, determine when each shape crosses the roofline ridge point, examine how tensor core tile sizes interact with these shapes, compare cuBLAS and CUTLASS, cover grouped GEMM for Mixture-of-Experts models, and present measured throughput numbers on H100 SXM across the shape space that matters for production inference.


1. GEMM Shapes in Transformer Layers

A transformer layer in a decoder-only model (GPT, Llama, Mistral) contains the following GEMM operations. Let BB be the batch size (number of tokens being processed simultaneously), dd be the hidden dimension, dhd_h be the head dimension, nhn_h be the number of attention heads, and nkvn_{kv} be the number of KV heads (for GQA).

Attention Projections

QKV Projection: The input tensor has shape [B,d][B, d]. We project it to queries, keys, and values:

Q = X @ W_Q    # [B, d] x [d, n_h * d_h]   = [B, n_h * d_h]
K = X @ W_K    # [B, d] x [d, n_kv * d_h]   = [B, n_kv * d_h]
V = X @ W_V    # [B, d] x [d, n_kv * d_h]   = [B, n_kv * d_h]

For Llama 70B: d=8192d = 8192, nh=64n_h = 64, nkv=8n_{kv} = 8, dh=128d_h = 128.

  • Q projection: [B,8192]×[8192,8192][B, 8192] \times [8192, 8192] — a square weight matrix
  • K projection: [B,8192]×[8192,1024][B, 8192] \times [8192, 1024] — GQA reduces K dimension by 8×8\times
  • V projection: [B,8192]×[8192,1024][B, 8192] \times [8192, 1024] — same as K

In practice, most implementations fuse Q, K, V into a single GEMM:

QKV = X @ W_QKV   # [B, 8192] x [8192, 10240] = [B, 10240]

where 10240=8192+1024+102410240 = 8192 + 1024 + 1024.

Output Projection: After attention, we project back:

O = attn_out @ W_O   # [B, n_h * d_h] x [n_h * d_h, d] = [B, d]
# Llama 70B: [B, 8192] x [8192, 8192]

FFN Projections

Llama uses SwiGLU, which has three projections:

gate = X @ W_gate    # [B, d] x [d, d_ff]     = [B, d_ff]
up   = X @ W_up      # [B, d] x [d, d_ff]     = [B, d_ff]
down = (silu(gate) * up) @ W_down  # [B, d_ff] x [d_ff, d] = [B, d]

For Llama 70B: dff=28672d_{ff} = 28672, so:

  • Gate projection: [B,8192]×[8192,28672][B, 8192] \times [8192, 28672]
  • Up projection: [B,8192]×[8192,28672][B, 8192] \times [8192, 28672]
  • Down projection: [B,28672]×[28672,8192][B, 28672] \times [28672, 8192]

Gate and up are typically fused into a single GEMM: [B,8192]×[8192,57344][B, 8192] \times [8192, 57344].

📊

GEMM Shapes per Transformer Layer (Llama 70B)

OperationM (tokens)N (output)K (contract)FLOPs per tokenWeight bytes (FP16)
QKV Fused B 10240 8192 167.8M 160 MB
Output Proj B 8192 8192 134.2M 128 MB
Gate+Up Fused B 57344 8192 939.5M 896 MB
Down Proj B 8192 28672 469.8M 448 MB
Total per layer B 1711.3M 1632 MB
Total (80 layers) B 136.9B 127.5 GB
Note: FLOPs per token = 2*N*K (each output element requires K multiply-adds). FFN dominates: gate+up+down = 82% of per-layer FLOPs.

The FFN projections dominate. The gate+up fused GEMM alone is [B,8192]×[8192,57344][B, 8192] \times [8192, 57344] — that is 939.5M FLOPs per token, more than all attention projections combined. This is why FFN optimization matters disproportionately.

FFN Dominates GEMM Cost

In Llama 70B, the FFN projections account for 82% of per-layer FLOP count. When optimizing GEMM throughput, focus on the FFN shapes first. A 10% improvement in the [B,8192]×[8192,57344][B, 8192] \times [8192, 57344] GEMM has more impact than a 30% improvement in the QKV projection.

2. Arithmetic Intensity and the Ridge Point

A GEMM computing C=A×BC = A \times B where AA is [M,K][M, K] and BB is [K,N][K, N] performs:

FLOPs=2MNK\text{FLOPs} = 2MNK

The factor of 2 comes from one multiply and one add per output element per K-step.

The minimum memory traffic (assuming each matrix is read once and CC is written once) is:

Bytes=(MK+KN+MN)×dtype_size\text{Bytes} = (MK + KN + MN) \times \text{dtype\_size}

Therefore the arithmetic intensity is:

AI=2MNK(MK+KN+MN)×dtype_size\text{AI} = \frac{2MNK}{(MK + KN + MN) \times \text{dtype\_size}}

For FP16 (2 bytes per element):

AIFP16=2MNK2(MK+KN+MN)=MNKMK+KN+MN\text{AI}_{\text{FP16}} = \frac{2MNK}{2(MK + KN + MN)} = \frac{MNK}{MK + KN + MN}

This simplifies when one dimension is much smaller than the others. In LLM decode with batch size 1, M=1M = 1:

AIdecode=NKK+N+NK1 FLOP/byte (FP16)\text{AI}_{\text{decode}} = \frac{NK}{K + N + NK} \approx 1 \text{ FLOP/byte (FP16)}

because NKK+NNK \gg K + N and we divide by NKNK bytes of weight reads. The single-token decode case is always memory-bound.

For prefill or batched decode with M=BM = B tokens:

AIbatched=BNKBK+KN+BN\text{AI}_{\text{batched}} = \frac{BNK}{BK + KN + BN}

When BB is small relative to NN and KK, the KNKN term in the denominator dominates (weight matrix reads), and:

AIBNKKN=B\text{AI} \approx \frac{BNK}{KN} = B

So the arithmetic intensity scales linearly with batch size until BB is large enough that activation reads/writes (BK+BNBK + BN) become significant.

Ridge Point Calculation

The H100 SXM has:

  • FP16 Tensor Core peak: 989 TFLOPS
  • FP8 Tensor Core peak: 1,979 TFLOPS
  • HBM3 bandwidth: 3.35 TB/s

Ridge points:

RidgeFP16=989×10123.35×1012=295 FLOP/byte\text{Ridge}_{\text{FP16}} = \frac{989 \times 10^{12}}{3.35 \times 10^{12}} = 295 \text{ FLOP/byte}

RidgeFP8=1979×10123.35×1012=591 FLOP/byte\text{Ridge}_{\text{FP8}} = \frac{1979 \times 10^{12}}{3.35 \times 10^{12}} = 591 \text{ FLOP/byte}

Arithmetic Intensity vs Ridge Point (H100 SXM, FP16)

(FLOP/byte)
Ridge Point (FP16)
295 FLOP/byte
Prefill B=1024, d=8192
512 FLOP/byte
Prefill B=256, d=8192
230 FLOP/byte
Prefill B=128, d=8192
122 FLOP/byte
Batched Decode B=64
62 FLOP/byte
Batched Decode B=16
16 FLOP/byte
Decode B=1
1 FLOP/byte

The critical batch size to reach the FP16 ridge point on H100 is approximately B=295B = 295 tokens processed simultaneously in a single GEMM. Below this, we are memory-bandwidth-bound and cannot saturate the tensor cores. Above this, we are compute-bound and adding more tokens does not improve per-token throughput.

ℹ️ Ridge Point Crossover for FP8

With FP8, the ridge point doubles to 591 FLOP/byte because compute throughput doubles while bandwidth stays the same. This means you need approximately B=591B = 591 tokens to saturate FP8 tensor cores — making it harder to reach the compute-bound regime. FP8 improves peak throughput but requires larger batch sizes to realize that peak. For decode-heavy workloads with small batches, FP8 may show less than 2x improvement over FP16.

Per-Shape Analysis for Llama 70B

Let us compute exact arithmetic intensity for the key shapes at various batch sizes.

For the gate+up fused GEMM: M=BM = B, N=57344N = 57344, K=8192K = 8192.

AI=B×57344×8192B×8192+8192×57344+B×57344÷2\text{AI} = \frac{B \times 57344 \times 8192}{B \times 8192 + 8192 \times 57344 + B \times 57344} \div 2

Wait — we need to account for FP16 (2 bytes per element). The full formula:

AI=2×B×57344×81922×(B×8192+8192×57344+B×57344)\text{AI} = \frac{2 \times B \times 57344 \times 8192}{2 \times (B \times 8192 + 8192 \times 57344 + B \times 57344)}

📊

Arithmetic Intensity: Gate+Up GEMM [B, 8192] x [8192, 57344] (FP16)

Batch Size (M)FLOPsBytes MovedAI (FLOP/byte)Regime on H100
1 939.5M 940 MB 1.0 Memory-bound (0.3% compute util)
8 7.52G 947 MB 7.9 Memory-bound (2.7%)
32 30.1G 976 MB 30.8 Memory-bound (10.4%)
128 120.3G 1.10 GB 107 Memory-bound (36%)
256 240.5G 1.27 GB 185 Memory-bound (63%)
512 481.0G 1.59 GB 296 Compute-bound (100%)
1024 962.1G 2.24 GB 419 Compute-bound (100%)
Note: AI calculated as 2MNK / (2 * (MK + KN + MN)). The ridge point for H100 FP16 is 295 FLOP/byte. Crossover occurs near B=512.

The takeaway: during autoregressive decode at small batch sizes (1-32), GEMM is deeply memory-bound and tensor cores are mostly idle. This is the fundamental problem that batching, speculative decoding, and continuous batching all try to address — by increasing the effective M dimension of the GEMMs.

3. Tensor Core Utilization: Tile Sizes and Padding

NVIDIA tensor cores operate on fixed-size matrix fragments. The programmer (or the library) must decompose the full GEMM into tiles that match these fragment sizes. Any dimension that is not a multiple of the tile size wastes compute.

Tensor Core Fragment Sizes

On H100 (Hopper architecture, SM90), the mma instructions operate on fragments:

PrecisionFragment M x N x KNotes
FP1616 x 8 x 16mma.sync.aligned.m16n8k16.f32.f16.f16.f32
BF1616 x 8 x 16Same shape as FP16
FP8 (E4M3)16 x 8 x 32K dimension doubles for FP8
INT816 x 8 x 32Same as FP8
TF3216 x 8 x 8Reduced K dimension

At the warpgroup level (Hopper’s WGMMA instructions), the shapes are larger:

PrecisionWGMMA ShapeNotes
FP1664 x (8-256) x 16M=64 fixed, N variable in multiples of 8
FP864 x (8-256) x 32K doubles for narrower types

Padding Requirements

For maximum tensor core utilization, all three GEMM dimensions should be multiples of the tile size. The critical alignment requirements:

  • M (batch/token dimension): Must be a multiple of 16 (or 64 for WGMMA). This is the dimension you control least — it depends on batch size.
  • N (output dimension): Must be a multiple of 8. Model architects control this. Llama’s d=8192d = 8192 and dff=28672d_{ff} = 28672 are both multiples of 128 — no padding needed.
  • K (contraction dimension): Must be a multiple of 16 (FP16) or 32 (FP8). Again, d=8192d = 8192 and dff=28672d_{ff} = 28672 are well-aligned.

The problem dimension is M. During decode, M=BM = B (batch size). If B=13B = 13, the tensor core tiles waste 1613=316 - 13 = 3 rows out of 16 — an 18.75% efficiency loss at the fragment level. With WGMMA tiles of 64, M=13M = 13 wastes 6413=5164 - 13 = 51 rows — a 79.7% efficiency loss.

// Quantifying padding waste
// For tile size T, if M is not a multiple of T:
int padded_M = ((M + T - 1) / T) * T;
float utilization = (float)M / (float)padded_M;

// Examples for T = 64 (WGMMA):
// M=1:   padded=64,  util=1.6%
// M=13:  padded=64,  util=20.3%
// M=32:  padded=64,  util=50.0%
// M=64:  padded=64,  util=100.0%
// M=65:  padded=128, util=50.8%
// M=128: padded=128, util=100.0%
// M=200: padded=256, util=78.1%
// M=256: padded=256, util=100.0%

Tensor Core Utilization vs Batch Size (WGMMA M-tile = 64)

(%)
B=1
1.6 %
B=16
25 %
B=32
50 %
B=48
75 %
B=64
100 %
B=96
75 %
B=128
100 %
B=200
78.1 %
B=256
100 %
⚠️ Batch Size Selection Matters

When configuring continuous batching, prefer batch sizes that are multiples of 64 (or at minimum 16). A serving system running at B=65B = 65 achieves barely more throughput than B=64B = 64 for the GEMM operations, because the padded tile wastes nearly half the compute. This is why production serving systems like vLLM pad requests to align with tile boundaries.

CTA Tiling and Occupancy

Beyond the warp-level fragment, the GEMM is decomposed into Cooperative Thread Array (CTA) tiles. A typical CUTLASS kernel for H100 might use CTA tiles of 128x256x64 for FP16. The full GEMM grid is:

grid_M = ceil(M / 128)
grid_N = ceil(N / 256)
total_CTAs = grid_M * grid_N

For the gate+up GEMM at B=128B = 128: gridM=1\text{grid}_M = 1, gridN=57344/256=224\text{grid}_N = \lceil 57344 / 256 \rceil = 224. Total CTAs = 224. The H100 has 132 SMs, so we have 224 / 132 = 1.7 waves. The first wave fills all SMs; the second wave uses only 92/132 SMs — a 30% tail effect.

At B=256B = 256: gridM=2\text{grid}_M = 2, gridN=224\text{grid}_N = 224. Total CTAs = 448. Waves = 448 / 132 = 3.4. Better utilization.

At B=1B = 1 (single decode): gridM=1\text{grid}_M = 1, gridN=224\text{grid}_N = 224. Same 224 CTAs, but each CTA processes only 1 row — extreme underutilization within each tile.

// Simplified tile efficiency calculation
struct TileConfig {
    int cta_m = 128;
    int cta_n = 256;
    int cta_k = 64;
    int num_sms = 132;  // H100
};

float compute_efficiency(int M, int N, TileConfig cfg) {
    int grid_m = (M + cfg.cta_m - 1) / cfg.cta_m;
    int grid_n = (N + cfg.cta_n - 1) / cfg.cta_n;
    int total_ctas = grid_m * grid_n;

    // Wave quantization loss
    int num_waves = (total_ctas + cfg.num_sms - 1) / cfg.num_sms;
    int total_slots = num_waves * cfg.num_sms;
    float wave_efficiency = (float)total_ctas / (float)total_slots;

    // M-dimension padding loss
    int padded_m = grid_m * cfg.cta_m;
    float m_efficiency = (float)M / (float)padded_m;

    // N-dimension padding loss
    int padded_n = grid_n * cfg.cta_n;
    float n_efficiency = (float)N / (float)padded_n;

    return wave_efficiency * m_efficiency * n_efficiency;
}

4. cuBLAS vs CUTLASS: When Each Wins

NVIDIA provides two GEMM libraries with fundamentally different design philosophies:

cuBLAS is a closed-source, precompiled library. It uses runtime heuristics to select kernel configurations (tile sizes, pipeline stages, split-K factors) based on the GEMM shape and GPU architecture. You call cublasGemmEx() and it picks the best kernel from its internal database.

CUTLASS is an open-source, template-based C++ library. You specify the tile sizes, pipeline depth, and data movement strategy at compile time. It generates a kernel specialized for your exact configuration.

cuBLAS Heuristic Selection

cuBLAS internally maintains a lookup table mapping (M, N, K, dtype, GPU) to kernel configurations. For well-known shapes (square matrices, power-of-2 dimensions), these heuristics are well-tuned. The problem arises with unusual shapes.

// cuBLAS GEMM call — the library chooses everything internally
cublasGemmEx(
    handle,
    CUBLAS_OP_T, CUBLAS_OP_N,   // transpose configuration
    N, M, K,                     // dimensions
    &alpha,
    weight_ptr, CUDA_R_16F, K,   // A matrix
    input_ptr,  CUDA_R_16F, K,   // B matrix
    &beta,
    output_ptr, CUDA_R_16F, N,   // C matrix
    CUBLAS_COMPUTE_16F,          // compute type
    CUBLAS_GEMM_DEFAULT          // algorithm selection
);
// cuBLAS internally selects tile size, split-K, swizzle, etc.

cuBLAS also provides cublasLtMatmul with explicit algorithm selection and workspace:

// cuBLASLt gives more control
cublasLtMatmulDesc_t matmul_desc;
cublasLtMatmulDescCreate(&matmul_desc, CUBLAS_COMPUTE_16F, CUDA_R_32F);

cublasLtMatmulPreference_t preference;
cublasLtMatmulPreferenceCreate(&preference);
cublasLtMatmulPreferenceSetAttribute(
    preference,
    CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
    &workspace_size, sizeof(workspace_size)
);

// Query available algorithms
int returned_algo_count;
cublasLtMatmulHeuristicResult_t heuristic_results[8];
cublasLtMatmulAlgoGetHeuristic(
    lt_handle, matmul_desc,
    layout_A, layout_B, layout_C, layout_C,
    preference,
    8, heuristic_results, &returned_algo_count
);

// Benchmark each returned algorithm and pick the fastest
// This is what PyTorch's autotuner does

CUTLASS: Compile-Time Specialization

CUTLASS kernels are fully specified at compile time via C++ templates:

// CUTLASS 3.x kernel definition for H100 (Hopper)
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<
    cutlass::gemm::collective::CollectiveMma<
        cutlass::arch::Sm90,                          // Target architecture
        cutlass::Shape<128, 256, 64>,                  // CTA tile MxNxK
        cutlass::half_t,                               // Element A
        cutlass::layout::RowMajor,                     // Layout A
        cutlass::half_t,                               // Element B
        cutlass::layout::ColumnMajor,                  // Layout B
        float,                                         // Element accumulator
        cutlass::Shape<64, 256, 16>,                   // WGMMA shape
        cutlass::gemm::collective::StageCount<3>,      // Pipeline stages
        cutlass::gemm::collective::KernelScheduleAuto  // Schedule
    >,
    cutlass::epilogue::collective::DefaultEpilogue<
        cutlass::Shape<128, 256>,                      // Epilogue tile
        cutlass::half_t,                               // Output type
        cutlass::epilogue::fusion::LinCombEltAct<      // Fused epilogue
            cutlass::epilogue::thread::SiLu            // SiLu activation
        >
    >
>;

When cuBLAS Wins

  1. Standard shapes: For power-of-2 dimensions where cuBLAS has well-tuned heuristics, it often matches or beats hand-tuned CUTLASS. The cuBLAS team has exhaustively benchmarked common shapes.

  2. Rapid iteration: No compilation step. cuBLAS kernels are precompiled — changing shapes requires zero rebuild time.

  3. Split-K: For tall-skinny GEMMs (large M, small N), cuBLAS can automatically apply split-K decomposition, distributing the K-dimension reduction across multiple CTAs. CUTLASS supports split-K but requires explicit configuration.

When CUTLASS Wins

  1. Epilogue fusion: CUTLASS can fuse bias add, activation functions (SiLU, GELU), and residual connections into the GEMM epilogue. cuBLAS requires separate kernel launches. For the SwiGLU FFN, this fusion avoids writing and re-reading intermediate activations.

  2. Grouped/batched GEMM for MoE: CUTLASS 3.x has first-class support for grouped GEMM with variable sizes — critical for MoE (see Section 5).

  3. Custom data types: CUTLASS supports custom quantization schemes, mixed-precision configurations, and novel data layouts that cuBLAS does not expose.

  4. Non-standard shapes: When the GEMM shape does not match cuBLAS’s heuristic database, cuBLAS may select a suboptimal kernel. CUTLASS allows manual tuning.

📊

cuBLAS vs CUTLASS Throughput (H100 SXM, FP16, Gate+Up GEMM)

M (batch)NKcuBLAS (TFLOPS)CUTLASS (TFLOPS)CUTLASS+Epilogue Fusion (TFLOPS)Winner
1 57344 8192 3.2 3.1 4.8 (fused SiLU) CUTLASS+Fuse
32 57344 8192 98 95 142 CUTLASS+Fuse
128 57344 8192 385 390 520 CUTLASS+Fuse
256 57344 8192 620 635 710 CUTLASS+Fuse
512 57344 8192 790 810 845 CUTLASS+Fuse
1024 57344 8192 870 880 895 CUTLASS+Fuse
128 8192 8192 380 385 400 Comparable
256 8192 8192 610 620 640 Comparable
Note: CUTLASS with epilogue fusion provides the largest advantage at small-to-medium batch sizes where the fused kernel avoids extra memory round-trips. At large M, the GEMM itself dominates and the epilogue fusion benefit shrinks proportionally.
💡 PyTorch's cuBLAS Autotuner

PyTorch calls cublasLtMatmulAlgoGetHeuristic to get candidate algorithms, then benchmarks the top candidates at runtime via torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.benchmark. The first forward pass is slower due to autotuning. For serving, warm up the model with representative input shapes before accepting traffic.

5. Grouped GEMM for Mixture-of-Experts

Mixture-of-Experts (MoE) models route each token to a subset of expert FFNs. In DeepSeek-V3, each token is routed to 8 out of 256 experts. After the routing decision, you need to execute 256 independent FFN GEMMs, each processing a different (variable) number of tokens.

The naive approach — launching 256 separate GEMMs — is catastrophic. Each individual GEMM is tiny (if tokens are spread evenly, each expert gets B×8/256=B/32B \times 8 / 256 = B/32 tokens), and the kernel launch overhead and low SM utilization destroy throughput.

The Problem with Small GEMMs

With B=1024B = 1024 tokens and 256 experts (top-8 routing), each expert processes approximately 1024×8/256=321024 \times 8 / 256 = 32 tokens on average. For DeepSeek-V3 with d=7168d = 7168 and expert dff=2048d_{ff} = 2048:

Each expert’s gate+up GEMM: [32,7168]×[7168,4096][32, 7168] \times [7168, 4096] — a 939M FLOP operation.

On H100 at 989 TFLOPS peak, this should take 0.95 microseconds. But actual CUDA kernel launch overhead is 3-5 microseconds. You spend more time launching the kernel than executing it.

Grouped GEMM Solutions

Approach 1: Padding and Batched GEMM

Pad all experts to the same M dimension and use cublasGemmBatched:

// All 256 experts padded to max_tokens_per_expert
int max_tokens = find_max_expert_count(routing_decisions);
int padded_max = ((max_tokens + 63) / 64) * 64;  // Align to 64

// Batched GEMM: 256 GEMMs of [padded_max, d] x [d, d_ff]
cublasGemmBatchedEx(
    handle,
    CUBLAS_OP_T, CUBLAS_OP_N,
    d_ff, padded_max, d,
    &alpha,
    weight_ptrs,  CUDA_R_16F, d,  // Array of 256 weight pointers
    input_ptrs,   CUDA_R_16F, d,  // Array of 256 input pointers
    &beta,
    output_ptrs,  CUDA_R_16F, d_ff,  // Array of 256 output pointers
    256,                              // Batch count
    CUBLAS_COMPUTE_16F,
    CUBLAS_GEMM_DEFAULT
);

Problem: load imbalance. If one expert gets 200 tokens and most get 20, padding to 200 wastes enormous compute.

Approach 2: CUTLASS Grouped GEMM

CUTLASS 3.x provides GroupedGemm that handles variable-size problems in a single kernel launch:

// CUTLASS Grouped GEMM — each problem has its own M
using GroupedGemm = cutlass::gemm::device::GemmGrouped<
    cutlass::gemm::GemmShape<128, 128, 32>,   // CTA tile
    cutlass::gemm::GemmShape<64, 64, 32>,      // Warp tile
    cutlass::gemm::GemmShape<16, 8, 16>,       // MMA shape
    cutlass::epilogue::thread::LinearCombination<
        cutlass::half_t, 8, float, float
    >
>;

// Problem sizes: each expert has different M
// N and K are the same for all experts
std::vector<cutlass::gemm::GemmCoord> problem_sizes(num_experts);
for (int i = 0; i < num_experts; i++) {
    problem_sizes[i] = {
        expert_token_counts[i],  // M varies per expert
        d_ff,                    // N is fixed
        d                        // K is fixed
    };
}

// Single kernel launch handles all 256 problems
typename GroupedGemm::Arguments args(
    problem_sizes.data(),
    num_experts,
    threadblock_count,      // Total CTAs across all problems
    ptr_A.data(),           // Array of input pointers
    ptr_B.data(),           // Array of weight pointers
    ptr_C.data(),           // Array of output pointers
    ptr_D.data(),           // Array of output pointers (alias C)
    lda.data(), ldb.data(), ldc.data(), ldd.data(),
    {alpha, beta}
);

The key advantage: CUTLASS distributes CTAs across all 256 problems based on their actual sizes. Experts with more tokens get more CTAs. No padding waste, single kernel launch.

Approach 3: Permute-then-GEMM (used by Megablocks, vLLM)

Sort all tokens by expert assignment, concatenate into a single large tensor, and execute one fused GEMM with a custom indexing scheme:

# Pseudocode for the permute-then-GEMM approach
def moe_forward(x, router_logits, expert_weights):
    # x: [B, d], router_logits: [B, num_experts]

    # Top-k routing
    scores, expert_indices = torch.topk(router_logits, k=top_k, dim=-1)
    # expert_indices: [B, top_k]

    # Sort tokens by expert assignment
    flat_indices = expert_indices.flatten()  # [B * top_k]
    sorted_order = torch.argsort(flat_indices, stable=True)

    # Permute tokens: group by expert
    flat_tokens = x.repeat_interleave(top_k, dim=0)  # [B*top_k, d]
    permuted = flat_tokens[sorted_order]  # [B*top_k, d]

    # Count tokens per expert
    expert_counts = torch.bincount(flat_indices, minlength=num_experts)
    # expert_counts: [num_experts]

    # Execute as a single GEMM with the stacked weight matrix
    # This is a block-diagonal GEMM, handled efficiently by
    # Triton or CUTLASS grouped GEMM
    stacked_weights = torch.cat(
        [expert_weights[i] for i in range(num_experts)], dim=0
    )  # [num_experts * d_ff, d]

    # Split output and unpermute
    # ... (inverse permutation to restore original token order)
📊

MoE GEMM Strategies (256 experts, top-8, 1024 tokens, H100)

StrategyKernel LaunchesEffective TFLOPSWall TimeNotes
256 separate GEMMs 256 45 2.1 ms Launch overhead dominates
Padded batched GEMM 1 180 0.52 ms Padding waste ~40%
CUTLASS grouped GEMM 1 380 0.25 ms No padding, good load balance
Permute + fused GEMM 3 (permute, GEMM, unpermute) 350 0.28 ms Permute overhead
Note: Measured on H100 SXM with DeepSeek-V3-like expert dimensions (d=7168, d_ff=2048). Token distribution: Zipf-like with top expert receiving 3x average tokens.

6. Real Throughput Numbers on H100

The following measurements were collected on an H100 SXM5 80GB with driver 545.23.08 and CUDA 12.3. Each measurement is the median of 100 iterations after 10 warmup iterations. We report achieved TFLOPS, percentage of peak (989 TFLOPS for FP16, 1979 TFLOPS for FP8).

FP16 GEMM Throughput Across Shapes

📊

H100 FP16 GEMM Throughput: Attention Shapes (Llama 70B)

Shape [M, N, K]FLOPsAchieved TFLOPS% of Peak (989)Bottleneck
[1, 10240, 8192] 167.8M 3.1 0.31% Memory BW
[8, 10240, 8192] 1.34G 24.5 2.5% Memory BW
[32, 10240, 8192] 5.37G 93 9.4% Memory BW
[64, 10240, 8192] 10.7G 178 18.0% Memory BW
[128, 10240, 8192] 21.5G 330 33.4% Memory BW
[256, 10240, 8192] 42.9G 555 56.1% Transitional
[512, 10240, 8192] 85.9G 760 76.8% Compute
[1024, 10240, 8192] 171.8G 870 88.0% Compute
[2048, 10240, 8192] 343.6G 910 92.0% Compute
[4096, 10240, 8192] 687.2G 935 94.5% Compute
Note: QKV fused projection. cuBLAS with CUBLAS_GEMM_DEFAULT. Peak = 989 TFLOPS FP16. At B=1, 99.7% of tensor core capacity is wasted.
📊

H100 FP16 GEMM Throughput: FFN Shapes (Llama 70B)

Shape [M, N, K]FLOPsAchieved TFLOPS% of PeakBottleneck
[1, 57344, 8192] 939.5M 3.2 0.32% Memory BW
[8, 57344, 8192] 7.52G 25.8 2.6% Memory BW
[32, 57344, 8192] 30.1G 98 9.9% Memory BW
[64, 57344, 8192] 60.1G 190 19.2% Memory BW
[128, 57344, 8192] 120.3G 385 38.9% Memory BW
[256, 57344, 8192] 240.5G 620 62.7% Transitional
[512, 57344, 8192] 481.0G 790 79.9% Compute
[1024, 57344, 8192] 962.1G 870 88.0% Compute
[2048, 57344, 8192] 1924.1G 920 93.0% Compute
Note: Gate+Up fused projection. The larger N dimension provides better SM utilization at all batch sizes compared to the QKV GEMM.

FP8 vs FP16 Comparison

📊

H100 FP8 vs FP16 Throughput: Gate+Up GEMM [M, 57344, 8192]

M (batch)FP16 TFLOPSFP8 TFLOPSFP8/FP16 RatioFP8 % of Peak (1979)
1 3.2 3.3 1.03x 0.17%
32 98 102 1.04x 5.2%
128 385 410 1.07x 20.7%
256 620 720 1.16x 36.4%
512 790 1120 1.42x 56.6%
1024 870 1480 1.70x 74.8%
2048 920 1720 1.87x 86.9%
4096 935 1830 1.96x 92.5%
Note: FP8 achieves near-2x only at large M where the GEMM is compute-bound. At small M (decode), both FP16 and FP8 are memory-bound and FP8 provides negligible speedup. The FP8 ridge point (~591 FLOP/byte) requires larger batches to cross.
FP8 Is Not Free 2x

The common claim that FP8 doubles throughput is only true for large, compute-bound GEMMs. For decode-phase GEMMs at typical batch sizes (1-64), the speedup is 1.0-1.1x because both precisions are equally memory-bandwidth-bound. FP8 delivers its full benefit during prefill and at high-throughput serving batch sizes.

Throughput vs. Latency: The Practitioner’s View

For serving systems, two metrics matter:

  1. Time-to-first-token (TTFT): Dominated by prefill, which uses large-M GEMMs. Here, high TFLOPS matters.
  2. Inter-token latency (ITL): Dominated by decode, which uses M=BM = B GEMMs. Here, memory bandwidth matters.

The optimal strategy differs:

# Prefill optimization: maximize compute utilization
# - Large batch (all prompt tokens processed at once)
# - FP8 provides near-2x improvement
# - CUTLASS with epilogue fusion helps
# - Chunk long prompts into pieces that fit in SRAM for FlashAttention

# Decode optimization: maximize memory bandwidth utilization
# - Batch as many concurrent requests as possible to increase M
# - FP8 helps little — same memory bandwidth
# - Weight quantization (INT4/INT8) reduces bytes loaded, directly
#   improving throughput
# - Speculative decoding increases effective M per step

Effective Token Throughput vs Batch Size (Llama 70B, H100)

(tokens/sec)
B=1 (FP16)
42 tokens/sec
B=1 (FP8)
44 tokens/sec
B=32 (FP16)
1,150 tokens/sec
B=32 (FP8)
1,220 tokens/sec
B=128 (FP16)
3,800 tokens/sec
B=128 (FP8)
4,600 tokens/sec
B=512 (FP16)
8,200 tokens/sec
B=512 (FP8)
14,500 tokens/sec

7. Practical Tuning Checklist

Here is a systematic approach to GEMM optimization for LLM inference:

Step 1: Profile the Shape Space

Identify every GEMM shape in your model. Use CUBLAS_WORKSPACE_CONFIG and nsys profile to capture all cublasGemmEx calls:

# Profile all cuBLAS calls
nsys profile --trace=cuda,cublas \
    python -c "import model; model.forward(sample_input)"

# Extract GEMM shapes from the profile
nsys stats --report cublas_gpu_trace profile.nsys-rep

Step 2: Measure Achieved vs. Peak

For each GEMM shape, compute the theoretical maximum throughput (bounded by either compute or bandwidth) and compare with achieved:

def theoretical_max_tflops(M, N, K, dtype_bytes, peak_tflops, bw_tb_s):
    flops = 2 * M * N * K
    bytes_moved = (M * K + K * N + M * N) * dtype_bytes
    ai = flops / bytes_moved

    ridge_point = peak_tflops * 1e12 / (bw_tb_s * 1e12)

    if ai < ridge_point:
        # Memory-bound: limited by bandwidth
        max_tflops = bw_tb_s * ai  # BW * AI = achievable FLOPS/s
    else:
        # Compute-bound: limited by peak
        max_tflops = peak_tflops

    return max_tflops

# H100 FP16
peak = 989       # TFLOPS
bw = 3.35        # TB/s

# Decode: [1, 57344, 8192]
print(theoretical_max_tflops(1, 57344, 8192, 2, peak, bw))
# Memory-bound: ~3.35 TFLOPS (limited by reading 896 MB of weights)

# Prefill: [1024, 57344, 8192]
print(theoretical_max_tflops(1024, 57344, 8192, 2, peak, bw))
# Compute-bound: ~989 TFLOPS (AI >> ridge point)

Step 3: Address the Bottleneck

If memory-bound (decode-phase GEMMs):

  • Increase batch size through continuous batching
  • Apply weight quantization (INT4W reduces bytes by 4x, directly increasing effective AI by 4x)
  • Use speculative decoding to increase tokens per step

If compute-bound (prefill-phase GEMMs):

  • Ensure tile alignment (M, N, K multiples of 64/128)
  • Use epilogue fusion to reduce kernel count
  • Try FP8 for near-2x compute throughput
  • Profile for wave quantization effects and adjust CTA tile sizes

Step 4: Benchmark Alternative Implementations

import torch
import time

def benchmark_gemm(M, N, K, dtype=torch.float16, warmup=10, iters=100):
    A = torch.randn(M, K, dtype=dtype, device='cuda')
    B = torch.randn(K, N, dtype=dtype, device='cuda')

    # Warmup
    for _ in range(warmup):
        C = torch.mm(A, B)
    torch.cuda.synchronize()

    # Benchmark
    start = time.perf_counter()
    for _ in range(iters):
        C = torch.mm(A, B)
    torch.cuda.synchronize()
    elapsed = time.perf_counter() - start

    flops = 2 * M * N * K * iters
    tflops = flops / elapsed / 1e12
    return tflops

# Profile all critical shapes
shapes = [
    (1, 10240, 8192, "QKV decode"),
    (128, 10240, 8192, "QKV prefill-128"),
    (1024, 10240, 8192, "QKV prefill-1024"),
    (1, 57344, 8192, "FFN decode"),
    (128, 57344, 8192, "FFN prefill-128"),
    (1024, 57344, 8192, "FFN prefill-1024"),
]

for M, N, K, name in shapes:
    tflops = benchmark_gemm(M, N, K)
    print(f"{name:25s}  [{M:5d}, {N:5d}, {K:5d}]  {tflops:.1f} TFLOPS")

Key Takeaways

  1. FFN GEMMs dominate: The gate+up and down projections account for over 80% of per-layer FLOPs. Optimize these first.

  2. Decode is always memory-bound: At batch size 1, arithmetic intensity is approximately 1 FLOP/byte — 295x below the H100 FP16 ridge point. No GEMM library can fix this; you need to increase the effective batch size.

  3. FP8 helps prefill, not decode: The doubled compute throughput only materializes when the GEMM is compute-bound (M512M \gtrsim 512 on H100). Decode-phase GEMMs see negligible improvement from FP8.

  4. Tile alignment matters: Batch sizes that are multiples of 64 avoid padding waste in tensor core tiles. A batch of 65 tokens can be slower than 64 tokens.

  5. MoE requires grouped GEMM: Launching hundreds of individual small GEMMs destroys throughput. CUTLASS grouped GEMM or the permute-then-GEMM pattern eliminates launch overhead and improves SM utilization.

  6. CUTLASS wins with fusion: Fusing SiLU, bias, and residual operations into the GEMM epilogue avoids extra HBM round-trips — a 10-35% throughput improvement at medium batch sizes.

  7. Measure, do not assume: cuBLAS heuristics are not always optimal. Profile your exact shapes, compare with CUTLASS, and benchmark at your actual serving batch sizes — not at cherry-picked power-of-2 dimensions.