Part of Series CUDA Kernel Engineering 5 of 32
1 CUDA Thread Hierarchy: Grids, Blocks, Warps, and the Execution Model That Determines Performance 2 Memory Coalescing: Why Access Patterns Determine 10x Performance Differences 3 Shared Memory and Bank Conflicts: 32 Banks, 4-Byte Width, and the Padding Trick 4 Warp Primitives: Shuffle, Vote, Match, and Cooperative Reduction Without Shared Memory 5 Tensor Cores: WMMA, MMA, and WGMMA — Matrix Multiply at Hardware Speed 6 Triton Kernel Development: Writing GPU Kernels in Python with Auto-Tuning 7 Kernel Fusion Patterns: Elementwise, Reduction, GEMM Epilogue, and Attention Fusion 8 Nsight Compute and Nsight Systems: The Complete GPU Profiling Workflow 9 CUDA Graphs: Capture, Replay, Memory Management, and Dynamic Shape Handling 10 Atomics and Advanced Reductions: Global Atomics, Warp Reductions, and Multi-Block Coordination 11 Occupancy Calculator: Registers, Shared Memory, Block Size, and Finding the Sweet Spot 12 Vectorized Loads: float4, int4, and 128-Bit Memory Transactions for Maximum Bandwidth 13 Cooperative Groups: Sub-Warp Tiles, Block Synchronization, and Grid-Level Cooperation 14 Dynamic Parallelism: Launching Kernels from Kernels and When It Actually Helps 15 CUDA Streams and Events: Concurrent Execution, Overlap, and Synchronization Patterns 16 Reduction Patterns: Sum, Max, Histogram — From Naive to Warp-Optimized 17 Parallel Scan and Prefix Sum: Blelloch Algorithm, Work-Efficient Implementation 18 Matrix Transpose: The Canonical CUDA Optimization Problem — From Naive to Bank-Conflict-Free 19 Writing a Custom Attention Kernel: From Naive to Tiled to FlashAttention-Style 20 Debugging CUDA: compute-sanitizer, cuda-gdb, Common Errors, and Race Condition Detection 21 CUTLASS GEMM Templates: Writing High-Performance Matrix Multiply with NVIDIA's Template Library 22 Persistent Kernels: Long-Running Thread Blocks for Continuous Inference Processing 23 Memory Access Pattern Analysis: From Roofline Model to Kernel Optimization Strategy 24 CUDA Graphs for LLM Inference: Eliminating Kernel Launch Overhead from First Principles 25 CUDA Kernel Fusion: Reducing Memory Traffic for Elementwise-Heavy Workloads 26 CUDA Kernel Optimization: A Systematic Guide from Roofline to Nsight 27 CUDA Streams: Overlapping PCIe Transfers with Compute (and When It Actually Helps) 28 CUDA Unified Memory: When It Helps, When It Hurts, and Grace Hopper 29 CUDA Warp Mastery: Scheduling, Divergence, Shuffles, Occupancy, and Profiling 30 eBPF for LLM Inference Profiling: Kernel-Level Observability 31 GPU Memory Profiling: Finding Leaks, Fragmentation, and Hidden Overhead 32 The Roofline Model for GPU Kernel Optimization: From First Principles to LLM Workload Analysis

An A100 delivers 19.5 TFLOPS on FP32 CUDA cores. The same chip delivers 312 TFLOPS using tensor cores for FP16 matrix multiplication — a 16x throughput multiplier. A single tensor core instruction computes a 16x8x16 matrix product across all 32 threads of a warp in one cycle, performing 256 FMAs per instruction where CUDA cores would require 256 separate instructions. Every GPU since Volta (2017) has included tensor cores, yet many production kernels still use CUDA cores for GEMM because the tensor core programming model is non-obvious. This post demystifies it.

This post covers the hardware architecture, the programming model at each abstraction level, the supported precision matrix, and a complete WMMA GEMM kernel that achieves meaningful tensor core utilization.

All benchmarks target A100-80GB SXM (CC 8.0, third-generation tensor cores) unless stated otherwise.

Why Tensor Cores Exist

Consider FP32 GEMM on CUDA cores. Each thread computes one fused multiply-add (FMA) per cycle. An A100 has 6912 FP32 CUDA cores at ~1.41 GHz, yielding:

