Part of Series CUDA Kernel Engineering 21 of 32
1 CUDA Thread Hierarchy: Grids, Blocks, Warps, and the Execution Model That Determines Performance 2 Memory Coalescing: Why Access Patterns Determine 10x Performance Differences 3 Shared Memory and Bank Conflicts: 32 Banks, 4-Byte Width, and the Padding Trick 4 Warp Primitives: Shuffle, Vote, Match, and Cooperative Reduction Without Shared Memory 5 Tensor Cores: WMMA, MMA, and WGMMA β€” Matrix Multiply at Hardware Speed 6 Triton Kernel Development: Writing GPU Kernels in Python with Auto-Tuning 7 Kernel Fusion Patterns: Elementwise, Reduction, GEMM Epilogue, and Attention Fusion 8 Nsight Compute and Nsight Systems: The Complete GPU Profiling Workflow 9 CUDA Graphs: Capture, Replay, Memory Management, and Dynamic Shape Handling 10 Atomics and Advanced Reductions: Global Atomics, Warp Reductions, and Multi-Block Coordination 11 Occupancy Calculator: Registers, Shared Memory, Block Size, and Finding the Sweet Spot 12 Vectorized Loads: float4, int4, and 128-Bit Memory Transactions for Maximum Bandwidth 13 Cooperative Groups: Sub-Warp Tiles, Block Synchronization, and Grid-Level Cooperation 14 Dynamic Parallelism: Launching Kernels from Kernels and When It Actually Helps 15 CUDA Streams and Events: Concurrent Execution, Overlap, and Synchronization Patterns 16 Reduction Patterns: Sum, Max, Histogram β€” From Naive to Warp-Optimized 17 Parallel Scan and Prefix Sum: Blelloch Algorithm, Work-Efficient Implementation 18 Matrix Transpose: The Canonical CUDA Optimization Problem β€” From Naive to Bank-Conflict-Free 19 Writing a Custom Attention Kernel: From Naive to Tiled to FlashAttention-Style 20 Debugging CUDA: compute-sanitizer, cuda-gdb, Common Errors, and Race Condition Detection 21 CUTLASS GEMM Templates: Writing High-Performance Matrix Multiply with NVIDIA's Template Library 22 Persistent Kernels: Long-Running Thread Blocks for Continuous Inference Processing 23 Memory Access Pattern Analysis: From Roofline Model to Kernel Optimization Strategy 24 CUDA Graphs for LLM Inference: Eliminating Kernel Launch Overhead from First Principles 25 CUDA Kernel Fusion: Reducing Memory Traffic for Elementwise-Heavy Workloads 26 CUDA Kernel Optimization: A Systematic Guide from Roofline to Nsight 27 CUDA Streams: Overlapping PCIe Transfers with Compute (and When It Actually Helps) 28 CUDA Unified Memory: When It Helps, When It Hurts, and Grace Hopper 29 CUDA Warp Mastery: Scheduling, Divergence, Shuffles, Occupancy, and Profiling 30 eBPF for LLM Inference Profiling: Kernel-Level Observability 31 GPU Memory Profiling: Finding Leaks, Fragmentation, and Hidden Overhead 32 The Roofline Model for GPU Kernel Optimization: From First Principles to LLM Workload Analysis

cuBLAS gives you matrix multiplication. CUTLASS gives you a composable template library for building GEMM variants that cuBLAS cannot express: FP16 input accumulating into INT32, GEMM with fused GroupNorm epilogue, sparse GEMM with 2:4 structured sparsity, batched GEMM with per-matrix scaling. The CUTLASS version achieves 95-99% of cuBLAS throughput for standard cases and enables fusion patterns that would otherwise require separate kernel launches. The cost: instead of a single cuBLAS call, you instantiate templates specifying threadblock shapes, warp shapes, instruction shapes, epilogue functors, and pipeline stages. The abstraction is leaky by design β€” you configure the hardware directly.

CUTLASS achieves 95-99% of cuBLAS performance for standard configurations and enables operations (mixed-precision GEMM, GEMM+bias+activation fusion, sparse GEMM) that cuBLAS does not support or supports less flexibly.

