Part of Series CUDA Kernel Engineering 7 of 32
1 CUDA Thread Hierarchy: Grids, Blocks, Warps, and the Execution Model That Determines Performance 2 Memory Coalescing: Why Access Patterns Determine 10x Performance Differences 3 Shared Memory and Bank Conflicts: 32 Banks, 4-Byte Width, and the Padding Trick 4 Warp Primitives: Shuffle, Vote, Match, and Cooperative Reduction Without Shared Memory 5 Tensor Cores: WMMA, MMA, and WGMMA — Matrix Multiply at Hardware Speed 6 Triton Kernel Development: Writing GPU Kernels in Python with Auto-Tuning 7 Kernel Fusion Patterns: Elementwise, Reduction, GEMM Epilogue, and Attention Fusion 8 Nsight Compute and Nsight Systems: The Complete GPU Profiling Workflow 9 CUDA Graphs: Capture, Replay, Memory Management, and Dynamic Shape Handling 10 Atomics and Advanced Reductions: Global Atomics, Warp Reductions, and Multi-Block Coordination 11 Occupancy Calculator: Registers, Shared Memory, Block Size, and Finding the Sweet Spot 12 Vectorized Loads: float4, int4, and 128-Bit Memory Transactions for Maximum Bandwidth 13 Cooperative Groups: Sub-Warp Tiles, Block Synchronization, and Grid-Level Cooperation 14 Dynamic Parallelism: Launching Kernels from Kernels and When It Actually Helps 15 CUDA Streams and Events: Concurrent Execution, Overlap, and Synchronization Patterns 16 Reduction Patterns: Sum, Max, Histogram — From Naive to Warp-Optimized 17 Parallel Scan and Prefix Sum: Blelloch Algorithm, Work-Efficient Implementation 18 Matrix Transpose: The Canonical CUDA Optimization Problem — From Naive to Bank-Conflict-Free 19 Writing a Custom Attention Kernel: From Naive to Tiled to FlashAttention-Style 20 Debugging CUDA: compute-sanitizer, cuda-gdb, Common Errors, and Race Condition Detection 21 CUTLASS GEMM Templates: Writing High-Performance Matrix Multiply with NVIDIA's Template Library 22 Persistent Kernels: Long-Running Thread Blocks for Continuous Inference Processing 23 Memory Access Pattern Analysis: From Roofline Model to Kernel Optimization Strategy 24 CUDA Graphs for LLM Inference: Eliminating Kernel Launch Overhead from First Principles 25 CUDA Kernel Fusion: Reducing Memory Traffic for Elementwise-Heavy Workloads 26 CUDA Kernel Optimization: A Systematic Guide from Roofline to Nsight 27 CUDA Streams: Overlapping PCIe Transfers with Compute (and When It Actually Helps) 28 CUDA Unified Memory: When It Helps, When It Hurts, and Grace Hopper 29 CUDA Warp Mastery: Scheduling, Divergence, Shuffles, Occupancy, and Profiling 30 eBPF for LLM Inference Profiling: Kernel-Level Observability 31 GPU Memory Profiling: Finding Leaks, Fragmentation, and Hidden Overhead 32 The Roofline Model for GPU Kernel Optimization: From First Principles to LLM Workload Analysis

A transformer MLP layer with separate kernels for GEMM, bias, GELU, and dropout executes in 340 microseconds on an H100. Fuse bias+GELU into the GEMM epilogue and time drops to 280 microseconds. Fuse dropout into the same kernel and time drops to 240 microseconds — a 30% speedup from eliminating two kernel launches and two HBM round-trips. The arithmetic is unchanged. The difference is memory traffic: the unfused version writes 32 MB of intermediate results to HBM after bias and 32 MB after GELU, then reads them back. The fused version writes once. HBM traffic drops from 128 MB to 64 MB.

Kernel fusion eliminates kernel launch overhead and memory round-trips by combining multiple operations into a single kernel. The fused kernel reads the input once, applies all operations in registers, and writes the final output once. This post covers four fusion patterns that appear everywhere in LLM inference: elementwise fusion, reduction fusion, GEMM epilogue fusion, and attention fusion.