FP32 CUDA core peak=6912×2×1.41×109=19.5 TFLOPS\text{FP32 CUDA core peak} = 6912 \times 2 \times 1.41 \times 10^9 = 19.5 \text{ TFLOPS}

Now consider the same computation on tensor cores. A single tensor core instruction performs a 4×4×44 \times 4 \times 4 matrix multiply-accumulate: 128 FMAs. The A100 has 432 tensor cores (4 per SM, 108 SMs). Each tensor core can execute one instruction per cycle on a 16×8×1616 \times 8 \times 16 shape:

FP16 tensor core peak=432×256×2×1.41×109312 TFLOPS\text{FP16 tensor core peak} = 432 \times 256 \times 2 \times 1.41 \times 10^9 \approx 312 \text{ TFLOPS}

The actual peak is 312 TFLOPS for FP16 accumulating into FP32, compared to 19.5 TFLOPS for FP32 on CUDA cores. That is a 16x throughput ratio.

📊

Peak Throughput: Tensor Cores vs CUDA Cores (A100)

PrecisionCUDA Core TFLOPSTensor Core TFLOPSSpeedup
FP32 19.5 19.5 (TF32 via TC) 1x (or 8x with TF32)
TF32 (19-bit) - 156 8x vs FP32 CUDA
FP16 / BF16 78 (FP16 CUDA) 312 4x vs FP16 CUDA
INT8 39 (INT8 CUDA) 624 16x vs INT8 CUDA
FP64 9.7 19.5 2x vs FP64 CUDA
Note: Tensor core throughput scales with lower precision. INT8 tensor cores deliver 624 TOPS (trillions of operations per second) on A100.

Supported Precisions Across Generations

The precision matrix has expanded with each generation:

📊

Tensor Core Precision Support by GPU Generation

Precision (A x B -> C)Volta (V100)Turing (T4)Ampere (A100)Ada (L40S)Hopper (H100)
FP16 x FP16 -> FP16 Yes Yes Yes Yes Yes
FP16 x FP16 -> FP32 Yes Yes Yes Yes Yes
BF16 x BF16 -> FP32 No No Yes Yes Yes
TF32 x TF32 -> FP32 No No Yes Yes Yes
INT8 x INT8 -> INT32 No Yes Yes Yes Yes
INT4 x INT4 -> INT32 No Yes Yes Yes Yes
INT1 x INT1 -> INT32 No Yes No No No
FP8 (E4M3) x FP8 -> FP32 No No No Yes Yes
FP8 (E5M2) x FP8 -> FP32 No No No Yes Yes
FP64 x FP64 -> FP64 No No Yes No Yes
Note: FP8 tensor cores (Hopper, Ada) are critical for LLM inference. TF32 (Ampere+) provides FP32-like accuracy at 8x FP32 speed by truncating the mantissa of FP32 inputs to 10 bits.

TF32: The Transparent Speedup

TF32 (TensorFloat-32) is a 19-bit format: 1 sign bit, 8 exponent bits (same as FP32), 10 mantissa bits (same as FP16). Tensor cores automatically truncate FP32 inputs to TF32 when enabled:

// TF32 is ON by default for tensor cores since CUDA 11.0
// cuBLAS uses TF32 for FP32 GEMMs automatically
// To disable (for exact FP32):
cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH);
// To explicitly enable:
cublasSetMathMode(handle, CUBLAS_TF32_TENSOR_OP_MATH);

The accuracy impact is small: TF32 has the same dynamic range as FP32 and sufficient mantissa precision for neural network training. Most deep learning workloads see zero accuracy degradation with TF32 enabled.

WMMA API: C++ Warp-Level Matrix Operations (Volta+)

WMMA (Warp-level Matrix Multiply-Accumulate) is the highest-level tensor core API, accessible from CUDA C++. It operates on fragments — warp-distributed matrix tiles stored across the 32 threads’ registers.

Fragment Types

#include <mma.h>
using namespace nvcuda;

// Fragment types for a 16x16x16 FP16 operation:
// A matrix: M x K = 16 x 16, stored as FP16
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> a_frag;

// B matrix: K x N = 16 x 16, stored as FP16
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> b_frag;