All examples in this post target CUTLASS 3.x on Hopper (SM 9.0) with CUDA 12.x.

The CUTLASS GEMM Decomposition

Four Levels of Tiling

CUTLASS decomposes a GEMM C=Ξ±β‹…Aβ‹…B+Ξ²β‹…CC = \alpha \cdot A \cdot B + \beta \cdot C into four hierarchical levels, each corresponding to a hardware abstraction:

Level 1: Problem Shape (Grid)
  The full GEMM: C[M, N] = A[M, K] * B[K, N]
  Partitioned across thread blocks in the grid

Level 2: Thread Block Tile (CTA)
  Each CTA computes a tile of C: shape [ThreadblockShape_M, ThreadblockShape_N]
  by iterating over K in chunks of ThreadblockShape_K
  Uses shared memory to stage tiles of A and B

Level 3: Warp Tile
  Each warp within the CTA computes a sub-tile of the CTA's output
  Shape: [WarpShape_M, WarpShape_N, WarpShape_K]
  Maps to tensor core MMA (Matrix Multiply-Accumulate) operations

Level 4: Instruction Tile (MMA)
  The fundamental hardware operation
  Hopper: wgmma (Warp Group MMA) - 64x128x16 for FP16
  Ampere: mma.sync - 16x8x16 for FP16
  Maps directly to a single tensor core instruction
// CUTLASS 3.x GEMM configuration for Hopper
#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm_universal.h>
#include <cutlass/gemm/collective/collective_mma.hpp>

// Step 1: Define the problem types
using ElementA = cutlass::half_t;     // A matrix element type
using ElementB = cutlass::half_t;     // B matrix element type
using ElementC = cutlass::half_t;     // C matrix element type (output)
using ElementAccumulator = float;      // Accumulator type (internal)

// Step 2: Define layouts
using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;

// Step 3: Define tile shapes
// ThreadblockShape: the tile of C computed by one CTA
using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>;
// Each CTA computes a 128x128 output tile, iterating over K in chunks of 32

// WarpShape: the tile of C computed by one warp
using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>;
// Each warp computes a 64x64 sub-tile
// CTAs per warp tile: (128/64) * (128/64) = 4 warps per CTA

// InstructionShape: the tensor core MMA instruction
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
// Ampere mma.sync: 16x8x16 for FP16

// Step 4: Define the GEMM operation
using Gemm = cutlass::gemm::device::Gemm<
    ElementA, LayoutA,           // A matrix
    ElementB, LayoutB,           // B matrix
    ElementC, LayoutC,           // C matrix
    ElementAccumulator,          // Accumulator type
    cutlass::arch::OpClassTensorOp,  // Use tensor cores
    cutlass::arch::Sm80,         // Target architecture
    ThreadblockShape,
    WarpShape,
    InstructionShape,
    cutlass::epilogue::thread::LinearCombination<
        ElementC,
        128 / cutlass::sizeof_bits<ElementC>::value,  // Elements per access
        ElementAccumulator,
        ElementAccumulator
    >,  // Epilogue: alpha*AB + beta*C
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
    3   // Pipeline stages (number of K tiles in flight)
>;

Understanding Tile Size Selection

The tile sizes determine occupancy, shared memory usage, and instruction-level parallelism:

// Shared memory requirement for one CTA:
// smem_A = ThreadblockShape_M * ThreadblockShape_K * sizeof(ElementA) * stages
// smem_B = ThreadblockShape_K * ThreadblockShape_N * sizeof(ElementB) * stages
//
// For ThreadblockShape<128, 128, 32>, FP16, stages=3:
// smem_A = 128 * 32 * 2 * 3 = 24,576 bytes = 24 KB
// smem_B = 32 * 128 * 2 * 3 = 24,576 bytes = 24 KB
// Total: 48 KB per CTA
//
// H100 has 228 KB shared memory per SM
// -> Can fit 4 CTAs per SM (48 * 4 = 192 KB < 228 KB)
// -> But register pressure may limit to 2-3 CTAs
//
// Larger tiles: higher arithmetic intensity, fewer CTAs per SM
// Smaller tiles: more CTAs (better latency hiding), lower efficiency per CTA