Why Fusion Matters: The Bandwidth Wall

Memory Traffic Dominates Kernel Time

For elementwise operations (add, multiply, activation functions, dropout), the arithmetic intensity is O(1) — one or two FLOPs per element loaded. The roofline model shows these operations are deep in the memory-bandwidth-bound regime:

import torch
import time

def measure_kernel_overhead():
    """Measure the cost of separate vs fused operations."""
    device = 'cuda'
    M, N = 2048, 4096
    dtype = torch.float16

    x = torch.randn(M, N, device=device, dtype=dtype)
    bias = torch.randn(N, device=device, dtype=dtype)
    dropout_mask = torch.bernoulli(torch.full((M, N), 0.9,
                                   device=device)).to(dtype)

    # Separate operations
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(1000):
        y = x + bias         # Kernel 1: bias add
        y = torch.nn.functional.gelu(y)  # Kernel 2: GELU
        y = y * dropout_mask  # Kernel 3: dropout
    torch.cuda.synchronize()
    t_separate = (time.perf_counter() - start) / 1000

    # Fused (using torch.compile)
    @torch.compile
    def fused_bias_gelu_dropout(x, bias, mask):
        y = x + bias
        y = torch.nn.functional.gelu(y)
        y = y * mask
        return y

    # Warmup
    for _ in range(10):
        fused_bias_gelu_dropout(x, bias, dropout_mask)

    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(1000):
        y = fused_bias_gelu_dropout(x, bias, dropout_mask)
    torch.cuda.synchronize()
    t_fused = (time.perf_counter() - start) / 1000

    print(f"Separate kernels: {t_separate*1e6:.1f} us")
    print(f"Fused kernel:     {t_fused*1e6:.1f} us")
    print(f"Speedup:          {t_separate/t_fused:.2f}x")

measure_kernel_overhead()

Quantifying Memory Traffic

def memory_traffic_analysis(M=2048, N=4096, dtype_bytes=2):
    """Calculate memory traffic for separate vs fused operations."""
    tensor_bytes = M * N * dtype_bytes

    # Separate: bias_add + gelu + dropout
    # Each op: read input + write output = 2 * tensor_bytes
    # bias_add also reads bias (N * dtype_bytes, negligible)
    separate_traffic = 3 * 2 * tensor_bytes  # 3 ops, each read+write

    # Fused: read input once, write output once
    fused_traffic = 2 * tensor_bytes

    print(f"Tensor size: {tensor_bytes / 1e6:.1f} MB")
    print(f"Separate traffic: {separate_traffic / 1e6:.1f} MB "
          f"(3 read + 3 write)")
    print(f"Fused traffic:    {fused_traffic / 1e6:.1f} MB "
          f"(1 read + 1 write)")
    print(f"Traffic reduction: {separate_traffic / fused_traffic:.1f}x")

    # Time estimate on H100 (3350 GB/s)
    hbm_bw = 3350e9  # bytes/sec
    t_separate = separate_traffic / hbm_bw * 1e6  # microseconds
    t_fused = fused_traffic / hbm_bw * 1e6

    print(f"Estimated time (H100): separate={t_separate:.1f} us, "
          f"fused={t_fused:.1f} us")

memory_traffic_analysis()

Pattern 1: Elementwise Fusion

The Simplest Fusion

Elementwise operations apply a function independently to each element (or corresponding elements from multiple tensors). They can always be fused because there are no inter-element dependencies.

// Unfused: three separate kernels
__global__ void bias_add_kernel(float* out, const float* in,
                                 const float* bias, int N, int total) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total) {
        out[idx] = in[idx] + bias[idx % N];
    }
}

__global__ void gelu_kernel(float* out, const float* in, int total) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total) {
        float x = in[idx];
        // GELU approximation
        out[idx] = 0.5f * x * (1.0f + tanhf(0.7978845608f *
                   (x + 0.044715f * x * x * x)));
    }
}