// Accumulator: M x N = 16 x 16, stored as FP32 (higher precision)
wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;

Each fragment is distributed across all 32 threads of the warp. Thread tt holds a subset of the matrix elements in its registers. The exact mapping is opaque (implementation-defined), which is why you must use WMMA APIs to load, store, and compute with fragments.

Supported WMMA Shapes

📊

WMMA Fragment Shapes (Ampere, CC 8.0)

PrecisionMNKAccumulator
FP16 16 16 16 FP16 or FP32
FP16 32 8 16 FP16 or FP32
FP16 8 32 16 FP16 or FP32
BF16 16 16 16 FP32
TF32 16 16 8 FP32
INT8 16 16 16 INT32
INT4 8 8 32 INT32
FP64 8 8 4 FP64
Note: The 16x16x16 FP16 shape is the most commonly used. Larger shapes (32x8x16) can be more efficient for specific matrix aspect ratios.

WMMA Operations

// 1. Initialize accumulator to zero
wmma::fill_fragment(c_frag, 0.0f);

// 2. Load A fragment from global/shared memory
// ptr: pointer to first element
// ldm: leading dimension (stride between rows/columns)
wmma::load_matrix_sync(a_frag, ptr_a, ldm_a);

// 3. Load B fragment
wmma::load_matrix_sync(b_frag, ptr_b, ldm_b);

// 4. Matrix multiply-accumulate: C += A * B
wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);

// 5. Store result
wmma::store_matrix_sync(ptr_c, c_frag, ldm_c, wmma::mem_row_major);

Each mma_sync call is warp-synchronous: all 32 threads in the warp must execute it together. The _sync suffix guarantees that all threads have completed before any thread proceeds.

Complete WMMA GEMM Implementation

This implementation computes C=A×BC = A \times B using WMMA tensor core operations with shared memory tiling.

#include <mma.h>
#include <cuda_fp16.h>

using namespace nvcuda;

// Tile dimensions
#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16

#define BLOCK_M 128  // Block tile rows
#define BLOCK_N 128  // Block tile cols
#define BLOCK_K 32   // Block tile K-depth

// Each warp computes a WMMA_M x WMMA_N tile
// Block has (BLOCK_M/WMMA_M) x (BLOCK_N/WMMA_N) = 8 x 8 = 64 warps
// 64 warps x 32 threads = 2048 threads (max per SM on Ampere)
// Use 256 threads per block instead, each warp handles multiple tiles

#define WARPS_PER_BLOCK 8    // 256 threads / 32
#define WARP_TILES_M 2       // Each warp computes 2 x 2 WMMA tiles
#define WARP_TILES_N 2