// Tile size selection heuristic:
// 1. Start with ThreadblockShape that gives 2-4 CTAs per SM
// 2. WarpShape should divide evenly into ThreadblockShape
// 3. InstructionShape is fixed by hardware
// 4. Pipeline stages: more is better (hides global memory latency)
//    but costs shared memory
ℹ️ Note

The number of pipeline stages is critical on Hopper. Hopper’s TMA (Tensor Memory Accelerator) allows asynchronous global-to-shared memory copies that overlap with computation. With 3-5 stages, the TMA can prefetch the next K-tile while the tensor cores process the current one, achieving near-perfect overlap. On Ampere (which uses cp.async instead of TMA), 2-3 stages are typical.

CUTLASS 3.x and the CuTe Layout Algebra

CUTLASS 3.x introduces CuTe (CUDA Templates for Tensors), a layout algebra that describes how tensor elements map to threads and memory:

// CuTe layout: describes the mapping from logical coordinates to physical memory
// A layout is a pair: (Shape, Stride)
//
// Example: a 128x32 tile of FP16 in row-major:
// Shape: (128, 32)
// Stride: (32, 1)  -- row-major: stride-1 along K dimension
//
// Element at (m, k) is at offset: m * 32 + k * 1

#include <cute/tensor.hpp>
#include <cute/layout.hpp>

using namespace cute;

// Define a layout for a 128x32 tile
auto layout_A = make_layout(
    make_shape(Int<128>{}, Int<32>{}),
    make_stride(Int<32>{}, Int<1>{})
);
// This is a compile-time layout - all values are known at compile time

// Thread-to-data mapping for loading A tile:
// 128 threads in a CTA, each loads a portion of the 128x32 tile
// 128*32 = 4096 elements / 128 threads = 32 elements per thread
auto thr_layout = make_layout(
    make_shape(Int<32>{}, Int<4>{}),  // 32 rows of threads, 4 columns
    make_stride(Int<4>{}, Int<1>{})
);

// Each thread loads: 128/32 = 4 rows, 32/4 = 8 columns = 32 elements
// Using vectorized loads: 8 FP16 = 16 bytes = 128-bit load per access

Hopper Kernel with TMA and wgmma

On Hopper, CUTLASS 3.x uses TMA for global memory loads and wgmma (Warp Group MMA) for tensor core operations:

// Simplified Hopper GEMM mainloop using CUTLASS 3.x concepts

// TMA descriptor: hardware unit that performs bulk memory copies
// from global memory to shared memory without SM involvement
template <class TiledMma, class TmaA, class TmaB>
__global__ void hopper_gemm_kernel(
    TmaA tma_a,   // TMA descriptor for loading A tiles
    TmaB tma_b,   // TMA descriptor for loading B tiles
    float* C,
    int M, int N, int K
) {
    // Shared memory for double/multi-buffered tiles
    extern __shared__ char smem[];
    half* smem_A = reinterpret_cast<half*>(smem);
    half* smem_B = smem_A + /* A tile size */;

    // Pipeline barriers for asynchronous TMA loads
    using Barrier = cutlass::arch::ClusterBarrier;
    Barrier barrier[NUM_STAGES];

    // Initialize barriers
    if (threadIdx.x == 0) {
        for (int s = 0; s < NUM_STAGES; s++) {
            barrier[s].init(1);  // Expected arrivals = 1 (TMA)
        }
    }
    __syncthreads();

    // Mainloop: iterate over K dimension
    int num_k_tiles = K / TILE_K;
    int stage = 0;

    // Producer: issue TMA loads
    if (threadIdx.x == 0) {
        for (int k = 0; k < num_k_tiles; k++) {
            int s = k % NUM_STAGES;
            barrier[s].arrive_and_expect_tx(/* bytes per tile */);

            // TMA loads A[block_m : block_m+TILE_M, k*TILE_K : (k+1)*TILE_K]
            // and     B[k*TILE_K : (k+1)*TILE_K, block_n : block_n+TILE_N]
            // directly into shared memory, no SM threads needed
            cutlass::arch::tma_load(
                smem_A + s * TILE_M * TILE_K,
                tma_a, k, blockIdx.x
            );
            cutlass::arch::tma_load(
                smem_B + s * TILE_K * TILE_N,
                tma_b, k, blockIdx.y
            );
        }
    }

    // Consumer: perform wgmma on loaded tiles
    float acc[TILE_M_PER_WARP * TILE_N_PER_WARP] = {0};

    for (int k = 0; k < num_k_tiles; k++) {
        int s = k % NUM_STAGES;

        // Wait for TMA load to complete
        barrier[s].wait(/* phase */);

        // wgmma: Warp Group MMA
        // A warp group is 4 warps (128 threads) that cooperate on one MMA
        // Hopper wgmma shape: 64x128x16 for FP16 (much larger than Ampere's 16x8x16)
        cutlass::arch::wgmma(
            acc,
            smem_A + s * TILE_M * TILE_K,
            smem_B + s * TILE_K * TILE_N
        );

        // Signal that shared memory can be reused
        if (threadIdx.x == 0) {
            barrier[s].arrive();
        }
    }

    // Epilogue: write accumulated results to global memory
    // Apply alpha*acc + beta*C
    store_epilogue(C, acc, blockIdx.x, blockIdx.y, M, N);
}

