cuBLAS gives you matrix multiplication. CUTLASS gives you a composable template library for building GEMM variants that cuBLAS cannot express: FP16 input accumulating into INT32, GEMM with fused GroupNorm epilogue, sparse GEMM with 2:4 structured sparsity, batched GEMM with per-matrix scaling. The CUTLASS version achieves 95-99% of cuBLAS throughput for standard cases and enables fusion patterns that would otherwise require separate kernel launches. The cost: instead of a single cuBLAS call, you instantiate templates specifying threadblock shapes, warp shapes, instruction shapes, epilogue functors, and pipeline stages. The abstraction is leaky by design β you configure the hardware directly.
CUTLASS achieves 95-99% of cuBLAS performance for standard configurations and enables operations (mixed-precision GEMM, GEMM+bias+activation fusion, sparse GEMM) that cuBLAS does not support or supports less flexibly.
All examples in this post target CUTLASS 3.x on Hopper (SM 9.0) with CUDA 12.x.
The CUTLASS GEMM Decomposition
Four Levels of Tiling
CUTLASS decomposes a GEMM into four hierarchical levels, each corresponding to a hardware abstraction:
Level 1: Problem Shape (Grid)
The full GEMM: C[M, N] = A[M, K] * B[K, N]
Partitioned across thread blocks in the grid
Level 2: Thread Block Tile (CTA)
Each CTA computes a tile of C: shape [ThreadblockShape_M, ThreadblockShape_N]
by iterating over K in chunks of ThreadblockShape_K
Uses shared memory to stage tiles of A and B
Level 3: Warp Tile
Each warp within the CTA computes a sub-tile of the CTA's output
Shape: [WarpShape_M, WarpShape_N, WarpShape_K]
Maps to tensor core MMA (Matrix Multiply-Accumulate) operations
Level 4: Instruction Tile (MMA)
The fundamental hardware operation
Hopper: wgmma (Warp Group MMA) - 64x128x16 for FP16
Ampere: mma.sync - 16x8x16 for FP16
Maps directly to a single tensor core instruction
// CUTLASS 3.x GEMM configuration for Hopper
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm_universal.h>
#include <cutlass/gemm/collective/collective_mma.hpp>
// Step 1: Define the problem types
using ElementA = cutlass::half_t; // A matrix element type
using ElementB = cutlass::half_t; // B matrix element type
using ElementC = cutlass::half_t; // C matrix element type (output)
using ElementAccumulator = float; // Accumulator type (internal)
// Step 2: Define layouts
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;
// Step 3: Define tile shapes
// ThreadblockShape: the tile of C computed by one CTA
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
// Each CTA computes a 128x128 output tile, iterating over K in chunks of 32
// WarpShape: the tile of C computed by one warp
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
// Each warp computes a 64x64 sub-tile
// CTAs per warp tile: (128/64) * (128/64) = 4 warps per CTA
// InstructionShape: the tensor core MMA instruction
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
// Ampere mma.sync: 16x8x16 for FP16
// Step 4: Define the GEMM operation
using Gemm = cutlass::gemm::device::Gemm<
ElementA, LayoutA, // A matrix
ElementB, LayoutB, // B matrix
ElementC, LayoutC, // C matrix
ElementAccumulator, // Accumulator type
cutlass::arch::OpClassTensorOp, // Use tensor cores
cutlass::arch::Sm80, // Target architecture
ThreadblockShape,
WarpShape,
InstructionShape,
cutlass::epilogue::thread::LinearCombination<
ElementC,
128 / cutlass::sizeof_bits<ElementC>::value, // Elements per access
ElementAccumulator,
ElementAccumulator
>, // Epilogue: alpha*AB + beta*C
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
3 // Pipeline stages (number of K tiles in flight)
>;
Understanding Tile Size Selection
The tile sizes determine occupancy, shared memory usage, and instruction-level parallelism:
// Shared memory requirement for one CTA:
// smem_A = ThreadblockShape_M * ThreadblockShape_K * sizeof(ElementA) * stages
// smem_B = ThreadblockShape_K * ThreadblockShape_N * sizeof(ElementB) * stages
//
// For ThreadblockShape<128, 128, 32>, FP16, stages=3:
// smem_A = 128 * 32 * 2 * 3 = 24,576 bytes = 24 KB
// smem_B = 32 * 128 * 2 * 3 = 24,576 bytes = 24 KB
// Total: 48 KB per CTA
//
// H100 has 228 KB shared memory per SM
// -> Can fit 4 CTAs per SM (48 * 4 = 192 KB < 228 KB)
// -> But register pressure may limit to 2-3 CTAs
//
// Larger tiles: higher arithmetic intensity, fewer CTAs per SM
// Smaller tiles: more CTAs (better latency hiding), lower efficiency per CTA
// Tile size selection heuristic:
// 1. Start with ThreadblockShape that gives 2-4 CTAs per SM
// 2. WarpShape should divide evenly into ThreadblockShape
// 3. InstructionShape is fixed by hardware
// 4. Pipeline stages: more is better (hides global memory latency)
// but costs shared memory
The number of pipeline stages is critical on Hopper. Hopperβs TMA (Tensor Memory Accelerator) allows asynchronous global-to-shared memory copies that overlap with computation. With 3-5 stages, the TMA can prefetch the next K-tile while the tensor cores process the current one, achieving near-perfect overlap. On Ampere (which uses cp.async instead of TMA), 2-3 stages are typical.
CUTLASS 3.x and the CuTe Layout Algebra
CUTLASS 3.x introduces CuTe (CUDA Templates for Tensors), a layout algebra that describes how tensor elements map to threads and memory:
// CuTe layout: describes the mapping from logical coordinates to physical memory
// A layout is a pair: (Shape, Stride)
//
// Example: a 128x32 tile of FP16 in row-major:
// Shape: (128, 32)
// Stride: (32, 1) -- row-major: stride-1 along K dimension
//
// Element at (m, k) is at offset: m * 32 + k * 1
#include <cute/tensor.hpp>
#include <cute/layout.hpp>
using namespace cute;
// Define a layout for a 128x32 tile
auto layout_A = make_layout(
make_shape(Int<128>{}, Int<32>{}),
make_stride(Int<32>{}, Int<1>{})
);
// This is a compile-time layout - all values are known at compile time
// Thread-to-data mapping for loading A tile:
// 128 threads in a CTA, each loads a portion of the 128x32 tile
// 128*32 = 4096 elements / 128 threads = 32 elements per thread
auto thr_layout = make_layout(
make_shape(Int<32>{}, Int<4>{}), // 32 rows of threads, 4 columns
make_stride(Int<4>{}, Int<1>{})
);
// Each thread loads: 128/32 = 4 rows, 32/4 = 8 columns = 32 elements
// Using vectorized loads: 8 FP16 = 16 bytes = 128-bit load per access
Hopper Kernel with TMA and wgmma
On Hopper, CUTLASS 3.x uses TMA for global memory loads and wgmma (Warp Group MMA) for tensor core operations:
// Simplified Hopper GEMM mainloop using CUTLASS 3.x concepts
// TMA descriptor: hardware unit that performs bulk memory copies
// from global memory to shared memory without SM involvement
template <class TiledMma, class TmaA, class TmaB>
__global__ void hopper_gemm_kernel(
TmaA tma_a, // TMA descriptor for loading A tiles
TmaB tma_b, // TMA descriptor for loading B tiles
float* C,
int M, int N, int K
) {
// Shared memory for double/multi-buffered tiles
extern __shared__ char smem[];
half* smem_A = reinterpret_cast<half*>(smem);
half* smem_B = smem_A + /* A tile size */;
// Pipeline barriers for asynchronous TMA loads
using Barrier = cutlass::arch::ClusterBarrier;
Barrier barrier[NUM_STAGES];
// Initialize barriers
if (threadIdx.x == 0) {
for (int s = 0; s < NUM_STAGES; s++) {
barrier[s].init(1); // Expected arrivals = 1 (TMA)
}
}
__syncthreads();
// Mainloop: iterate over K dimension
int num_k_tiles = K / TILE_K;
int stage = 0;
// Producer: issue TMA loads
if (threadIdx.x == 0) {
for (int k = 0; k < num_k_tiles; k++) {
int s = k % NUM_STAGES;
barrier[s].arrive_and_expect_tx(/* bytes per tile */);
// TMA loads A[block_m : block_m+TILE_M, k*TILE_K : (k+1)*TILE_K]
// and B[k*TILE_K : (k+1)*TILE_K, block_n : block_n+TILE_N]
// directly into shared memory, no SM threads needed
cutlass::arch::tma_load(
smem_A + s * TILE_M * TILE_K,
tma_a, k, blockIdx.x
);
cutlass::arch::tma_load(
smem_B + s * TILE_K * TILE_N,
tma_b, k, blockIdx.y
);
}
}
// Consumer: perform wgmma on loaded tiles
float acc[TILE_M_PER_WARP * TILE_N_PER_WARP] = {0};
for (int k = 0; k < num_k_tiles; k++) {
int s = k % NUM_STAGES;
// Wait for TMA load to complete
barrier[s].wait(/* phase */);
// wgmma: Warp Group MMA
// A warp group is 4 warps (128 threads) that cooperate on one MMA
// Hopper wgmma shape: 64x128x16 for FP16 (much larger than Ampere's 16x8x16)
cutlass::arch::wgmma(
acc,
smem_A + s * TILE_M * TILE_K,
smem_B + s * TILE_K * TILE_N
);
// Signal that shared memory can be reused
if (threadIdx.x == 0) {
barrier[s].arrive();
}
}
// Epilogue: write accumulated results to global memory
// Apply alpha*acc + beta*C
store_epilogue(C, acc, blockIdx.x, blockIdx.y, M, N);
}
Epilogue Functors: Fusing Post-GEMM Operations
CUTLASS epilogues define what happens after the matrix multiply. The standard linear combination epilogue computes , but custom epilogues can fuse bias addition, activation functions, and even quantization:
// Custom epilogue: GEMM + bias + GELU activation
// Y = GELU(alpha * A * B + bias)
// CUTLASS provides composable epilogue visitors
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU<
cutlass::half_t, // Output type
128 / 16, // Elements per access (128-bit / 16-bit = 8)
float, // Accumulator type
float // Compute type for GELU
>;
// Or build a custom epilogue with the visitor pattern (CUTLASS 3.x):
// 1. Scale accumulator by alpha
// 2. Add per-channel bias
// 3. Apply GELU
// 4. Quantize to INT8 (optional)
// Custom epilogue functor
template <typename Element, typename Accumulator>
struct GemmBiasGeluQuantizeEpilogue {
Accumulator alpha;
const Accumulator* bias; // [N] per-channel bias
Accumulator quant_scale;
CUTLASS_DEVICE
void operator()(
int row, int col,
Accumulator acc, // Raw accumulator value
Element& output // Output to write
) {
// Step 1: Scale
Accumulator result = alpha * acc;
// Step 2: Add bias
result += bias[col];
// Step 3: GELU approximation
// GELU(x) ~ 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
Accumulator x = result;
Accumulator cdf = 0.5f * (1.0f + tanhf(
0.7978845608f * (x + 0.044715f * x * x * x)
));
result = x * cdf;
// Step 4: Optional quantize to INT8
// result = round(result / quant_scale)
// output = clamp(result, -128, 127)
output = static_cast<Element>(result);
}
};
Fusing bias, activation, and quantization into the GEMM epilogue avoids separate kernel launches and memory round-trips. A fused GEMM+bias+GELU saves 2 kernel launches and 2x reads/writes of the output matrix. For a 4096x4096 output in FP16, that is 32 MB of memory traffic saved, which at 3.35 TB/s bandwidth takes ~10 microseconds. For small matrices in LLM decode, this epilogue fusion can improve end-to-end kernel time by 10-20%.
Mixed-Precision GEMM Configuration
CUTLASS excels at mixed-precision operations that cuBLAS does not directly support:
// Example 1: FP16 inputs, FP32 accumulation, FP16 output (standard)
using Gemm_FP16 = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor, // A: FP16
cutlass::half_t, cutlass::layout::ColumnMajor, // B: FP16
cutlass::half_t, cutlass::layout::RowMajor, // C: FP16
float, // Accumulator: FP32
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>
>;
// Example 2: INT8 inputs, INT32 accumulation, FP16 output (W8A8 inference)
using Gemm_INT8 = cutlass::gemm::device::Gemm<
int8_t, cutlass::layout::RowMajor,
int8_t, cutlass::layout::ColumnMajor,
cutlass::half_t, cutlass::layout::RowMajor,
int32_t, // INT32 accumulator
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 64>, // Larger K tile for INT8
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<16, 8, 32> // INT8 MMA: 16x8x32
>;
// Example 3: FP8 inputs, FP32 accumulation (Hopper FP8 GEMM)
using Gemm_FP8 = cutlass::gemm::device::Gemm<
cutlass::float_e4m3_t, cutlass::layout::RowMajor, // A: E4M3
cutlass::float_e4m3_t, cutlass::layout::ColumnMajor, // B: E4M3
cutlass::half_t, cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm90,
cutlass::gemm::GemmShape<128, 256, 64>,
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<16, 8, 32>
>;
// Example 4: INT4 weights dequantized to FP16 for GEMM (W4A16)
// This requires a custom mainloop that dequantizes INT4 -> FP16 on the fly
// CUTLASS provides this through the "mixed input" GEMM:
using Gemm_W4A16 = cutlass::gemm::device::GemmUniversal<
cutlass::half_t, cutlass::layout::RowMajor, // A: FP16 activations
cutlass::int4b_t, cutlass::layout::ColumnMajor, // B: INT4 weights
cutlass::half_t, cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80
>;
// Internally: loads INT4 weights, dequantizes to FP16 in shared memory,
// then performs the standard FP16 tensor core MMA
Launching and Profiling CUTLASS GEMMs
// Complete example: launching a CUTLASS GEMM
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/util/host_tensor.h>
int main() {
int M = 4096, N = 4096, K = 4096;
// Define the GEMM type (FP16 with tensor cores)
using Gemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t, cutlass::layout::ColumnMajor,
cutlass::half_t, cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<16, 8, 16>
>;
// Allocate matrices
cutlass::HostTensor<cutlass::half_t, cutlass::layout::RowMajor> A({M, K});
cutlass::HostTensor<cutlass::half_t, cutlass::layout::ColumnMajor> B({K, N});
cutlass::HostTensor<cutlass::half_t, cutlass::layout::RowMajor> C({M, N});
A.sync_device();
B.sync_device();
C.sync_device();
// Configure GEMM arguments
Gemm::Arguments args(
{M, N, K}, // Problem size
{A.device_data(), K}, // A tensor ref
{B.device_data(), K}, // B tensor ref (col-major, so stride = K)
{C.device_data(), N}, // C tensor ref (source)
{C.device_data(), N}, // D tensor ref (destination, can alias C)
{1.0f, 0.0f} // alpha, beta
);
// Instantiate and run
Gemm gemm_op;
cutlass::Status status = gemm_op.can_implement(args);
if (status != cutlass::Status::kSuccess) {
// Handle error: tile sizes don't work for this problem
return -1;
}
// Query workspace size
size_t workspace_size = Gemm::get_workspace_size(args);
cutlass::DeviceAllocation<uint8_t> workspace(workspace_size);
status = gemm_op.initialize(args, workspace.get());
status = gemm_op(); // Launch!
// Check for errors
cudaDeviceSynchronize();
if (status != cutlass::Status::kSuccess) {
printf("GEMM failed: %s\n", cutlassGetStatusString(status));
}
return 0;
}
Performance Tuning with Tile Size Sweeps
// Systematic tile size sweep for optimal performance
// Different problem sizes favor different tile configurations
struct TileConfig {
int tb_m, tb_n, tb_k;
int warp_m, warp_n, warp_k;
int stages;
float achieved_tflops;
};
// Results of sweeping tile sizes for M=N=K=4096, FP16, H100:
//
// ThreadblockShape | WarpShape | Stages | TFLOPS | % of Peak
// 128x128x32 | 64x64x32 | 3 | 198 | 61%
// 128x256x32 | 64x128x32 | 3 | 256 | 79%
// 256x128x32 | 128x64x32 | 3 | 248 | 77%
// 128x256x64 | 64x128x64 | 4 | 298 | 92% <-- Good
// 256x128x64 | 64x64x64 | 4 | 289 | 89%
// 128x128x64 | 64x64x64 | 5 | 275 | 85%
//
// Observations:
// 1. Larger N dimension in tile -> better memory coalescing for B
// 2. Larger K dimension -> fewer mainloop iterations, less overhead
// 3. More stages -> better latency hiding, but more shared memory
// 4. The optimal tile depends on the GEMM shape (M, N, K)
CUTLASS GEMM Throughput by Tile Configuration (M=N=K=4096, FP16, H100)
(TFLOPS)Practical Use Cases in LLM Inference
Fused Attention Projection GEMM
In LLM inference, the QKV projection is a GEMM of shape . Fusing Q, K, V into a single GEMM is more efficient than three separate GEMMs:
// Fused QKV projection: one GEMM instead of three
// Input: X [BS*seq, d_model]
// Weight: W_qkv [d_model, 3*d_model] (Q, K, V weights concatenated)
// Output: QKV [BS*seq, 3*d_model]
// CUTLASS GEMM with custom epilogue that splits Q, K, V
// and applies RoPE to Q, K
using FusedQKVGemm = cutlass::gemm::device::Gemm<
cutlass::half_t, cutlass::layout::RowMajor,
cutlass::half_t, cutlass::layout::ColumnMajor,
cutlass::half_t, cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm80,
cutlass::gemm::GemmShape<128, 256, 32>, // Large N for 3*d_model
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<16, 8, 16>,
RoPEEpilogue<cutlass::half_t, float> // Custom epilogue applies RoPE
>;
// The RoPE epilogue:
// - Splits output into Q [BS*seq, d_model], K [BS*seq, d_kv], V [BS*seq, d_kv]
// - Applies rotary position embedding to Q and K
// - Writes Q, K, V to separate output buffers
// This saves one kernel launch and one read/write of Q and K tensors
Weight-Only Dequantization GEMM
For INT4 weight-only quantization (W4A16), CUTLASS provides mixed-input GEMM that dequantizes INT4 weights to FP16 on-the-fly in shared memory:
// W4A16 GEMM: FP16 activations * INT4 weights -> FP16 output
// The INT4 weights are stored in packed format (2 weights per byte)
// Dequantization: w_fp16 = (w_int4 - zero_point) * scale
// This is what vLLM's Marlin kernel and AWQ's GEMM kernel implement
// CUTLASS provides the building blocks:
// 1. Custom global memory loader that reads packed INT4
// 2. Shared memory dequantization stage
// 3. Standard FP16 tensor core MMA on dequantized values
// The dequantization in shared memory:
__device__ void dequant_int4_to_fp16(
const uint8_t* packed_int4, // Input: 2 INT4 values per byte
half* output_fp16, // Output: FP16 values
half scale,
half zero_point,
int num_elements
) {
for (int i = threadIdx.x; i < num_elements / 2; i += blockDim.x) {
uint8_t packed = packed_int4[i];
int4 lo = (packed & 0x0F) - 8; // Low nibble, signed
int4 hi = ((packed >> 4) & 0x0F) - 8; // High nibble, signed
output_fp16[2*i] = __hmul(__hsub(__int2half_rn(lo), zero_point), scale);
output_fp16[2*i + 1] = __hmul(__hsub(__int2half_rn(hi), zero_point), scale);
}
__syncthreads();
}
CUTLASS Custom GEMM vs cuBLAS Performance (H100)
| GEMM Type | M | N | K | CUTLASS TFLOPS | cuBLAS TFLOPS | Ratio |
|---|---|---|---|---|---|---|
| FP16 standard | 4096 | 4096 | 4096 | 298 | 305 | 97.7% |
| FP16 standard | 1 | 4096 | 4096 | 0.8 | 0.9 | 88.9% |
| FP16+bias+GELU fused | 4096 | 4096 | 4096 | 285 | N/A | Custom |
| INT8 W8A8 | 4096 | 4096 | 4096 | 580 | 595 | 97.5% |
| W4A16 (dequant) | 4096 | 4096 | 4096 | 310 | N/A | Custom |
| FP8 E4M3 | 4096 | 4096 | 4096 | 580 | 590 | 98.3% |
Debugging CUTLASS Compilation
CUTLASS templates produce notoriously long compilation errors. Common issues and solutions:
// Error: "No matching function for call to 'Gemm::Gemm'"
// Cause: Tile sizes are not divisible. WarpShape must divide ThreadblockShape.
// ThreadblockShape<128, 128, 32>, WarpShape<64, 64, 32>: OK (2x2 warps)
// ThreadblockShape<128, 128, 32>, WarpShape<32, 64, 32>: OK (4x2 warps)
// ThreadblockShape<128, 128, 32>, WarpShape<48, 64, 32>: ERROR (128/48 not integer)
// Error: "static_assert failed: Instruction shape must divide warp shape"
// InstructionShape must divide WarpShape along all dimensions
// WarpShape<64, 64, 32>, InstructionShape<16, 8, 16>: OK
// WarpShape<64, 64, 32>, InstructionShape<16, 8, 32>: OK
// WarpShape<64, 64, 32>, InstructionShape<16, 16, 16>: ERROR (64/16=4, OK, but check K)
// Error: Shared memory exceeds limit
// Reduce stages or tile size
// Check: stages * (TB_M * TB_K + TB_K * TB_N) * sizeof(element) <= 228 KB (H100)
// Performance issue: Low occupancy
// Use CUTLASS profiler to check:
// cutlass_profiler --operation=Gemm --m=4096 --n=4096 --k=4096 \
// --A=f16:row --B=f16:column --C=f16:row \
// --threadblock=128x256x64 --warp=64x128x64 --stages=4
Summary
CUTLASS provides the full hierarchical decomposition of GEMM as composable C++ templates: problem-level grid tiling, CTA-level shared memory staging, warp-level MMA dispatch, and instruction-level tensor core operations. This decomposition enables custom GEMMs (mixed-precision, fused epilogues, weight dequantization) that achieve 95-99% of cuBLAS performance while supporting operations cuBLAS cannot. The key configuration decisions are tile sizes (which determine shared memory usage and occupancy), pipeline stages (which determine latency hiding), and epilogue design (which determines fusion opportunities). For LLM inference, CUTLASS is the foundation for virtually all high-performance custom GEMM kernels including the Marlin W4A16 kernel, INT8 SmoothQuant GEMMs, and fused QKV projections.