__global__ void wmma_gemm(const half* __restrict__ A,
                           const half* __restrict__ B,
                           float* __restrict__ C,
                           int M, int N, int K) {
    // Shared memory for A and B tiles
    __shared__ half As[BLOCK_K][BLOCK_M + 8];   // Padded to avoid bank conflicts
    __shared__ half Bs[BLOCK_K][BLOCK_N + 8];

    int warp_id = threadIdx.x / 32;
    int lane_id = threadIdx.x % 32;

    // Map warps to output tile positions
    int warp_row = (warp_id / (BLOCK_N / WMMA_N / WARP_TILES_N)) * WARP_TILES_M;
    int warp_col = (warp_id % (BLOCK_N / WMMA_N / WARP_TILES_N)) * WARP_TILES_N;

    // Declare accumulator fragments (one per WMMA tile this warp computes)
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float>
        c_frag[WARP_TILES_M][WARP_TILES_N];

    // Initialize accumulators to zero
    #pragma unroll
    for (int wm = 0; wm < WARP_TILES_M; wm++) {
        #pragma unroll
        for (int wn = 0; wn < WARP_TILES_N; wn++) {
            wmma::fill_fragment(c_frag[wm][wn], 0.0f);
        }
    }

    // Loop over K dimension in BLOCK_K chunks
    for (int bk = 0; bk < K; bk += BLOCK_K) {
        // Collaborative load: all threads in block load A and B tiles
        // into shared memory. Each thread loads multiple elements.
        int elements_per_thread_A = (BLOCK_K * BLOCK_M) / 256;
        for (int i = 0; i < elements_per_thread_A; i++) {
            int idx = threadIdx.x + i * 256;
            int k_idx = idx / BLOCK_M;
            int m_idx = idx % BLOCK_M;
            int global_m = blockIdx.y * BLOCK_M + m_idx;
            int global_k = bk + k_idx;

            half val = (global_m < M && global_k < K)
                       ? A[global_m * K + global_k]
                       : __float2half(0.0f);
            As[k_idx][m_idx] = val;
        }

        int elements_per_thread_B = (BLOCK_K * BLOCK_N) / 256;
        for (int i = 0; i < elements_per_thread_B; i++) {
            int idx = threadIdx.x + i * 256;
            int k_idx = idx / BLOCK_N;
            int n_idx = idx % BLOCK_N;
            int global_k = bk + k_idx;
            int global_n = blockIdx.x * BLOCK_N + n_idx;

            half val = (global_k < K && global_n < N)
                       ? B[global_k * N + global_n]
                       : __float2half(0.0f);
            Bs[k_idx][n_idx] = val;
        }

        __syncthreads();

        // Compute: iterate over WMMA_K chunks within the BLOCK_K tile
        for (int k = 0; k < BLOCK_K; k += WMMA_K) {
            #pragma unroll
            for (int wm = 0; wm < WARP_TILES_M; wm++) {
                #pragma unroll
                for (int wn = 0; wn < WARP_TILES_N; wn++) {
                    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N,
                                   WMMA_K, half, wmma::col_major> a_frag;
                    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N,
                                   WMMA_K, half, wmma::row_major> b_frag;

                    // Load A fragment from shared memory
                    int a_row = (warp_row + wm) * WMMA_M;
                    wmma::load_matrix_sync(
                        a_frag,
                        &As[k][a_row],
                        BLOCK_M + 8  // leading dimension (with padding)
                    );

                    // Load B fragment from shared memory
                    int b_col = (warp_col + wn) * WMMA_N;
                    wmma::load_matrix_sync(
                        b_frag,
                        &Bs[k][b_col],
                        BLOCK_N + 8
                    );

                    // Matrix multiply-accumulate
                    wmma::mma_sync(c_frag[wm][wn], a_frag, b_frag,
                                   c_frag[wm][wn]);
                }
            }
        }

        __syncthreads();
    }

    // Store accumulator fragments to global memory
    #pragma unroll
    for (int wm = 0; wm < WARP_TILES_M; wm++) {
        #pragma unroll
        for (int wn = 0; wn < WARP_TILES_N; wn++) {
            int c_row = blockIdx.y * BLOCK_M + (warp_row + wm) * WMMA_M;
            int c_col = blockIdx.x * BLOCK_N + (warp_col + wn) * WMMA_N;

            if (c_row < M && c_col < N) {
                wmma::store_matrix_sync(
                    &C[c_row * N + c_col],
                    c_frag[wm][wn],
                    N,  // leading dimension
                    wmma::mem_row_major
                );
            }
        }
    }
}

// Launch
void launch_wmma_gemm(const half* A, const half* B, float* C,
                       int M, int N, int K) {
    dim3 block(256);  // 8 warps
    dim3 grid((N + BLOCK_N - 1) / BLOCK_N,
              (M + BLOCK_M - 1) / BLOCK_M);

    wmma_gemm<<<grid, block>>>(A, B, C, M, N, K);
}

WMMA GEMM Performance

📊

WMMA GEMM Performance (A100, M=N=K=4096, FP16 input, FP32 accum)

ImplementationTFLOPS% of Peak (312 TFLOPS)Notes
CUDA core FP32 (optimized) 7.1 2.3% (of TC peak) No tensor cores
CUDA core FP16 (optimized) 14.8 4.7% No tensor cores
WMMA naive (no shared mem) 28 9.0% Global memory bound
WMMA tiled (above implementation) 95 30.4% Shared memory tiling
WMMA tiled + double buffer 135 43.3% Overlap load/compute
CUTLASS (optimized WMMA) 240 76.9% Multi-stage pipeline, swizzled smem
cuBLAS FP16 290 93.0% Fully optimized
Note: The gap between our WMMA kernel and cuBLAS is due to multi-stage pipelining, epilogue fusion, split-K, and micro-architectural tuning. CUTLASS provides a good middle ground.