Epilogue Functors: Fusing Post-GEMM Operations

CUTLASS epilogues define what happens after the matrix multiply. The standard linear combination epilogue computes Ξ±β‹…AB+Ξ²β‹…C\alpha \cdot AB + \beta \cdot C, but custom epilogues can fuse bias addition, activation functions, and even quantization:

// Custom epilogue: GEMM + bias + GELU activation
// Y = GELU(alpha * A * B + bias)

// CUTLASS provides composable epilogue visitors
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU<
    cutlass::half_t,     // Output type
    128 / 16,            // Elements per access (128-bit / 16-bit = 8)
    float,               // Accumulator type
    float                // Compute type for GELU
>;

// Or build a custom epilogue with the visitor pattern (CUTLASS 3.x):
// 1. Scale accumulator by alpha
// 2. Add per-channel bias
// 3. Apply GELU
// 4. Quantize to INT8 (optional)

// Custom epilogue functor
template <typename Element, typename Accumulator>
struct GemmBiasGeluQuantizeEpilogue {
    Accumulator alpha;
    const Accumulator* bias;   // [N] per-channel bias
    Accumulator quant_scale;

    CUTLASS_DEVICE
    void operator()(
        int row, int col,
        Accumulator acc,       // Raw accumulator value
        Element& output        // Output to write
    ) {
        // Step 1: Scale
        Accumulator result = alpha * acc;

        // Step 2: Add bias
        result += bias[col];

        // Step 3: GELU approximation
        // GELU(x) ~ 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
        Accumulator x = result;
        Accumulator cdf = 0.5f * (1.0f + tanhf(
            0.7978845608f * (x + 0.044715f * x * x * x)
        ));
        result = x * cdf;

        // Step 4: Optional quantize to INT8
        // result = round(result / quant_scale)
        // output = clamp(result, -128, 127)

        output = static_cast<Element>(result);
    }
};
⚑ Performance

Fusing bias, activation, and quantization into the GEMM epilogue avoids separate kernel launches and memory round-trips. A fused GEMM+bias+GELU saves 2 kernel launches and 2x reads/writes of the output matrix. For a 4096x4096 output in FP16, that is 32 MB of memory traffic saved, which at 3.35 TB/s bandwidth takes ~10 microseconds. For small matrices in LLM decode, this epilogue fusion can improve end-to-end kernel time by 10-20%.

Mixed-Precision GEMM Configuration

CUTLASS excels at mixed-precision operations that cuBLAS does not directly support:

// Example 1: FP16 inputs, FP32 accumulation, FP16 output (standard)
using Gemm_FP16 = cutlass::gemm::device::Gemm<
    cutlass::half_t, cutlass::layout::RowMajor,     // A: FP16
    cutlass::half_t, cutlass::layout::ColumnMajor,   // B: FP16
    cutlass::half_t, cutlass::layout::RowMajor,      // C: FP16
    float,                                             // Accumulator: FP32
    cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<128, 128, 32>,
    cutlass::gemm::GemmShape<64, 64, 32>,
    cutlass::gemm::GemmShape<16, 8, 16>