__global__ void dropout_kernel(float* out, const float* in,
                                const float* mask, float scale, int total) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total) {
        out[idx] = in[idx] * mask[idx] * scale;
    }
}

// Fused: single kernel
__global__ void fused_bias_gelu_dropout_kernel(
    float* out,
    const float* in,
    const float* bias,
    const float* dropout_mask,
    float dropout_scale,
    int N,
    int total
) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < total) {
        // Load input (one HBM read)
        float x = in[idx];

        // Bias add (bias is small, likely in L2/L1)
        x += bias[idx % N];

        // GELU (compute in registers)
        x = 0.5f * x * (1.0f + tanhf(0.7978845608f *
            (x + 0.044715f * x * x * x)));

        // Dropout (mask load + multiply)
        x *= dropout_mask[idx] * dropout_scale;

        // Store output (one HBM write)
        out[idx] = x;
    }
}

Vectorized Fused Kernel

For half-precision (FP16), use vectorized loads/stores with half2 or float4 to maximize memory bandwidth utilization:

#include <cuda_fp16.h>

__global__ void fused_bias_gelu_dropout_fp16(
    half* __restrict__ out,
    const half* __restrict__ in,
    const half* __restrict__ bias,
    const uint8_t* __restrict__ dropout_mask,
    half dropout_scale,
    int N,
    int total
) {
    // Process 8 elements per thread (4 half2 loads)
    int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 8;
    if (idx + 7 >= total) return;

    // Vectorized load: read 8 halfs = 16 bytes
    float4 in_vec = *reinterpret_cast<const float4*>(&in[idx]);
    half2* in_h2 = reinterpret_cast<half2*>(&in_vec);

    // Load bias (8 elements)
    int bias_start = idx % N;
    float4 bias_vec;
    if (bias_start + 7 < N) {
        bias_vec = *reinterpret_cast<const float4*>(&bias[bias_start]);
    } else {
        // Handle wrap-around (rare for typical dimensions)
        half bias_buf[8];
        for (int i = 0; i < 8; i++) {
            bias_buf[i] = bias[(idx + i) % N];
        }
        bias_vec = *reinterpret_cast<float4*>(bias_buf);
    }
    half2* bias_h2 = reinterpret_cast<half2*>(&bias_vec);

    // Load dropout mask
    uint8_t mask_byte = dropout_mask[idx / 8];

    // Process 4 half2 pairs
    half2 scale_h2 = __half2half2(dropout_scale);
    float4 out_vec;
    half2* out_h2 = reinterpret_cast<half2*>(&out_vec);

    #pragma unroll
    for (int i = 0; i < 4; i++) {
        // Bias add
        half2 val = __hadd2(in_h2[i], bias_h2[i]);

        // GELU approximation in FP16
        float2 f = __half22float2(val);
        f.x = 0.5f * f.x * (1.0f + tanhf(0.7978845608f *
              (f.x + 0.044715f * f.x * f.x * f.x)));
        f.y = 0.5f * f.y * (1.0f + tanhf(0.7978845608f *
              (f.y + 0.044715f * f.y * f.y * f.y)));
        val = __float22half2_rn(f);

        // Dropout
        val = __hmul2(val, scale_h2);

        out_h2[i] = val;
    }

    // Vectorized store: write 8 halfs = 16 bytes
    *reinterpret_cast<float4*>(&out[idx]) = out_vec;
}
Vectorized Loads Are Essential

On A100/H100, a single thread issuing 2-byte half loads achieves only ~40% of peak HBM bandwidth. Using float4 (16-byte) loads pushes this to >90%. The fused kernel must use vectorized loads to actually realize the theoretical bandwidth savings from fusion.

Pattern 2: Reduction Fusion

LayerNorm as a Single Kernel

LayerNorm computes mean, variance, normalize, scale, and shift — five logical operations. Unfused, this requires multiple kernel launches and multiple HBM round-trips:

// Unfused LayerNorm: multiple kernels
// Kernel 1: Compute mean (reduction over hidden dim)
// Kernel 2: Compute variance (reduction over hidden dim)
// Kernel 3: Normalize + scale + shift (elementwise)
// Total: 3 kernel launches, ~5 HBM read/writes of the full tensor

// Fused LayerNorm: single kernel per row
__global__ void fused_rmsnorm_kernel(
    float* __restrict__ out,
    const float* __restrict__ in,
    const float* __restrict__ weight,
    int hidden_dim,
    float eps
) {
    // Each block processes one row (one token)
    int row = blockIdx.x;
    const float* row_in = in + row * hidden_dim;
    float* row_out = out + row * hidden_dim;

    // Step 1: Compute sum of squares (warp-level reduction)
    float sum_sq = 0.0f;
    for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
        float val = row_in[i];
        sum_sq += val * val;
    }

    // Warp-level reduction
    for (int offset = warpSize / 2; offset > 0; offset /= 2) {
        sum_sq += __shfl_xor_sync(0xFFFFFFFF, sum_sq, offset);
    }

    // Block-level reduction using shared memory
    __shared__ float warp_sums[32];
    int warp_id = threadIdx.x / warpSize;
    int lane = threadIdx.x % warpSize;

    if (lane == 0) {
        warp_sums[warp_id] = sum_sq;
    }
    __syncthreads();

    // First warp reduces across warps
    if (warp_id == 0) {
        sum_sq = (lane < blockDim.x / warpSize) ?
                  warp_sums[lane] : 0.0f;
        for (int offset = warpSize / 2; offset > 0; offset /= 2) {
            sum_sq += __shfl_xor_sync(0xFFFFFFFF, sum_sq, offset);
        }
    }

    // Broadcast the RMS value
    __shared__ float rms_inv;
    if (threadIdx.x == 0) {
        rms_inv = rsqrtf(sum_sq / hidden_dim + eps);
    }
    __syncthreads();

    // Step 2: Normalize and apply weight
    // Second pass over the data (read from HBM or L2 cache)
    for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
        row_out[i] = row_in[i] * rms_inv * weight[i];
    }
}

Why Reduction Fusion Is Harder

Unlike elementwise fusion, reductions have inter-element dependencies: computing the mean requires reading all elements. This forces a two-pass pattern:

  1. First pass: Read all elements, compute reduction (mean, variance, sum of squares)
  2. Second pass: Read all elements again, apply the normalized transformation

The fused kernel still reads the data twice from memory, but:

  • Both passes are in the same kernel (one launch instead of three)
  • On the second pass, the data may be in L2 cache (if the tensor fits)
  • The intermediate mean/variance are in shared memory/registers, not HBM
def reduction_fusion_traffic(hidden_dim=4096, batch_tokens=2048,
                              dtype_bytes=2):
    """Compare memory traffic for fused vs unfused LayerNorm."""
    tensor_bytes = batch_tokens * hidden_dim * dtype_bytes

    # Unfused: 3 kernel launches
    # K1 (mean): read input, write per-row means
    # K2 (var): read input + means, write per-row vars
    # K3 (norm): read input + means + vars + weight, write output
    unfused_traffic = 3 * tensor_bytes + 3 * tensor_bytes  # ~6 reads + writes

    # Fused: 1 kernel, 2 passes over data
    # Pass 1: read input (may stay in L2)
    # Pass 2: read input (from L2 if fits), read weight, write output
    fused_traffic_cold = 2 * tensor_bytes + tensor_bytes  # 2 reads + 1 write
    fused_traffic_l2 = tensor_bytes + tensor_bytes  # 1 HBM read + 1 write (2nd from L2)

    print(f"Unfused traffic: {unfused_traffic/1e6:.1f} MB")
    print(f"Fused (cold):    {fused_traffic_cold/1e6:.1f} MB")
    print(f"Fused (L2 hit):  {fused_traffic_l2/1e6:.1f} MB")

reduction_fusion_traffic()