GEMM Throughput: CUDA Cores vs Tensor Cores

(TFLOPS)
FP32 CUDA cores
7.1 TFLOPS
FP16 CUDA cores
14.8 TFLOPS
WMMA naive
28 TFLOPS
WMMA tiled
95 TFLOPS
WMMA + double buf
135 TFLOPS
CUTLASS
240 TFLOPS
cuBLAS 93% peak
290 TFLOPS

MMA PTX: Fine-Grained Control (Ampere+)

The MMA PTX interface provides lower-level access to tensor cores through inline PTX assembly. It exposes more shapes, explicit register mapping, and control over data layout that WMMA abstracts away.

MMA Instruction Format

// PTX mma instruction for m16n8k16 FP16
// This is the native Ampere shape (one HMMA instruction)
__device__ void mma_m16n8k16_fp16(
    float* d,           // 4 output registers (FP32 accumulator)
    const uint32_t* a,  // 4 input registers (A matrix, FP16 packed)
    const uint32_t* b,  // 2 input registers (B matrix, FP16 packed)
    const float* c      // 4 input registers (C accumulator)
) {
    asm volatile(
        "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
        "{%0, %1, %2, %3}, "   // D (output)
        "{%4, %5, %6, %7}, "   // A (input)
        "{%8, %9}, "           // B (input)
        "{%10, %11, %12, %13};\n"  // C (accumulator)
        : "=f"(d[0]), "=f"(d[1]), "=f"(d[2]), "=f"(d[3])
        : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
          "r"(b[0]), "r"(b[1]),
          "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])
    );
}

MMA Shapes on Ampere

The native MMA shapes on Ampere differ from WMMA shapes:

mma.sync.aligned.m16n8k8    // FP16, BF16
mma.sync.aligned.m16n8k16   // FP16, BF16
mma.sync.aligned.m16n8k4    // TF32
mma.sync.aligned.m16n8k8    // TF32
mma.sync.aligned.m16n8k16   // INT8
mma.sync.aligned.m16n8k32   // INT8
mma.sync.aligned.m16n8k32   // INT4
mma.sync.aligned.m16n8k64   // INT4
mma.sync.aligned.m8n8k4     // FP64

The asymmetric m16n8kX shapes reflect the hardware: each tensor core is physically a 16x8 matrix unit, and wider operations are composed from multiple invocations.

Register Layout for m16n8k16

Each thread in the warp holds a specific set of matrix elements. For the m16n8k16 FP16 operation:

A matrix (16 x 16, FP16, row-major):
  Thread t holds 8 elements of A, packed into 4 uint32_t registers
  Each uint32_t holds 2 FP16 values

B matrix (16 x 8, FP16, col-major):
  Thread t holds 4 elements of B, packed into 2 uint32_t registers

C/D accumulator (16 x 8, FP32):
  Thread t holds 4 FP32 elements in 4 float registers
💡 WMMA vs MMA: When to Use Which

Use WMMA for prototyping and when portability across GPU generations matters. WMMA code compiles for Volta through Hopper. Use MMA PTX when you need maximum performance and are targeting a specific architecture. MMA gives you explicit control over register layout, enabling tighter integration with shared memory load patterns and custom epilogues. CUTLASS uses MMA PTX internally.

WGMMA: Warp-Group Level Operations (Hopper)

Hopper introduces Warp-Group Matrix Multiply-Accumulate (WGMMA), which operates across a warp group of 128 threads (4 warps). This is a fundamental architectural shift.

Why Warp Groups?

On Hopper, the SM is redesigned with 128-thread warp groups as the primary scheduling unit. WGMMA leverages this to:

  1. Compute larger tiles per instruction: up to 64x256x16 in a single WGMMA operation
  2. Access shared memory asynchronously: one operand can come from shared memory via the Tensor Memory Accelerator (TMA), overlapping the load with the multiply
  3. Pipeline deeply: multiple WGMMA operations can be in flight simultaneously

WGMMA Shapes

wgmma.mma_async.sync.aligned.shape.dtype.dtype.dtype
  Shapes: m64n8k16, m64n16k16, m64n24k16, ..., m64n256k16 (FP16/BF16)
          m64n8k8, m64n16k8, ..., m64n256k8 (TF32)
          m64n8k32, ..., m64n256k32 (FP8, INT8)