>;

// Example 2: INT8 inputs, INT32 accumulation, FP16 output (W8A8 inference)
using Gemm_INT8 = cutlass::gemm::device::Gemm<
    int8_t, cutlass::layout::RowMajor,
    int8_t, cutlass::layout::ColumnMajor,
    cutlass::half_t, cutlass::layout::RowMajor,
    int32_t,  // INT32 accumulator
    cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<128, 128, 64>,  // Larger K tile for INT8
    cutlass::gemm::GemmShape<64, 64, 64>,
    cutlass::gemm::GemmShape<16, 8, 32>      // INT8 MMA: 16x8x32
>;

// Example 3: FP8 inputs, FP32 accumulation (Hopper FP8 GEMM)
using Gemm_FP8 = cutlass::gemm::device::Gemm<
    cutlass::float_e4m3_t, cutlass::layout::RowMajor,   // A: E4M3
    cutlass::float_e4m3_t, cutlass::layout::ColumnMajor, // B: E4M3
    cutlass::half_t, cutlass::layout::RowMajor,
    float,
    cutlass::arch::OpClassTensorOp, cutlass::arch::Sm90,
    cutlass::gemm::GemmShape<128, 256, 64>,
    cutlass::gemm::GemmShape<64, 128, 64>,
    cutlass::gemm::GemmShape<16, 8, 32>
>;

// Example 4: INT4 weights dequantized to FP16 for GEMM (W4A16)
// This requires a custom mainloop that dequantizes INT4 -> FP16 on the fly
// CUTLASS provides this through the "mixed input" GEMM:
using Gemm_W4A16 = cutlass::gemm::device::GemmUniversal<
    cutlass::half_t, cutlass::layout::RowMajor,     // A: FP16 activations
    cutlass::int4b_t, cutlass::layout::ColumnMajor,  // B: INT4 weights
    cutlass::half_t, cutlass::layout::RowMajor,
    float,
    cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80
>;
// Internally: loads INT4 weights, dequantizes to FP16 in shared memory,
// then performs the standard FP16 tensor core MMA

Launching and Profiling CUTLASS GEMMs

// Complete example: launching a CUTLASS GEMM

#include <cutlass/gemm/device/gemm.h>
#include <cutlass/util/host_tensor.h>

int main() {
    int M = 4096, N = 4096, K = 4096;

    // Define the GEMM type (FP16 with tensor cores)
    using Gemm = cutlass::gemm::device::Gemm<
        cutlass::half_t, cutlass::layout::RowMajor,
        cutlass::half_t, cutlass::layout::ColumnMajor,
        cutlass::half_t, cutlass::layout::RowMajor,
        float,
        cutlass::arch::OpClassTensorOp,
        cutlass::arch::Sm80,
        cutlass::gemm::GemmShape<128, 128, 32>,
        cutlass::gemm::GemmShape<64, 64, 32>,
        cutlass::gemm::GemmShape<16, 8, 16>
    >;

    // Allocate matrices
    cutlass::HostTensor<cutlass::half_t, cutlass::layout::RowMajor> A({M, K});
    cutlass::HostTensor<cutlass::half_t, cutlass::layout::ColumnMajor> B({K, N});
    cutlass::HostTensor<cutlass::half_t, cutlass::layout::RowMajor> C({M, N});

    A.sync_device();
    B.sync_device();
    C.sync_device();

    // Configure GEMM arguments
    Gemm::Arguments args(
        {M, N, K},               // Problem size
        {A.device_data(), K},     // A tensor ref
        {B.device_data(), K},     // B tensor ref (col-major, so stride = K)
        {C.device_data(), N},     // C tensor ref (source)
        {C.device_data(), N},     // D tensor ref (destination, can alias C)
        {1.0f, 0.0f}             // alpha, beta
    );

    // Instantiate and run
    Gemm gemm_op;
    cutlass::Status status = gemm_op.can_implement(args);
    if (status != cutlass::Status::kSuccess) {
        // Handle error: tile sizes don't work for this problem
        return -1;
    }

    // Query workspace size
    size_t workspace_size = Gemm::get_workspace_size(args);
    cutlass::DeviceAllocation<uint8_t> workspace(workspace_size);

    status = gemm_op.initialize(args, workspace.get());
    status = gemm_op();  // Launch!

    // Check for errors
    cudaDeviceSynchronize();
    if (status != cutlass::Status::kSuccess) {
        printf("GEMM failed: %s\n", cutlassGetStatusString(status));
    }

    return 0;
}