Pattern 3: GEMM Epilogue Fusion

Fusing Operations After Matrix Multiply

The most impactful fusion in transformer inference is fusing the bias, activation, and possibly residual addition into the GEMM epilogue. cuBLAS and CUTLASS support this natively:

// CUTLASS GEMM with fused epilogue
// Instead of: Y = A @ B; Y = Y + bias; Y = gelu(Y)
// Computes:   Y = gelu(A @ B + bias) in one kernel

#include <cutlass/gemm/device/gemm_universal.h>
#include <cutlass/epilogue/thread/linear_combination_gelu.h>

// Define the GEMM with GELU epilogue
using GemmGelu = cutlass::gemm::device::GemmUniversal<
    cutlass::half_t,                    // Element A
    cutlass::layout::RowMajor,          // Layout A
    cutlass::half_t,                    // Element B
    cutlass::layout::ColumnMajor,       // Layout B
    cutlass::half_t,                    // Element C/D
    cutlass::layout::RowMajor,          // Layout C/D
    float,                              // Accumulator
    cutlass::arch::OpClassTensorOp,     // Operator class
    cutlass::arch::Sm80,                // Architecture (Ampere)
    cutlass::gemm::GemmShape<128, 128, 32>,  // Tile shape
    cutlass::gemm::GemmShape<64, 64, 32>,    // Warp shape
    cutlass::gemm::GemmShape<16, 8, 16>,     // MMA shape
    // Epilogue with GELU activation
    cutlass::epilogue::thread::LinearCombinationGELU<
        cutlass::half_t,    // Output type
        8,                  // Elements per access
        float,              // Accumulator type
        float               // Compute type
    >
>;

Why GEMM Epilogue Fusion Is Especially Effective

The GEMM produces a large output matrix in register/shared memory tiles. Without epilogue fusion, these tiles are written to HBM and then read back by the next kernel (bias add). With epilogue fusion, the bias is added and the activation function is applied while the data is still in registers — zero additional HBM traffic.

def gemm_epilogue_savings(M=2048, N=4096, K=4096, dtype_bytes=2):
    """Quantify GEMM epilogue fusion savings."""
    output_bytes = M * N * dtype_bytes

    # Without epilogue fusion:
    # GEMM writes output to HBM: M*N*2 bytes
    # Bias add reads output + writes: 2 * M*N*2 bytes
    # GELU reads + writes: 2 * M*N*2 bytes
    unfused_extra = 4 * output_bytes  # 4 extra read/writes

    # With epilogue fusion:
    # GEMM applies bias+GELU in registers before writing
    # Extra traffic: 0 bytes (bias is tiny)
    fused_extra = 0

    hbm_bw = 3350e9  # H100 GB/s
    time_saved_us = unfused_extra / hbm_bw * 1e6

    print(f"Output tensor: {output_bytes/1e6:.1f} MB")
    print(f"Extra traffic without fusion: {unfused_extra/1e6:.1f} MB")
    print(f"Time saved per layer: {time_saved_us:.1f} us")
    print(f"Time saved per 80 layers: {time_saved_us * 80:.0f} us = "
          f"{time_saved_us * 80 / 1000:.1f} ms")

gemm_epilogue_savings()
📊

GEMM Epilogue Fusion Impact (H100, M=2048, N=K=4096, FP16)

ConfigurationGEMM Time (ms)Post-ops Time (ms)Total (ms)Speedup
Unfused: GEMM + bias + GELU 0.39 0.08 0.47 1.00x
Fused: GEMM w/ bias+GELU epilogue 0.40 0.00 0.40 1.18x
Unfused: GEMM + bias + GELU + residual 0.39 0.12 0.51 1.00x
Fused: GEMM w/ bias+GELU+residual epilogue 0.41 0.00 0.41 1.24x
Note: The GEMM itself takes slightly longer with epilogue fusion (more register pressure), but eliminating the post-op kernels provides a net speedup.
ℹ️ torch.compile Generates Epilogue Fusions