All WGMMA shapes have M=64M = 64 (because 128 threads, each contributing to 64 rows through the 4-warp group). The NN dimension varies from 8 to 256. The KK dimension depends on precision.

WGMMA Programming Model

WGMMA is inherently asynchronous. The instruction initiates the computation but does not wait for it to complete. You must explicitly fence and wait:

// Pseudocode for WGMMA usage (actual API is PTX)
// 1. Load B matrix into shared memory via TMA
// 2. Issue WGMMA: A from registers, B from shared memory
// 3. Insert fence
// 4. Wait for completion
// 5. Use results

// PTX:
// wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16
//   {d0..d63}, {a0..a7}, desc_b, scale_d, imm_scale_a, imm_scale_b, ...
// wgmma.wait_group.sync.aligned N  // wait for N outstanding groups

TMA: Tensor Memory Accelerator

WGMMA’s power comes from tight integration with TMA (Hopper only). TMA is a hardware unit that:

  1. Copies multi-dimensional tiles from global to shared memory (or vice versa) in one instruction
  2. Handles address computation in hardware — no per-thread index math
  3. Supports swizzled layouts to eliminate bank conflicts
  4. Is asynchronous: the copy runs independently while threads compute
// Create a TMA descriptor (host code)
CUtensorMap tensor_map;
cuTensorMapEncodeTiled(
    &tensor_map,
    CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
    2,                    // 2D tensor
    (void*)device_ptr,    // base address
    dims,                 // tensor dimensions
    strides,              // tensor strides
    box_dims,             // tile dimensions to copy
    element_strides,      // element strides within tile
    CU_TENSOR_MAP_INTERLEAVE_NONE,
    CU_TENSOR_MAP_SWIZZLE_128B,    // 128-byte swizzle for bank-conflict-free
    CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
    CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE
);

// Device code: one instruction copies an entire tile
// cp.async.bulk.tensor.2d.shared::cluster.global.tile
//   [smem_ptr], [tensor_map, {coord_x, coord_y}], mbar

Structured Sparsity: 2:4 Pattern (Ampere+)

Ampere tensor cores support structured sparsity: if a matrix has at most 2 nonzero values per group of 4 consecutive elements (2:4 sparsity), the tensor core can skip the zero multiplications, effectively doubling throughput.

Dense:     [1.5, 0.0, -0.3, 2.1, 0.7, -1.2, 0.0, 0.0]
2:4 Sparse: [1.5, 0.0, -0.3, 0.0, 0.7, -1.2, 0.0, 0.0]
             ^^^^^^^^^^^^^^^^^^  ^^^^^^^^^^^^^^^^^^
             2 nonzeros per 4    2 nonzeros per 4

Metadata: a 2-bit index per nonzero, indicating which of the 4 positions
          it occupies. Compressed storage: only the nonzeros + metadata.

Sparse GEMM Throughput

📊

Dense vs Sparse Tensor Core Throughput (A100)

PrecisionDense TFLOPS2:4 Sparse TFLOPSSpeedup
FP16 312 624 2x
BF16 312 624 2x
TF32 156 312 2x
INT8 624 1248 2x
Note: Structured sparsity doubles throughput at every precision. The 2:4 pattern must be enforced during training via pruning.
⚠️ Sparsity Must Be Structural, Not Arbitrary

The 2:4 pattern requirement is strict: exactly 2 of every 4 consecutive elements must be zero, and this must be true for every group across the entire matrix. Random sparsity (even at 50%) does not qualify. You must use structured pruning during training (e.g., NVIDIA’s ASP — Automatic Sparsity) to produce conforming weight matrices.

Precision Formats in Detail

FP16 (IEEE 754 Half-Precision)

  • 1 sign + 5 exponent + 10 mantissa = 16 bits
  • Range: ±6.55×104\pm 6.55 \times 10^4, precision: 3\sim 3 decimal digits
  • Standard format for inference and mixed-precision training

BF16 (Brain Floating Point)

  • 1 sign + 8 exponent + 7 mantissa = 16 bits
  • Range: ±3.4×1038\pm 3.4 \times 10^{38} (same as FP32), precision: 2\sim 2 decimal digits
  • Preferred for training because its range matches FP32, avoiding overflow in loss scaling