Performance Tuning with Tile Size Sweeps

// Systematic tile size sweep for optimal performance
// Different problem sizes favor different tile configurations

struct TileConfig {
    int tb_m, tb_n, tb_k;
    int warp_m, warp_n, warp_k;
    int stages;
    float achieved_tflops;
};

// Results of sweeping tile sizes for M=N=K=4096, FP16, H100:
//
// ThreadblockShape  | WarpShape     | Stages | TFLOPS | % of Peak
// 128x128x32        | 64x64x32      | 3      | 198    | 61%
// 128x256x32        | 64x128x32     | 3      | 256    | 79%
// 256x128x32        | 128x64x32     | 3      | 248    | 77%
// 128x256x64        | 64x128x64     | 4      | 298    | 92%  <-- Good
// 256x128x64        | 64x64x64      | 4      | 289    | 89%
// 128x128x64        | 64x64x64      | 5      | 275    | 85%
//
// Observations:
// 1. Larger N dimension in tile -> better memory coalescing for B
// 2. Larger K dimension -> fewer mainloop iterations, less overhead
// 3. More stages -> better latency hiding, but more shared memory
// 4. The optimal tile depends on the GEMM shape (M, N, K)

CUTLASS GEMM Throughput by Tile Configuration (M=N=K=4096, FP16, H100)

(TFLOPS)
128x128x32 (3 stages) 61% peak
198 TFLOPS
128x256x32 (3 stages) 79% peak
256 TFLOPS
128x256x64 (4 stages) 92% peak
298 TFLOPS
cuBLAS (reference) 94% peak
305 TFLOPS
H100 FP16 peak Theoretical
324 TFLOPS

Practical Use Cases in LLM Inference

Fused Attention Projection GEMM

In LLM inference, the QKV projection is a GEMM of shape [batchΓ—seq,3Γ—dmodel]=[batchΓ—seq,dmodel]Γ—[dmodel,3Γ—dmodel][\text{batch} \times \text{seq}, 3 \times d_{\text{model}}] = [\text{batch} \times \text{seq}, d_{\text{model}}] \times [d_{\text{model}}, 3 \times d_{\text{model}}]. Fusing Q, K, V into a single GEMM is more efficient than three separate GEMMs:

// Fused QKV projection: one GEMM instead of three
// Input:  X [BS*seq, d_model]
// Weight: W_qkv [d_model, 3*d_model]  (Q, K, V weights concatenated)
// Output: QKV [BS*seq, 3*d_model]

// CUTLASS GEMM with custom epilogue that splits Q, K, V
// and applies RoPE to Q, K
using FusedQKVGemm = cutlass::gemm::device::Gemm<
    cutlass::half_t, cutlass::layout::RowMajor,
    cutlass::half_t, cutlass::layout::ColumnMajor,
    cutlass::half_t, cutlass::layout::RowMajor,
    float,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<128, 256, 32>,  // Large N for 3*d_model
    cutlass::gemm::GemmShape<64, 128, 32>,
    cutlass::gemm::GemmShape<16, 8, 16>,
    RoPEEpilogue<cutlass::half_t, float>  // Custom epilogue applies RoPE
>;

// The RoPE epilogue:
// - Splits output into Q [BS*seq, d_model], K [BS*seq, d_kv], V [BS*seq, d_kv]
// - Applies rotary position embedding to Q and K
// - Writes Q, K, V to separate output buffers
// This saves one kernel launch and one read/write of Q and K tensors

Weight-Only Dequantization GEMM

For INT4 weight-only quantization (W4A16), CUTLASS provides mixed-input GEMM that dequantizes INT4 weights to FP16 on-the-fly in shared memory:

// W4A16 GEMM: FP16 activations * INT4 weights -> FP16 output
// The INT4 weights are stored in packed format (2 weights per byte)
// Dequantization: w_fp16 = (w_int4 - zero_point) * scale

// This is what vLLM's Marlin kernel and AWQ's GEMM kernel implement
// CUTLASS provides the building blocks:

// 1. Custom global memory loader that reads packed INT4
// 2. Shared memory dequantization stage
// 3. Standard FP16 tensor core MMA on dequantized values

// The dequantization in shared memory:
__device__ void dequant_int4_to_fp16(
    const uint8_t* packed_int4,  // Input: 2 INT4 values per byte
    half* output_fp16,            // Output: FP16 values
    half scale,
    half zero_point,
    int num_elements
) {
    for (int i = threadIdx.x; i < num_elements / 2; i += blockDim.x) {
        uint8_t packed = packed_int4[i];
        int4 lo = (packed & 0x0F) - 8;       // Low nibble, signed
        int4 hi = ((packed >> 4) & 0x0F) - 8; // High nibble, signed

        output_fp16[2*i]     = __hmul(__hsub(__int2half_rn(lo), zero_point), scale);
        output_fp16[2*i + 1] = __hmul(__hsub(__int2half_rn(hi), zero_point), scale);
    }
    __syncthreads();
}
πŸ“Š

CUTLASS Custom GEMM vs cuBLAS Performance (H100)

GEMM TypeMNKCUTLASS TFLOPScuBLAS TFLOPSRatio
FP16 standard 4096 4096 4096 298 305 97.7%
FP16 standard 1 4096 4096 0.8 0.9 88.9%
FP16+bias+GELU fused 4096 4096 4096 285 N/A Custom
INT8 W8A8 4096 4096 4096 580 595 97.5%
W4A16 (dequant) 4096 4096 4096 310 N/A Custom
FP8 E4M3 4096 4096 4096 580 590 98.3%

Debugging CUTLASS Compilation

CUTLASS templates produce notoriously long compilation errors. Common issues and solutions:

// Error: "No matching function for call to 'Gemm::Gemm'"
// Cause: Tile sizes are not divisible. WarpShape must divide ThreadblockShape.
// ThreadblockShape<128, 128, 32>, WarpShape<64, 64, 32>: OK (2x2 warps)
// ThreadblockShape<128, 128, 32>, WarpShape<32, 64, 32>: OK (4x2 warps)
// ThreadblockShape<128, 128, 32>, WarpShape<48, 64, 32>: ERROR (128/48 not integer)

// Error: "static_assert failed: Instruction shape must divide warp shape"
// InstructionShape must divide WarpShape along all dimensions
// WarpShape<64, 64, 32>, InstructionShape<16, 8, 16>: OK
// WarpShape<64, 64, 32>, InstructionShape<16, 8, 32>: OK
// WarpShape<64, 64, 32>, InstructionShape<16, 16, 16>: ERROR (64/16=4, OK, but check K)

// Error: Shared memory exceeds limit
// Reduce stages or tile size
// Check: stages * (TB_M * TB_K + TB_K * TB_N) * sizeof(element) <= 228 KB (H100)

// Performance issue: Low occupancy
// Use CUTLASS profiler to check:
// cutlass_profiler --operation=Gemm --m=4096 --n=4096 --k=4096 \
//   --A=f16:row --B=f16:column --C=f16:row \
//   --threadblock=128x256x64 --warp=64x128x64 --stages=4

Summary

CUTLASS provides the full hierarchical decomposition of GEMM as composable C++ templates: problem-level grid tiling, CTA-level shared memory staging, warp-level MMA dispatch, and instruction-level tensor core operations. This decomposition enables custom GEMMs (mixed-precision, fused epilogues, weight dequantization) that achieve 95-99% of cuBLAS performance while supporting operations cuBLAS cannot. The key configuration decisions are tile sizes (which determine shared memory usage and occupancy), pipeline stages (which determine latency hiding), and epilogue design (which determines fusion opportunities). For LLM inference, CUTLASS is the foundation for virtually all high-performance custom GEMM kernels including the Marlin W4A16 kernel, INT8 SmoothQuant GEMMs, and fused QKV projections.