PyTorch’s torch.compile with the inductor backend can automatically fuse elementwise operations after GEMMs into CUTLASS/Triton epilogues. For custom kernels, you must implement epilogue fusion manually using CUTLASS or write a Triton kernel that combines the GEMM and post-ops.

Pattern 4: Attention Fusion (FlashAttention)

The Most Complex Fusion

FlashAttention fuses the entire attention computation — QKTQK^T scoring, softmax, dropout, and score×V\text{score} \times V — into a single kernel. This is not just elementwise fusion; it requires the online softmax algorithm to avoid materializing the full N×NN \times N attention matrix.

def attention_memory_analysis(batch=1, heads=32, seq_len=4096,
                               head_dim=128, dtype_bytes=2):
    """Compare memory for standard vs fused attention."""
    # Standard attention:
    # 1. S = Q @ K^T -> [B, H, N, N] attention scores
    # 2. P = softmax(S) -> [B, H, N, N]
    # 3. O = P @ V -> [B, H, N, D]

    attention_matrix_bytes = batch * heads * seq_len * seq_len * dtype_bytes
    qkv_bytes = batch * heads * seq_len * head_dim * dtype_bytes * 3
    output_bytes = batch * heads * seq_len * head_dim * dtype_bytes

    standard_peak = (qkv_bytes + 2 * attention_matrix_bytes +
                     output_bytes)

    # FlashAttention: never materializes NxN matrix
    # Processes in blocks, keeping running softmax statistics in SRAM
    flash_peak = qkv_bytes + output_bytes  # No NxN matrix

    print(f"Sequence length: {seq_len}")
    print(f"Attention matrix: {attention_matrix_bytes/1e9:.2f} GB")
    print(f"Standard peak memory: {standard_peak/1e9:.2f} GB")
    print(f"FlashAttention peak:  {flash_peak/1e9:.4f} GB")
    print(f"Memory reduction: {standard_peak/flash_peak:.0f}x")

attention_memory_analysis(seq_len=4096)
print()
attention_memory_analysis(seq_len=32768)
print()
attention_memory_analysis(seq_len=131072)

FlashAttention Tiling Strategy

def flash_attention_tiling(seq_len=4096, head_dim=128, block_size=256,
                            sram_bytes=192*1024):
    """Analyze FlashAttention tiling parameters.

    FlashAttention processes the attention computation in blocks:
    - Outer loop: iterate over blocks of Q (rows of output)
    - Inner loop: iterate over blocks of K,V (columns of attention)
    - For each (Q_block, K_block): compute partial attention scores
      in shared memory, update running softmax statistics
    """
    # SRAM budget per block:
    # Q block: block_size * head_dim * 2 bytes
    # K block: block_size * head_dim * 2 bytes
    # V block: block_size * head_dim * 2 bytes
    # Output accumulator: block_size * head_dim * 4 bytes (FP32)
    # Softmax statistics: block_size * 4 * 2 (max and sum per row)

    q_sram = block_size * head_dim * 2
    k_sram = block_size * head_dim * 2
    v_sram = block_size * head_dim * 2
    out_sram = block_size * head_dim * 4  # FP32 accumulator
    stats_sram = block_size * 4 * 2  # max + sum per row

    total_sram = q_sram + k_sram + v_sram + out_sram + stats_sram

    num_q_blocks = (seq_len + block_size - 1) // block_size
    num_kv_blocks = (seq_len + block_size - 1) // block_size

    # HBM traffic
    # Q: read once per outer iteration
    # K,V: read num_q_blocks times (once per outer iteration)
    # Output: write once
    q_reads = seq_len * head_dim * 2
    kv_reads = 2 * seq_len * head_dim * 2 * num_q_blocks
    output_writes = seq_len * head_dim * 2

    total_hbm = q_reads + kv_reads + output_writes

    # Standard attention HBM traffic
    standard_hbm = (3 * seq_len * head_dim * 2 +  # Q, K, V reads
                    2 * seq_len * seq_len * 2 +      # S write + P read
                    seq_len * head_dim * 2)           # Output

    print(f"Block size: {block_size}")
    print(f"SRAM per block: {total_sram/1024:.1f} KB "
          f"(limit: {sram_bytes/1024:.0f} KB)")
    print(f"Q blocks x KV blocks: {num_q_blocks} x {num_kv_blocks}")
    print(f"FlashAttention HBM traffic: {total_hbm/1e6:.1f} MB")
    print(f"Standard attention HBM:     {standard_hbm/1e6:.1f} MB")
    print(f"Traffic reduction: {standard_hbm/total_hbm:.1f}x")

