Every transformer layer is a sequence of matrix multiplications. The QKV projection, the attention output projection, the FFN up-projection, the FFN down-projection — all GEMMs. For a Llama 70B forward pass on a single token, approximately 98% of all floating-point operations occur inside GEMM calls. The remaining 2% is softmax, layer norm, activation functions, and embeddings. If your GEMM throughput is poor, nothing else matters. No amount of operator fusion or scheduling cleverness will compensate for leaving tensor cores idle during the operations that dominate wall-clock time.
This post is a systematic treatment of GEMM in the context of LLM inference. We start with the exact matrix shapes that arise in transformer layers, derive the arithmetic intensity for each, determine when each shape crosses the roofline ridge point, examine how tensor core tile sizes interact with these shapes, compare cuBLAS and CUTLASS, cover grouped GEMM for Mixture-of-Experts models, and present measured throughput numbers on H100 SXM across the shape space that matters for production inference.
1. GEMM Shapes in Transformer Layers
A transformer layer in a decoder-only model (GPT, Llama, Mistral) contains the following GEMM operations. Let be the batch size (number of tokens being processed simultaneously), be the hidden dimension, be the head dimension, be the number of attention heads, and be the number of KV heads (for GQA).
Attention Projections
QKV Projection: The input tensor has shape . We project it to queries, keys, and values:
Q = X @ W_Q # [B, d] x [d, n_h * d_h] = [B, n_h * d_h]
K = X @ W_K # [B, d] x [d, n_kv * d_h] = [B, n_kv * d_h]
V = X @ W_V # [B, d] x [d, n_kv * d_h] = [B, n_kv * d_h]
For Llama 70B: , , , .
- Q projection: — a square weight matrix
- K projection: — GQA reduces K dimension by
- V projection: — same as K
In practice, most implementations fuse Q, K, V into a single GEMM:
QKV = X @ W_QKV # [B, 8192] x [8192, 10240] = [B, 10240]
where .
Output Projection: After attention, we project back:
O = attn_out @ W_O # [B, n_h * d_h] x [n_h * d_h, d] = [B, d]
# Llama 70B: [B, 8192] x [8192, 8192]
FFN Projections
Llama uses SwiGLU, which has three projections:
gate = X @ W_gate # [B, d] x [d, d_ff] = [B, d_ff]
up = X @ W_up # [B, d] x [d, d_ff] = [B, d_ff]
down = (silu(gate) * up) @ W_down # [B, d_ff] x [d_ff, d] = [B, d]
For Llama 70B: , so:
- Gate projection:
- Up projection:
- Down projection:
Gate and up are typically fused into a single GEMM: .
GEMM Shapes per Transformer Layer (Llama 70B)
| Operation | M (tokens) | N (output) | K (contract) | FLOPs per token | Weight bytes (FP16) |
|---|---|---|---|---|---|
| QKV Fused | B | 10240 | 8192 | 167.8M | 160 MB |
| Output Proj | B | 8192 | 8192 | 134.2M | 128 MB |
| Gate+Up Fused | B | 57344 | 8192 | 939.5M | 896 MB |
| Down Proj | B | 8192 | 28672 | 469.8M | 448 MB |
| Total per layer | B | — | — | 1711.3M | 1632 MB |
| Total (80 layers) | B | — | — | 136.9B | 127.5 GB |
The FFN projections dominate. The gate+up fused GEMM alone is — that is 939.5M FLOPs per token, more than all attention projections combined. This is why FFN optimization matters disproportionately.
In Llama 70B, the FFN projections account for 82% of per-layer FLOP count. When optimizing GEMM throughput, focus on the FFN shapes first. A 10% improvement in the GEMM has more impact than a 30% improvement in the QKV projection.
2. Arithmetic Intensity and the Ridge Point
A GEMM computing where is and is performs:
The factor of 2 comes from one multiply and one add per output element per K-step.
The minimum memory traffic (assuming each matrix is read once and is written once) is:
Therefore the arithmetic intensity is:
For FP16 (2 bytes per element):
This simplifies when one dimension is much smaller than the others. In LLM decode with batch size 1, :
because and we divide by bytes of weight reads. The single-token decode case is always memory-bound.
For prefill or batched decode with tokens:
When is small relative to and , the term in the denominator dominates (weight matrix reads), and:
So the arithmetic intensity scales linearly with batch size until is large enough that activation reads/writes () become significant.
Ridge Point Calculation
The H100 SXM has:
- FP16 Tensor Core peak: 989 TFLOPS
- FP8 Tensor Core peak: 1,979 TFLOPS
- HBM3 bandwidth: 3.35 TB/s
Ridge points:
Arithmetic Intensity vs Ridge Point (H100 SXM, FP16)
(FLOP/byte)The critical batch size to reach the FP16 ridge point on H100 is approximately tokens processed simultaneously in a single GEMM. Below this, we are memory-bandwidth-bound and cannot saturate the tensor cores. Above this, we are compute-bound and adding more tokens does not improve per-token throughput.
With FP8, the ridge point doubles to 591 FLOP/byte because compute throughput doubles while bandwidth stays the same. This means you need approximately tokens to saturate FP8 tensor cores — making it harder to reach the compute-bound regime. FP8 improves peak throughput but requires larger batch sizes to realize that peak. For decode-heavy workloads with small batches, FP8 may show less than 2x improvement over FP16.
Per-Shape Analysis for Llama 70B
Let us compute exact arithmetic intensity for the key shapes at various batch sizes.
For the gate+up fused GEMM: , , .
Wait — we need to account for FP16 (2 bytes per element). The full formula:
Arithmetic Intensity: Gate+Up GEMM [B, 8192] x [8192, 57344] (FP16)
| Batch Size (M) | FLOPs | Bytes Moved | AI (FLOP/byte) | Regime on H100 |
|---|---|---|---|---|
| 1 | 939.5M | 940 MB | 1.0 | Memory-bound (0.3% compute util) |
| 8 | 7.52G | 947 MB | 7.9 | Memory-bound (2.7%) |
| 32 | 30.1G | 976 MB | 30.8 | Memory-bound (10.4%) |
| 128 | 120.3G | 1.10 GB | 107 | Memory-bound (36%) |
| 256 | 240.5G | 1.27 GB | 185 | Memory-bound (63%) |
| 512 | 481.0G | 1.59 GB | 296 | Compute-bound (100%) |
| 1024 | 962.1G | 2.24 GB | 419 | Compute-bound (100%) |
The takeaway: during autoregressive decode at small batch sizes (1-32), GEMM is deeply memory-bound and tensor cores are mostly idle. This is the fundamental problem that batching, speculative decoding, and continuous batching all try to address — by increasing the effective M dimension of the GEMMs.
3. Tensor Core Utilization: Tile Sizes and Padding
NVIDIA tensor cores operate on fixed-size matrix fragments. The programmer (or the library) must decompose the full GEMM into tiles that match these fragment sizes. Any dimension that is not a multiple of the tile size wastes compute.
Tensor Core Fragment Sizes
On H100 (Hopper architecture, SM90), the mma instructions operate on fragments:
| Precision | Fragment M x N x K | Notes |
|---|---|---|
| FP16 | 16 x 8 x 16 | mma.sync.aligned.m16n8k16.f32.f16.f16.f32 |
| BF16 | 16 x 8 x 16 | Same shape as FP16 |
| FP8 (E4M3) | 16 x 8 x 32 | K dimension doubles for FP8 |
| INT8 | 16 x 8 x 32 | Same as FP8 |
| TF32 | 16 x 8 x 8 | Reduced K dimension |
At the warpgroup level (Hopper’s WGMMA instructions), the shapes are larger:
| Precision | WGMMA Shape | Notes |
|---|---|---|
| FP16 | 64 x (8-256) x 16 | M=64 fixed, N variable in multiples of 8 |
| FP8 | 64 x (8-256) x 32 | K doubles for narrower types |
Padding Requirements
For maximum tensor core utilization, all three GEMM dimensions should be multiples of the tile size. The critical alignment requirements:
- M (batch/token dimension): Must be a multiple of 16 (or 64 for WGMMA). This is the dimension you control least — it depends on batch size.
- N (output dimension): Must be a multiple of 8. Model architects control this. Llama’s and are both multiples of 128 — no padding needed.
- K (contraction dimension): Must be a multiple of 16 (FP16) or 32 (FP8). Again, and are well-aligned.
The problem dimension is M. During decode, (batch size). If , the tensor core tiles waste rows out of 16 — an 18.75% efficiency loss at the fragment level. With WGMMA tiles of 64, wastes rows — a 79.7% efficiency loss.
// Quantifying padding waste
// For tile size T, if M is not a multiple of T:
int padded_M = ((M + T - 1) / T) * T;
float utilization = (float)M / (float)padded_M;
// Examples for T = 64 (WGMMA):
// M=1: padded=64, util=1.6%
// M=13: padded=64, util=20.3%
// M=32: padded=64, util=50.0%
// M=64: padded=64, util=100.0%
// M=65: padded=128, util=50.8%
// M=128: padded=128, util=100.0%
// M=200: padded=256, util=78.1%
// M=256: padded=256, util=100.0%
Tensor Core Utilization vs Batch Size (WGMMA M-tile = 64)
(%)When configuring continuous batching, prefer batch sizes that are multiples of 64 (or at minimum 16). A serving system running at achieves barely more throughput than for the GEMM operations, because the padded tile wastes nearly half the compute. This is why production serving systems like vLLM pad requests to align with tile boundaries.
CTA Tiling and Occupancy
Beyond the warp-level fragment, the GEMM is decomposed into Cooperative Thread Array (CTA) tiles. A typical CUTLASS kernel for H100 might use CTA tiles of 128x256x64 for FP16. The full GEMM grid is:
grid_M = ceil(M / 128)
grid_N = ceil(N / 256)
total_CTAs = grid_M * grid_N
For the gate+up GEMM at : , . Total CTAs = 224. The H100 has 132 SMs, so we have 224 / 132 = 1.7 waves. The first wave fills all SMs; the second wave uses only 92/132 SMs — a 30% tail effect.
At : , . Total CTAs = 448. Waves = 448 / 132 = 3.4. Better utilization.
At (single decode): , . Same 224 CTAs, but each CTA processes only 1 row — extreme underutilization within each tile.
// Simplified tile efficiency calculation
struct TileConfig {
int cta_m = 128;
int cta_n = 256;
int cta_k = 64;
int num_sms = 132; // H100
};
float compute_efficiency(int M, int N, TileConfig cfg) {
int grid_m = (M + cfg.cta_m - 1) / cfg.cta_m;
int grid_n = (N + cfg.cta_n - 1) / cfg.cta_n;
int total_ctas = grid_m * grid_n;
// Wave quantization loss
int num_waves = (total_ctas + cfg.num_sms - 1) / cfg.num_sms;
int total_slots = num_waves * cfg.num_sms;
float wave_efficiency = (float)total_ctas / (float)total_slots;
// M-dimension padding loss
int padded_m = grid_m * cfg.cta_m;
float m_efficiency = (float)M / (float)padded_m;
// N-dimension padding loss
int padded_n = grid_n * cfg.cta_n;
float n_efficiency = (float)N / (float)padded_n;
return wave_efficiency * m_efficiency * n_efficiency;
}
4. cuBLAS vs CUTLASS: When Each Wins
NVIDIA provides two GEMM libraries with fundamentally different design philosophies:
cuBLAS is a closed-source, precompiled library. It uses runtime heuristics to select kernel configurations (tile sizes, pipeline stages, split-K factors) based on the GEMM shape and GPU architecture. You call cublasGemmEx() and it picks the best kernel from its internal database.
CUTLASS is an open-source, template-based C++ library. You specify the tile sizes, pipeline depth, and data movement strategy at compile time. It generates a kernel specialized for your exact configuration.
cuBLAS Heuristic Selection
cuBLAS internally maintains a lookup table mapping (M, N, K, dtype, GPU) to kernel configurations. For well-known shapes (square matrices, power-of-2 dimensions), these heuristics are well-tuned. The problem arises with unusual shapes.
// cuBLAS GEMM call — the library chooses everything internally
cublasGemmEx(
handle,
CUBLAS_OP_T, CUBLAS_OP_N, // transpose configuration
N, M, K, // dimensions
&alpha,
weight_ptr, CUDA_R_16F, K, // A matrix
input_ptr, CUDA_R_16F, K, // B matrix
&beta,
output_ptr, CUDA_R_16F, N, // C matrix
CUBLAS_COMPUTE_16F, // compute type
CUBLAS_GEMM_DEFAULT // algorithm selection
);
// cuBLAS internally selects tile size, split-K, swizzle, etc.
cuBLAS also provides cublasLtMatmul with explicit algorithm selection and workspace:
// cuBLASLt gives more control
cublasLtMatmulDesc_t matmul_desc;
cublasLtMatmulDescCreate(&matmul_desc, CUBLAS_COMPUTE_16F, CUDA_R_32F);
cublasLtMatmulPreference_t preference;
cublasLtMatmulPreferenceCreate(&preference);
cublasLtMatmulPreferenceSetAttribute(
preference,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size, sizeof(workspace_size)
);
// Query available algorithms
int returned_algo_count;
cublasLtMatmulHeuristicResult_t heuristic_results[8];
cublasLtMatmulAlgoGetHeuristic(
lt_handle, matmul_desc,
layout_A, layout_B, layout_C, layout_C,
preference,
8, heuristic_results, &returned_algo_count
);
// Benchmark each returned algorithm and pick the fastest
// This is what PyTorch's autotuner does
CUTLASS: Compile-Time Specialization
CUTLASS kernels are fully specified at compile time via C++ templates:
// CUTLASS 3.x kernel definition for H100 (Hopper)
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<
cutlass::gemm::collective::CollectiveMma<
cutlass::arch::Sm90, // Target architecture
cutlass::Shape<128, 256, 64>, // CTA tile MxNxK
cutlass::half_t, // Element A
cutlass::layout::RowMajor, // Layout A
cutlass::half_t, // Element B
cutlass::layout::ColumnMajor, // Layout B
float, // Element accumulator
cutlass::Shape<64, 256, 16>, // WGMMA shape
cutlass::gemm::collective::StageCount<3>, // Pipeline stages
cutlass::gemm::collective::KernelScheduleAuto // Schedule
>,
cutlass::epilogue::collective::DefaultEpilogue<
cutlass::Shape<128, 256>, // Epilogue tile
cutlass::half_t, // Output type
cutlass::epilogue::fusion::LinCombEltAct< // Fused epilogue
cutlass::epilogue::thread::SiLu // SiLu activation
>
>
>;
When cuBLAS Wins
-
Standard shapes: For power-of-2 dimensions where cuBLAS has well-tuned heuristics, it often matches or beats hand-tuned CUTLASS. The cuBLAS team has exhaustively benchmarked common shapes.
-
Rapid iteration: No compilation step. cuBLAS kernels are precompiled — changing shapes requires zero rebuild time.
-
Split-K: For tall-skinny GEMMs (large M, small N), cuBLAS can automatically apply split-K decomposition, distributing the K-dimension reduction across multiple CTAs. CUTLASS supports split-K but requires explicit configuration.
When CUTLASS Wins
-
Epilogue fusion: CUTLASS can fuse bias add, activation functions (SiLU, GELU), and residual connections into the GEMM epilogue. cuBLAS requires separate kernel launches. For the SwiGLU FFN, this fusion avoids writing and re-reading intermediate activations.
-
Grouped/batched GEMM for MoE: CUTLASS 3.x has first-class support for grouped GEMM with variable sizes — critical for MoE (see Section 5).
-
Custom data types: CUTLASS supports custom quantization schemes, mixed-precision configurations, and novel data layouts that cuBLAS does not expose.
-
Non-standard shapes: When the GEMM shape does not match cuBLAS’s heuristic database, cuBLAS may select a suboptimal kernel. CUTLASS allows manual tuning.
cuBLAS vs CUTLASS Throughput (H100 SXM, FP16, Gate+Up GEMM)
| M (batch) | N | K | cuBLAS (TFLOPS) | CUTLASS (TFLOPS) | CUTLASS+Epilogue Fusion (TFLOPS) | Winner |
|---|---|---|---|---|---|---|
| 1 | 57344 | 8192 | 3.2 | 3.1 | 4.8 (fused SiLU) | CUTLASS+Fuse |
| 32 | 57344 | 8192 | 98 | 95 | 142 | CUTLASS+Fuse |
| 128 | 57344 | 8192 | 385 | 390 | 520 | CUTLASS+Fuse |
| 256 | 57344 | 8192 | 620 | 635 | 710 | CUTLASS+Fuse |
| 512 | 57344 | 8192 | 790 | 810 | 845 | CUTLASS+Fuse |
| 1024 | 57344 | 8192 | 870 | 880 | 895 | CUTLASS+Fuse |
| 128 | 8192 | 8192 | 380 | 385 | 400 | Comparable |
| 256 | 8192 | 8192 | 610 | 620 | 640 | Comparable |
PyTorch calls cublasLtMatmulAlgoGetHeuristic to get candidate algorithms, then benchmarks the top candidates at runtime via torch.backends.cuda.matmul.allow_tf32 and torch.backends.cudnn.benchmark. The first forward pass is slower due to autotuning. For serving, warm up the model with representative input shapes before accepting traffic.
5. Grouped GEMM for Mixture-of-Experts
Mixture-of-Experts (MoE) models route each token to a subset of expert FFNs. In DeepSeek-V3, each token is routed to 8 out of 256 experts. After the routing decision, you need to execute 256 independent FFN GEMMs, each processing a different (variable) number of tokens.
The naive approach — launching 256 separate GEMMs — is catastrophic. Each individual GEMM is tiny (if tokens are spread evenly, each expert gets tokens), and the kernel launch overhead and low SM utilization destroy throughput.
The Problem with Small GEMMs
With tokens and 256 experts (top-8 routing), each expert processes approximately tokens on average. For DeepSeek-V3 with and expert :
Each expert’s gate+up GEMM: — a 939M FLOP operation.
On H100 at 989 TFLOPS peak, this should take 0.95 microseconds. But actual CUDA kernel launch overhead is 3-5 microseconds. You spend more time launching the kernel than executing it.
Grouped GEMM Solutions
Approach 1: Padding and Batched GEMM
Pad all experts to the same M dimension and use cublasGemmBatched:
// All 256 experts padded to max_tokens_per_expert
int max_tokens = find_max_expert_count(routing_decisions);
int padded_max = ((max_tokens + 63) / 64) * 64; // Align to 64
// Batched GEMM: 256 GEMMs of [padded_max, d] x [d, d_ff]
cublasGemmBatchedEx(
handle,
CUBLAS_OP_T, CUBLAS_OP_N,
d_ff, padded_max, d,
&alpha,
weight_ptrs, CUDA_R_16F, d, // Array of 256 weight pointers
input_ptrs, CUDA_R_16F, d, // Array of 256 input pointers
&beta,
output_ptrs, CUDA_R_16F, d_ff, // Array of 256 output pointers
256, // Batch count
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT
);
Problem: load imbalance. If one expert gets 200 tokens and most get 20, padding to 200 wastes enormous compute.
Approach 2: CUTLASS Grouped GEMM
CUTLASS 3.x provides GroupedGemm that handles variable-size problems in a single kernel launch:
// CUTLASS Grouped GEMM — each problem has its own M
using GroupedGemm = cutlass::gemm::device::GemmGrouped<
cutlass::gemm::GemmShape<128, 128, 32>, // CTA tile
cutlass::gemm::GemmShape<64, 64, 32>, // Warp tile
cutlass::gemm::GemmShape<16, 8, 16>, // MMA shape
cutlass::epilogue::thread::LinearCombination<
cutlass::half_t, 8, float, float
>
>;
// Problem sizes: each expert has different M
// N and K are the same for all experts
std::vector<cutlass::gemm::GemmCoord> problem_sizes(num_experts);
for (int i = 0; i < num_experts; i++) {
problem_sizes[i] = {
expert_token_counts[i], // M varies per expert
d_ff, // N is fixed
d // K is fixed
};
}
// Single kernel launch handles all 256 problems
typename GroupedGemm::Arguments args(
problem_sizes.data(),
num_experts,
threadblock_count, // Total CTAs across all problems
ptr_A.data(), // Array of input pointers
ptr_B.data(), // Array of weight pointers
ptr_C.data(), // Array of output pointers
ptr_D.data(), // Array of output pointers (alias C)
lda.data(), ldb.data(), ldc.data(), ldd.data(),
{alpha, beta}
);
The key advantage: CUTLASS distributes CTAs across all 256 problems based on their actual sizes. Experts with more tokens get more CTAs. No padding waste, single kernel launch.
Approach 3: Permute-then-GEMM (used by Megablocks, vLLM)
Sort all tokens by expert assignment, concatenate into a single large tensor, and execute one fused GEMM with a custom indexing scheme:
# Pseudocode for the permute-then-GEMM approach
def moe_forward(x, router_logits, expert_weights):
# x: [B, d], router_logits: [B, num_experts]
# Top-k routing
scores, expert_indices = torch.topk(router_logits, k=top_k, dim=-1)
# expert_indices: [B, top_k]
# Sort tokens by expert assignment
flat_indices = expert_indices.flatten() # [B * top_k]
sorted_order = torch.argsort(flat_indices, stable=True)
# Permute tokens: group by expert
flat_tokens = x.repeat_interleave(top_k, dim=0) # [B*top_k, d]
permuted = flat_tokens[sorted_order] # [B*top_k, d]
# Count tokens per expert
expert_counts = torch.bincount(flat_indices, minlength=num_experts)
# expert_counts: [num_experts]
# Execute as a single GEMM with the stacked weight matrix
# This is a block-diagonal GEMM, handled efficiently by
# Triton or CUTLASS grouped GEMM
stacked_weights = torch.cat(
[expert_weights[i] for i in range(num_experts)], dim=0
) # [num_experts * d_ff, d]
# Split output and unpermute
# ... (inverse permutation to restore original token order)
MoE GEMM Strategies (256 experts, top-8, 1024 tokens, H100)
| Strategy | Kernel Launches | Effective TFLOPS | Wall Time | Notes |
|---|---|---|---|---|
| 256 separate GEMMs | 256 | 45 | 2.1 ms | Launch overhead dominates |
| Padded batched GEMM | 1 | 180 | 0.52 ms | Padding waste ~40% |
| CUTLASS grouped GEMM | 1 | 380 | 0.25 ms | No padding, good load balance |
| Permute + fused GEMM | 3 (permute, GEMM, unpermute) | 350 | 0.28 ms | Permute overhead |
6. Real Throughput Numbers on H100
The following measurements were collected on an H100 SXM5 80GB with driver 545.23.08 and CUDA 12.3. Each measurement is the median of 100 iterations after 10 warmup iterations. We report achieved TFLOPS, percentage of peak (989 TFLOPS for FP16, 1979 TFLOPS for FP8).
FP16 GEMM Throughput Across Shapes
H100 FP16 GEMM Throughput: Attention Shapes (Llama 70B)
| Shape [M, N, K] | FLOPs | Achieved TFLOPS | % of Peak (989) | Bottleneck |
|---|---|---|---|---|
| [1, 10240, 8192] | 167.8M | 3.1 | 0.31% | Memory BW |
| [8, 10240, 8192] | 1.34G | 24.5 | 2.5% | Memory BW |
| [32, 10240, 8192] | 5.37G | 93 | 9.4% | Memory BW |
| [64, 10240, 8192] | 10.7G | 178 | 18.0% | Memory BW |
| [128, 10240, 8192] | 21.5G | 330 | 33.4% | Memory BW |
| [256, 10240, 8192] | 42.9G | 555 | 56.1% | Transitional |
| [512, 10240, 8192] | 85.9G | 760 | 76.8% | Compute |
| [1024, 10240, 8192] | 171.8G | 870 | 88.0% | Compute |
| [2048, 10240, 8192] | 343.6G | 910 | 92.0% | Compute |
| [4096, 10240, 8192] | 687.2G | 935 | 94.5% | Compute |
H100 FP16 GEMM Throughput: FFN Shapes (Llama 70B)
| Shape [M, N, K] | FLOPs | Achieved TFLOPS | % of Peak | Bottleneck |
|---|---|---|---|---|
| [1, 57344, 8192] | 939.5M | 3.2 | 0.32% | Memory BW |
| [8, 57344, 8192] | 7.52G | 25.8 | 2.6% | Memory BW |
| [32, 57344, 8192] | 30.1G | 98 | 9.9% | Memory BW |
| [64, 57344, 8192] | 60.1G | 190 | 19.2% | Memory BW |
| [128, 57344, 8192] | 120.3G | 385 | 38.9% | Memory BW |
| [256, 57344, 8192] | 240.5G | 620 | 62.7% | Transitional |
| [512, 57344, 8192] | 481.0G | 790 | 79.9% | Compute |
| [1024, 57344, 8192] | 962.1G | 870 | 88.0% | Compute |
| [2048, 57344, 8192] | 1924.1G | 920 | 93.0% | Compute |
FP8 vs FP16 Comparison
H100 FP8 vs FP16 Throughput: Gate+Up GEMM [M, 57344, 8192]
| M (batch) | FP16 TFLOPS | FP8 TFLOPS | FP8/FP16 Ratio | FP8 % of Peak (1979) |
|---|---|---|---|---|
| 1 | 3.2 | 3.3 | 1.03x | 0.17% |
| 32 | 98 | 102 | 1.04x | 5.2% |
| 128 | 385 | 410 | 1.07x | 20.7% |
| 256 | 620 | 720 | 1.16x | 36.4% |
| 512 | 790 | 1120 | 1.42x | 56.6% |
| 1024 | 870 | 1480 | 1.70x | 74.8% |
| 2048 | 920 | 1720 | 1.87x | 86.9% |
| 4096 | 935 | 1830 | 1.96x | 92.5% |
The common claim that FP8 doubles throughput is only true for large, compute-bound GEMMs. For decode-phase GEMMs at typical batch sizes (1-64), the speedup is 1.0-1.1x because both precisions are equally memory-bandwidth-bound. FP8 delivers its full benefit during prefill and at high-throughput serving batch sizes.
Throughput vs. Latency: The Practitioner’s View
For serving systems, two metrics matter:
- Time-to-first-token (TTFT): Dominated by prefill, which uses large-M GEMMs. Here, high TFLOPS matters.
- Inter-token latency (ITL): Dominated by decode, which uses GEMMs. Here, memory bandwidth matters.
The optimal strategy differs:
# Prefill optimization: maximize compute utilization
# - Large batch (all prompt tokens processed at once)
# - FP8 provides near-2x improvement
# - CUTLASS with epilogue fusion helps
# - Chunk long prompts into pieces that fit in SRAM for FlashAttention
# Decode optimization: maximize memory bandwidth utilization
# - Batch as many concurrent requests as possible to increase M
# - FP8 helps little — same memory bandwidth
# - Weight quantization (INT4/INT8) reduces bytes loaded, directly
# improving throughput
# - Speculative decoding increases effective M per step
Effective Token Throughput vs Batch Size (Llama 70B, H100)
(tokens/sec)7. Practical Tuning Checklist
Here is a systematic approach to GEMM optimization for LLM inference:
Step 1: Profile the Shape Space
Identify every GEMM shape in your model. Use CUBLAS_WORKSPACE_CONFIG and nsys profile to capture all cublasGemmEx calls:
# Profile all cuBLAS calls
nsys profile --trace=cuda,cublas \
python -c "import model; model.forward(sample_input)"
# Extract GEMM shapes from the profile
nsys stats --report cublas_gpu_trace profile.nsys-rep
Step 2: Measure Achieved vs. Peak
For each GEMM shape, compute the theoretical maximum throughput (bounded by either compute or bandwidth) and compare with achieved:
def theoretical_max_tflops(M, N, K, dtype_bytes, peak_tflops, bw_tb_s):
flops = 2 * M * N * K
bytes_moved = (M * K + K * N + M * N) * dtype_bytes
ai = flops / bytes_moved
ridge_point = peak_tflops * 1e12 / (bw_tb_s * 1e12)
if ai < ridge_point:
# Memory-bound: limited by bandwidth
max_tflops = bw_tb_s * ai # BW * AI = achievable FLOPS/s
else:
# Compute-bound: limited by peak
max_tflops = peak_tflops
return max_tflops
# H100 FP16
peak = 989 # TFLOPS
bw = 3.35 # TB/s
# Decode: [1, 57344, 8192]
print(theoretical_max_tflops(1, 57344, 8192, 2, peak, bw))
# Memory-bound: ~3.35 TFLOPS (limited by reading 896 MB of weights)
# Prefill: [1024, 57344, 8192]
print(theoretical_max_tflops(1024, 57344, 8192, 2, peak, bw))
# Compute-bound: ~989 TFLOPS (AI >> ridge point)
Step 3: Address the Bottleneck
If memory-bound (decode-phase GEMMs):
- Increase batch size through continuous batching
- Apply weight quantization (INT4W reduces bytes by 4x, directly increasing effective AI by 4x)
- Use speculative decoding to increase tokens per step
If compute-bound (prefill-phase GEMMs):
- Ensure tile alignment (M, N, K multiples of 64/128)
- Use epilogue fusion to reduce kernel count
- Try FP8 for near-2x compute throughput
- Profile for wave quantization effects and adjust CTA tile sizes
Step 4: Benchmark Alternative Implementations
import torch
import time
def benchmark_gemm(M, N, K, dtype=torch.float16, warmup=10, iters=100):
A = torch.randn(M, K, dtype=dtype, device='cuda')
B = torch.randn(K, N, dtype=dtype, device='cuda')
# Warmup
for _ in range(warmup):
C = torch.mm(A, B)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(iters):
C = torch.mm(A, B)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start
flops = 2 * M * N * K * iters
tflops = flops / elapsed / 1e12
return tflops
# Profile all critical shapes
shapes = [
(1, 10240, 8192, "QKV decode"),
(128, 10240, 8192, "QKV prefill-128"),
(1024, 10240, 8192, "QKV prefill-1024"),
(1, 57344, 8192, "FFN decode"),
(128, 57344, 8192, "FFN prefill-128"),
(1024, 57344, 8192, "FFN prefill-1024"),
]
for M, N, K, name in shapes:
tflops = benchmark_gemm(M, N, K)
print(f"{name:25s} [{M:5d}, {N:5d}, {K:5d}] {tflops:.1f} TFLOPS")
Key Takeaways
-
FFN GEMMs dominate: The gate+up and down projections account for over 80% of per-layer FLOPs. Optimize these first.
-
Decode is always memory-bound: At batch size 1, arithmetic intensity is approximately 1 FLOP/byte — 295x below the H100 FP16 ridge point. No GEMM library can fix this; you need to increase the effective batch size.
-
FP8 helps prefill, not decode: The doubled compute throughput only materializes when the GEMM is compute-bound ( on H100). Decode-phase GEMMs see negligible improvement from FP8.
-
Tile alignment matters: Batch sizes that are multiples of 64 avoid padding waste in tensor core tiles. A batch of 65 tokens can be slower than 64 tokens.
-
MoE requires grouped GEMM: Launching hundreds of individual small GEMMs destroys throughput. CUTLASS grouped GEMM or the permute-then-GEMM pattern eliminates launch overhead and improves SM utilization.
-
CUTLASS wins with fusion: Fusing SiLU, bias, and residual operations into the GEMM epilogue avoids extra HBM round-trips — a 10-35% throughput improvement at medium batch sizes.
-
Measure, do not assume: cuBLAS heuristics are not always optimal. Profile your exact shapes, compare with CUTLASS, and benchmark at your actual serving batch sizes — not at cherry-picked power-of-2 dimensions.