TF32 (TensorFloat-32)

  • 1 sign + 8 exponent + 10 mantissa = 19 bits
  • Not a storage format — only exists inside tensor cores
  • Input FP32 values are truncated to TF32 on-the-fly

FP8 (E4M3 and E5M2)

  • E4M3: 1 sign + 4 exponent + 3 mantissa. Range: ±448\pm 448, precision: 1.5\sim 1.5 digits
  • E5M2: 1 sign + 5 exponent + 2 mantissa. Range: ±57344\pm 57344, precision: 1\sim 1 digit
  • E4M3 for weights and activations (more precision needed)
  • E5M2 for gradients (more range needed)
// FP8 GEMM via cuBLAS (Hopper)
cublasLtMatmul(
    handle, operationDesc,
    &alpha,
    A_desc,  // FP8 E4M3
    B_desc,  // FP8 E4M3
    &beta,
    C_desc,  // BF16 or FP32 accumulator
    D_desc,  // BF16 or FP32 output
    ...
);

CUTLASS: The Template Library

CUTLASS (CUDA Templates for Linear Algebra Subroutines) is NVIDIA’s open-source template library that provides optimized building blocks for GEMM at every level. It abstracts the complexity of MMA/WGMMA while still allowing fine-grained control.

#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>

// Define a GEMM operation type
using Gemm = cutlass::gemm::device::Gemm<
    cutlass::half_t,              // A element type
    cutlass::layout::RowMajor,    // A layout
    cutlass::half_t,              // B element type
    cutlass::layout::ColumnMajor, // B layout
    float,                        // C element type
    cutlass::layout::RowMajor,    // C layout
    float,                        // Accumulator type
    cutlass::arch::OpClassTensorOp,  // Use tensor cores
    cutlass::arch::Sm80,          // Target Ampere
    cutlass::gemm::GemmShape<128, 128, 32>,  // Thread block tile
    cutlass::gemm::GemmShape<64, 64, 32>,    // Warp tile
    cutlass::gemm::GemmShape<16, 8, 16>      // MMA instruction shape
>;

// Launch
Gemm gemm_op;
Gemm::Arguments args(
    {M, N, K},      // problem size
    {A, K},          // A tensor ref
    {B, N},          // B tensor ref
    {C, N},          // C tensor ref
    {D, N},          // D tensor ref
    {1.0f, 0.0f}    // alpha, beta
);
gemm_op(args);

CUTLASS GEMM performance is typically within 5-10% of cuBLAS for standard shapes and often matches or exceeds cuBLAS for non-standard shapes or fused epilogues.

Practical Guidelines

When to Use Tensor Cores

  1. Always for GEMM: any matrix multiply with M,N,K16M, N, K \geq 16 benefits from tensor cores
  2. Convolution: convolutions lowered to GEMM (implicit or explicit) use tensor cores via cuDNN
  3. Attention: FlashAttention uses tensor cores for the QK and PV matmuls
  4. Reductions/elementwise: tensor cores do NOT help — these are not matrix operations

Common Pitfalls

// Pitfall 1: Matrix dimensions not multiples of WMMA tile size
// WMMA requires M, N, K to be multiples of the tile dimensions
// Solution: pad matrices to the nearest multiple
int M_padded = ((M + WMMA_M - 1) / WMMA_M) * WMMA_M;

// Pitfall 2: Wrong layout specification
// A must be row-major or col-major as declared in the fragment type
// Mismatched layout produces silently wrong results

// Pitfall 3: Accumulator overflow with FP16 accumulator
// FP16 max value is 65504. If your partial sums exceed this, use FP32 accum
wmma::fragment<wmma::accumulator, 16, 16, 16, float> c_frag;  // FP32 accum

// Pitfall 4: Not checking alignment
// WMMA load_matrix_sync requires the base pointer to be aligned
// to 16 bytes (128 bits) for FP16 data
// cudaMalloc guarantees 256-byte alignment, so device allocations are safe
ℹ️ Series Navigation

This is Part 5 of the CUDA Kernel Engineering series. Part 6 covers Triton kernel development — writing GPU kernels in Python with automatic tuning, and when Triton is (and is not) a suitable alternative to CUDA.