flash_attention_tiling()

Attention HBM Traffic: Standard vs FlashAttention

(MB)
Standard 4K
134 MB
Flash 4K 7.4x less
18 MB
Standard 16K
2,048 MB
Flash 16K 29x less
70 MB
Standard 128K
131,072 MB
Flash 128K 240x less
544 MB

Implementation: Fused Bias + GELU Kernel

Complete CUDA Implementation

#include <cuda_fp16.h>
#include <cuda_runtime.h>

// GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
__device__ __forceinline__ float gelu_forward(float x) {
    const float kSqrt2OverPi = 0.7978845608028654f;
    const float kCoeff = 0.044715f;
    float cube = x * x * x;
    float inner = kSqrt2OverPi * (x + kCoeff * cube);
    return 0.5f * x * (1.0f + tanhf(inner));
}

// FP16 vectorized fused bias+GELU
// Processes 8 FP16 elements per thread (128 bits = float4 load)
__global__ void fused_bias_gelu_half(
    half* __restrict__ output,
    const half* __restrict__ input,
    const half* __restrict__ bias,
    const int hidden_dim,
    const int total_elements
) {
    const int tid = blockIdx.x * blockDim.x + threadIdx.x;
    const int vec_idx = tid * 8;  // 8 elements per thread

    if (vec_idx + 7 >= total_elements) {
        // Scalar fallback for tail elements
        for (int i = vec_idx; i < total_elements && i < vec_idx + 8; i++) {
            float val = __half2float(input[i]);
            val += __half2float(bias[i % hidden_dim]);
            val = gelu_forward(val);
            output[i] = __float2half(val);
        }
        return;
    }

    // Vectorized load: 16 bytes
    float4 in_vec = *reinterpret_cast<const float4*>(&input[vec_idx]);
    half* in_half = reinterpret_cast<half*>(&in_vec);

    // Load bias (assumes hidden_dim is large enough for aligned access)
    int bias_offset = vec_idx % hidden_dim;
    float4 bias_vec;
    if (bias_offset + 7 < hidden_dim) {
        bias_vec = *reinterpret_cast<const float4*>(&bias[bias_offset]);
    } else {
        half bias_buf[8];
        for (int i = 0; i < 8; i++) {
            bias_buf[i] = bias[(vec_idx + i) % hidden_dim];
        }
        bias_vec = *reinterpret_cast<float4*>(bias_buf);
    }
    half* bias_half = reinterpret_cast<half*>(&bias_vec);

    // Compute: bias_add + GELU
    float4 out_vec;
    half* out_half = reinterpret_cast<half*>(&out_vec);

    #pragma unroll
    for (int i = 0; i < 8; i++) {
        float val = __half2float(in_half[i]) + __half2float(bias_half[i]);
        val = gelu_forward(val);
        out_half[i] = __float2half(val);
    }

    // Vectorized store: 16 bytes
    *reinterpret_cast<float4*>(&output[vec_idx]) = out_vec;
}

// Launch configuration
void launch_fused_bias_gelu(
    half* output, const half* input, const half* bias,
    int batch_tokens, int hidden_dim, cudaStream_t stream
) {
    int total = batch_tokens * hidden_dim;
    int threads_needed = (total + 7) / 8;  // 8 elements per thread
    int block_size = 256;
    int grid_size = (threads_needed + block_size - 1) / block_size;

    fused_bias_gelu_half<<<grid_size, block_size, 0, stream>>>(
        output, input, bias, hidden_dim, total
    );
}

Benchmark the Fused Kernel

def benchmark_fused_vs_unfused(M=2048, N=4096, num_iters=1000):
    """Benchmark fused bias+GELU vs separate operations."""
    device = 'cuda'
    x = torch.randn(M, N, device=device, dtype=torch.float16)
    bias = torch.randn(N, device=device, dtype=torch.float16)

    # Unfused
    def unfused(x, bias):
        y = x + bias.unsqueeze(0)
        y = torch.nn.functional.gelu(y, approximate='tanh')
        return y

    # torch.compile fused
    fused = torch.compile(unfused)

    # Warmup
    for _ in range(50):
        unfused(x, bias)
        fused(x, bias)

    # Benchmark
    for name, fn in [('Unfused', unfused), ('Fused', fused)]:
        torch.cuda.synchronize()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        start.record()
        for _ in range(num_iters):
            fn(x, bias)
        end.record()

        torch.cuda.synchronize()
        elapsed = start.elapsed_time(end) / num_iters
        bandwidth = (M * N * 2 * 2) / (elapsed / 1000) / 1e9

        print(f"{name:10s}: {elapsed:.4f} ms, "
              f"effective bandwidth: {bandwidth:.0f} GB/s")
📊

Fused vs Unfused Bias+GELU (H100, FP16)

ShapeUnfused (us)Fused (us)SpeedupFused BW (GB/s)
[512, 4096] 15.2 6.8 2.24x 1190
[2048, 4096] 36.4 15.1 2.41x 2180
[2048, 8192] 68.1 28.3 2.41x 2370
[2048, 11008] 89.2 37.5 2.38x 2410
[1, 4096] 8.1 4.2 1.93x 3.9
Note: Fusion provides 2-2.4x speedup for large tensors. For tiny tensors (decode, M=1), the speedup is lower because kernel launch overhead dominates.

Fusion Opportunities in a Transformer Layer

Mapping All Fusible Operations

def transformer_layer_fusion_map():
    """Identify all fusion opportunities in a decoder layer."""
    operations = [
        # Pre-attention norm
        ("RMSNorm", "fused: reduction + elementwise scale"),

        # QKV projection
        ("GEMM (QKV)", "fused epilogue: + bias"),
        ("RoPE", "standalone (trigonometric, cannot fuse with GEMM)"),

        # Attention
        ("QK^T + scale + mask + softmax + dropout + V@",
         "FlashAttention: all fused into one kernel"),

        # Output projection
        ("GEMM (O_proj)", "fused epilogue: + residual add"),

        # Post-attention norm
        ("RMSNorm", "fused: reduction + elementwise scale"),

        # MLP
        ("GEMM (gate_proj)", "standalone"),
        ("GEMM (up_proj)", "standalone"),
        ("SiLU(gate) * up", "fused elementwise: SiLU + multiply"),
        ("GEMM (down_proj)", "fused epilogue: + residual add"),
    ]

    print("=== Transformer Layer Fusion Map ===")
    for op, fusion_status in operations:
        print(f"  {op:45s} -> {fusion_status}")

transformer_layer_fusion_map()

Summary

Kernel fusion eliminates HBM round-trips between operations that would otherwise each launch a separate kernel. The four patterns cover the entire space: elementwise fusion combines independent per-element operations (bias+GELU+dropout), reduction fusion combines dependent operations that require collective computation (LayerNorm), GEMM epilogue fusion applies post-GEMM operations while output tiles are still in registers, and attention fusion (FlashAttention) eliminates the O(N2)O(N^2) attention matrix entirely.

The implementation hierarchy: torch.compile handles elementwise fusion automatically, CUTLASS handles GEMM epilogue fusion, FlashAttention handles attention fusion, and custom CUDA kernels handle anything that does not fit these patterns. For LLM inference, the combined effect of all fusion patterns reduces per-layer HBM traffic by 2-4x and total latency by 1.5-2x.