A transformer MLP layer with separate kernels for GEMM, bias, GELU, and dropout executes in 340 microseconds on an H100. Fuse bias+GELU into the GEMM epilogue and time drops to 280 microseconds. Fuse dropout into the same kernel and time drops to 240 microseconds — a 30% speedup from eliminating two kernel launches and two HBM round-trips. The arithmetic is unchanged. The difference is memory traffic: the unfused version writes 32 MB of intermediate results to HBM after bias and 32 MB after GELU, then reads them back. The fused version writes once. HBM traffic drops from 128 MB to 64 MB.
Kernel fusion eliminates kernel launch overhead and memory round-trips by combining multiple operations into a single kernel. The fused kernel reads the input once, applies all operations in registers, and writes the final output once. This post covers four fusion patterns that appear everywhere in LLM inference: elementwise fusion, reduction fusion, GEMM epilogue fusion, and attention fusion.
Why Fusion Matters: The Bandwidth Wall
Memory Traffic Dominates Kernel Time
For elementwise operations (add, multiply, activation functions, dropout), the arithmetic intensity is O(1) — one or two FLOPs per element loaded. The roofline model shows these operations are deep in the memory-bandwidth-bound regime:
import torch
import time
def measure_kernel_overhead():
"""Measure the cost of separate vs fused operations."""
device = 'cuda'
M, N = 2048, 4096
dtype = torch.float16
x = torch.randn(M, N, device=device, dtype=dtype)
bias = torch.randn(N, device=device, dtype=dtype)
dropout_mask = torch.bernoulli(torch.full((M, N), 0.9,
device=device)).to(dtype)
# Separate operations
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(1000):
y = x + bias # Kernel 1: bias add
y = torch.nn.functional.gelu(y) # Kernel 2: GELU
y = y * dropout_mask # Kernel 3: dropout
torch.cuda.synchronize()
t_separate = (time.perf_counter() - start) / 1000
# Fused (using torch.compile)
@torch.compile
def fused_bias_gelu_dropout(x, bias, mask):
y = x + bias
y = torch.nn.functional.gelu(y)
y = y * mask
return y
# Warmup
for _ in range(10):
fused_bias_gelu_dropout(x, bias, dropout_mask)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(1000):
y = fused_bias_gelu_dropout(x, bias, dropout_mask)
torch.cuda.synchronize()
t_fused = (time.perf_counter() - start) / 1000
print(f"Separate kernels: {t_separate*1e6:.1f} us")
print(f"Fused kernel: {t_fused*1e6:.1f} us")
print(f"Speedup: {t_separate/t_fused:.2f}x")
measure_kernel_overhead()
Quantifying Memory Traffic
def memory_traffic_analysis(M=2048, N=4096, dtype_bytes=2):
"""Calculate memory traffic for separate vs fused operations."""
tensor_bytes = M * N * dtype_bytes
# Separate: bias_add + gelu + dropout
# Each op: read input + write output = 2 * tensor_bytes
# bias_add also reads bias (N * dtype_bytes, negligible)
separate_traffic = 3 * 2 * tensor_bytes # 3 ops, each read+write
# Fused: read input once, write output once
fused_traffic = 2 * tensor_bytes
print(f"Tensor size: {tensor_bytes / 1e6:.1f} MB")
print(f"Separate traffic: {separate_traffic / 1e6:.1f} MB "
f"(3 read + 3 write)")
print(f"Fused traffic: {fused_traffic / 1e6:.1f} MB "
f"(1 read + 1 write)")
print(f"Traffic reduction: {separate_traffic / fused_traffic:.1f}x")
# Time estimate on H100 (3350 GB/s)
hbm_bw = 3350e9 # bytes/sec
t_separate = separate_traffic / hbm_bw * 1e6 # microseconds
t_fused = fused_traffic / hbm_bw * 1e6
print(f"Estimated time (H100): separate={t_separate:.1f} us, "
f"fused={t_fused:.1f} us")
memory_traffic_analysis()
Pattern 1: Elementwise Fusion
The Simplest Fusion
Elementwise operations apply a function independently to each element (or corresponding elements from multiple tensors). They can always be fused because there are no inter-element dependencies.
// Unfused: three separate kernels
__global__ void bias_add_kernel(float* out, const float* in,
const float* bias, int N, int total) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total) {
out[idx] = in[idx] + bias[idx % N];
}
}
__global__ void gelu_kernel(float* out, const float* in, int total) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total) {
float x = in[idx];
// GELU approximation
out[idx] = 0.5f * x * (1.0f + tanhf(0.7978845608f *
(x + 0.044715f * x * x * x)));
}
}
__global__ void dropout_kernel(float* out, const float* in,
const float* mask, float scale, int total) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total) {
out[idx] = in[idx] * mask[idx] * scale;
}
}
// Fused: single kernel
__global__ void fused_bias_gelu_dropout_kernel(
float* out,
const float* in,
const float* bias,
const float* dropout_mask,
float dropout_scale,
int N,
int total
) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < total) {
// Load input (one HBM read)
float x = in[idx];
// Bias add (bias is small, likely in L2/L1)
x += bias[idx % N];
// GELU (compute in registers)
x = 0.5f * x * (1.0f + tanhf(0.7978845608f *
(x + 0.044715f * x * x * x)));
// Dropout (mask load + multiply)
x *= dropout_mask[idx] * dropout_scale;
// Store output (one HBM write)
out[idx] = x;
}
}
Vectorized Fused Kernel
For half-precision (FP16), use vectorized loads/stores with half2 or float4 to maximize memory bandwidth utilization:
#include <cuda_fp16.h>
__global__ void fused_bias_gelu_dropout_fp16(
half* __restrict__ out,
const half* __restrict__ in,
const half* __restrict__ bias,
const uint8_t* __restrict__ dropout_mask,
half dropout_scale,
int N,
int total
) {
// Process 8 elements per thread (4 half2 loads)
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 8;
if (idx + 7 >= total) return;
// Vectorized load: read 8 halfs = 16 bytes
float4 in_vec = *reinterpret_cast<const float4*>(&in[idx]);
half2* in_h2 = reinterpret_cast<half2*>(&in_vec);
// Load bias (8 elements)
int bias_start = idx % N;
float4 bias_vec;
if (bias_start + 7 < N) {
bias_vec = *reinterpret_cast<const float4*>(&bias[bias_start]);
} else {
// Handle wrap-around (rare for typical dimensions)
half bias_buf[8];
for (int i = 0; i < 8; i++) {
bias_buf[i] = bias[(idx + i) % N];
}
bias_vec = *reinterpret_cast<float4*>(bias_buf);
}
half2* bias_h2 = reinterpret_cast<half2*>(&bias_vec);
// Load dropout mask
uint8_t mask_byte = dropout_mask[idx / 8];
// Process 4 half2 pairs
half2 scale_h2 = __half2half2(dropout_scale);
float4 out_vec;
half2* out_h2 = reinterpret_cast<half2*>(&out_vec);
#pragma unroll
for (int i = 0; i < 4; i++) {
// Bias add
half2 val = __hadd2(in_h2[i], bias_h2[i]);
// GELU approximation in FP16
float2 f = __half22float2(val);
f.x = 0.5f * f.x * (1.0f + tanhf(0.7978845608f *
(f.x + 0.044715f * f.x * f.x * f.x)));
f.y = 0.5f * f.y * (1.0f + tanhf(0.7978845608f *
(f.y + 0.044715f * f.y * f.y * f.y)));
val = __float22half2_rn(f);
// Dropout
val = __hmul2(val, scale_h2);
out_h2[i] = val;
}
// Vectorized store: write 8 halfs = 16 bytes
*reinterpret_cast<float4*>(&out[idx]) = out_vec;
}
On A100/H100, a single thread issuing 2-byte half loads achieves only ~40% of peak HBM bandwidth. Using float4 (16-byte) loads pushes this to >90%. The fused kernel must use vectorized loads to actually realize the theoretical bandwidth savings from fusion.
Pattern 2: Reduction Fusion
LayerNorm as a Single Kernel
LayerNorm computes mean, variance, normalize, scale, and shift — five logical operations. Unfused, this requires multiple kernel launches and multiple HBM round-trips:
// Unfused LayerNorm: multiple kernels
// Kernel 1: Compute mean (reduction over hidden dim)
// Kernel 2: Compute variance (reduction over hidden dim)
// Kernel 3: Normalize + scale + shift (elementwise)
// Total: 3 kernel launches, ~5 HBM read/writes of the full tensor
// Fused LayerNorm: single kernel per row
__global__ void fused_rmsnorm_kernel(
float* __restrict__ out,
const float* __restrict__ in,
const float* __restrict__ weight,
int hidden_dim,
float eps
) {
// Each block processes one row (one token)
int row = blockIdx.x;
const float* row_in = in + row * hidden_dim;
float* row_out = out + row * hidden_dim;
// Step 1: Compute sum of squares (warp-level reduction)
float sum_sq = 0.0f;
for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
float val = row_in[i];
sum_sq += val * val;
}
// Warp-level reduction
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
sum_sq += __shfl_xor_sync(0xFFFFFFFF, sum_sq, offset);
}
// Block-level reduction using shared memory
__shared__ float warp_sums[32];
int warp_id = threadIdx.x / warpSize;
int lane = threadIdx.x % warpSize;
if (lane == 0) {
warp_sums[warp_id] = sum_sq;
}
__syncthreads();
// First warp reduces across warps
if (warp_id == 0) {
sum_sq = (lane < blockDim.x / warpSize) ?
warp_sums[lane] : 0.0f;
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
sum_sq += __shfl_xor_sync(0xFFFFFFFF, sum_sq, offset);
}
}
// Broadcast the RMS value
__shared__ float rms_inv;
if (threadIdx.x == 0) {
rms_inv = rsqrtf(sum_sq / hidden_dim + eps);
}
__syncthreads();
// Step 2: Normalize and apply weight
// Second pass over the data (read from HBM or L2 cache)
for (int i = threadIdx.x; i < hidden_dim; i += blockDim.x) {
row_out[i] = row_in[i] * rms_inv * weight[i];
}
}
Why Reduction Fusion Is Harder
Unlike elementwise fusion, reductions have inter-element dependencies: computing the mean requires reading all elements. This forces a two-pass pattern:
- First pass: Read all elements, compute reduction (mean, variance, sum of squares)
- Second pass: Read all elements again, apply the normalized transformation
The fused kernel still reads the data twice from memory, but:
- Both passes are in the same kernel (one launch instead of three)
- On the second pass, the data may be in L2 cache (if the tensor fits)
- The intermediate mean/variance are in shared memory/registers, not HBM
def reduction_fusion_traffic(hidden_dim=4096, batch_tokens=2048,
dtype_bytes=2):
"""Compare memory traffic for fused vs unfused LayerNorm."""
tensor_bytes = batch_tokens * hidden_dim * dtype_bytes
# Unfused: 3 kernel launches
# K1 (mean): read input, write per-row means
# K2 (var): read input + means, write per-row vars
# K3 (norm): read input + means + vars + weight, write output
unfused_traffic = 3 * tensor_bytes + 3 * tensor_bytes # ~6 reads + writes
# Fused: 1 kernel, 2 passes over data
# Pass 1: read input (may stay in L2)
# Pass 2: read input (from L2 if fits), read weight, write output
fused_traffic_cold = 2 * tensor_bytes + tensor_bytes # 2 reads + 1 write
fused_traffic_l2 = tensor_bytes + tensor_bytes # 1 HBM read + 1 write (2nd from L2)
print(f"Unfused traffic: {unfused_traffic/1e6:.1f} MB")
print(f"Fused (cold): {fused_traffic_cold/1e6:.1f} MB")
print(f"Fused (L2 hit): {fused_traffic_l2/1e6:.1f} MB")
reduction_fusion_traffic()
Pattern 3: GEMM Epilogue Fusion
Fusing Operations After Matrix Multiply
The most impactful fusion in transformer inference is fusing the bias, activation, and possibly residual addition into the GEMM epilogue. cuBLAS and CUTLASS support this natively:
// CUTLASS GEMM with fused epilogue
// Instead of: Y = A @ B; Y = Y + bias; Y = gelu(Y)
// Computes: Y = gelu(A @ B + bias) in one kernel
#include <cutlass/gemm/device/gemm_universal.h>
#include <cutlass/epilogue/thread/linear_combination_gelu.h>
// Define the GEMM with GELU epilogue
using GemmGelu = cutlass::gemm::device::GemmUniversal<
cutlass::half_t, // Element A
cutlass::layout::RowMajor, // Layout A
cutlass::half_t, // Element B
cutlass::layout::ColumnMajor, // Layout B
cutlass::half_t, // Element C/D
cutlass::layout::RowMajor, // Layout C/D
float, // Accumulator
cutlass::arch::OpClassTensorOp, // Operator class
cutlass::arch::Sm80, // Architecture (Ampere)
cutlass::gemm::GemmShape<128, 128, 32>, // Tile shape
cutlass::gemm::GemmShape<64, 64, 32>, // Warp shape
cutlass::gemm::GemmShape<16, 8, 16>, // MMA shape
// Epilogue with GELU activation
cutlass::epilogue::thread::LinearCombinationGELU<
cutlass::half_t, // Output type
8, // Elements per access
float, // Accumulator type
float // Compute type
>
>;
Why GEMM Epilogue Fusion Is Especially Effective
The GEMM produces a large output matrix in register/shared memory tiles. Without epilogue fusion, these tiles are written to HBM and then read back by the next kernel (bias add). With epilogue fusion, the bias is added and the activation function is applied while the data is still in registers — zero additional HBM traffic.
def gemm_epilogue_savings(M=2048, N=4096, K=4096, dtype_bytes=2):
"""Quantify GEMM epilogue fusion savings."""
output_bytes = M * N * dtype_bytes
# Without epilogue fusion:
# GEMM writes output to HBM: M*N*2 bytes
# Bias add reads output + writes: 2 * M*N*2 bytes
# GELU reads + writes: 2 * M*N*2 bytes
unfused_extra = 4 * output_bytes # 4 extra read/writes
# With epilogue fusion:
# GEMM applies bias+GELU in registers before writing
# Extra traffic: 0 bytes (bias is tiny)
fused_extra = 0
hbm_bw = 3350e9 # H100 GB/s
time_saved_us = unfused_extra / hbm_bw * 1e6
print(f"Output tensor: {output_bytes/1e6:.1f} MB")
print(f"Extra traffic without fusion: {unfused_extra/1e6:.1f} MB")
print(f"Time saved per layer: {time_saved_us:.1f} us")
print(f"Time saved per 80 layers: {time_saved_us * 80:.0f} us = "
f"{time_saved_us * 80 / 1000:.1f} ms")
gemm_epilogue_savings()
GEMM Epilogue Fusion Impact (H100, M=2048, N=K=4096, FP16)
| Configuration | GEMM Time (ms) | Post-ops Time (ms) | Total (ms) | Speedup |
|---|---|---|---|---|
| Unfused: GEMM + bias + GELU | 0.39 | 0.08 | 0.47 | 1.00x |
| Fused: GEMM w/ bias+GELU epilogue | 0.40 | 0.00 | 0.40 | 1.18x |
| Unfused: GEMM + bias + GELU + residual | 0.39 | 0.12 | 0.51 | 1.00x |
| Fused: GEMM w/ bias+GELU+residual epilogue | 0.41 | 0.00 | 0.41 | 1.24x |
PyTorch’s torch.compile with the inductor backend can automatically fuse elementwise operations after GEMMs into CUTLASS/Triton epilogues. For custom kernels, you must implement epilogue fusion manually using CUTLASS or write a Triton kernel that combines the GEMM and post-ops.
Pattern 4: Attention Fusion (FlashAttention)
The Most Complex Fusion
FlashAttention fuses the entire attention computation — scoring, softmax, dropout, and — into a single kernel. This is not just elementwise fusion; it requires the online softmax algorithm to avoid materializing the full attention matrix.
def attention_memory_analysis(batch=1, heads=32, seq_len=4096,
head_dim=128, dtype_bytes=2):
"""Compare memory for standard vs fused attention."""
# Standard attention:
# 1. S = Q @ K^T -> [B, H, N, N] attention scores
# 2. P = softmax(S) -> [B, H, N, N]
# 3. O = P @ V -> [B, H, N, D]
attention_matrix_bytes = batch * heads * seq_len * seq_len * dtype_bytes
qkv_bytes = batch * heads * seq_len * head_dim * dtype_bytes * 3
output_bytes = batch * heads * seq_len * head_dim * dtype_bytes
standard_peak = (qkv_bytes + 2 * attention_matrix_bytes +
output_bytes)
# FlashAttention: never materializes NxN matrix
# Processes in blocks, keeping running softmax statistics in SRAM
flash_peak = qkv_bytes + output_bytes # No NxN matrix
print(f"Sequence length: {seq_len}")
print(f"Attention matrix: {attention_matrix_bytes/1e9:.2f} GB")
print(f"Standard peak memory: {standard_peak/1e9:.2f} GB")
print(f"FlashAttention peak: {flash_peak/1e9:.4f} GB")
print(f"Memory reduction: {standard_peak/flash_peak:.0f}x")
attention_memory_analysis(seq_len=4096)
print()
attention_memory_analysis(seq_len=32768)
print()
attention_memory_analysis(seq_len=131072)
FlashAttention Tiling Strategy
def flash_attention_tiling(seq_len=4096, head_dim=128, block_size=256,
sram_bytes=192*1024):
"""Analyze FlashAttention tiling parameters.
FlashAttention processes the attention computation in blocks:
- Outer loop: iterate over blocks of Q (rows of output)
- Inner loop: iterate over blocks of K,V (columns of attention)
- For each (Q_block, K_block): compute partial attention scores
in shared memory, update running softmax statistics
"""
# SRAM budget per block:
# Q block: block_size * head_dim * 2 bytes
# K block: block_size * head_dim * 2 bytes
# V block: block_size * head_dim * 2 bytes
# Output accumulator: block_size * head_dim * 4 bytes (FP32)
# Softmax statistics: block_size * 4 * 2 (max and sum per row)
q_sram = block_size * head_dim * 2
k_sram = block_size * head_dim * 2
v_sram = block_size * head_dim * 2
out_sram = block_size * head_dim * 4 # FP32 accumulator
stats_sram = block_size * 4 * 2 # max + sum per row
total_sram = q_sram + k_sram + v_sram + out_sram + stats_sram
num_q_blocks = (seq_len + block_size - 1) // block_size
num_kv_blocks = (seq_len + block_size - 1) // block_size
# HBM traffic
# Q: read once per outer iteration
# K,V: read num_q_blocks times (once per outer iteration)
# Output: write once
q_reads = seq_len * head_dim * 2
kv_reads = 2 * seq_len * head_dim * 2 * num_q_blocks
output_writes = seq_len * head_dim * 2
total_hbm = q_reads + kv_reads + output_writes
# Standard attention HBM traffic
standard_hbm = (3 * seq_len * head_dim * 2 + # Q, K, V reads
2 * seq_len * seq_len * 2 + # S write + P read
seq_len * head_dim * 2) # Output
print(f"Block size: {block_size}")
print(f"SRAM per block: {total_sram/1024:.1f} KB "
f"(limit: {sram_bytes/1024:.0f} KB)")
print(f"Q blocks x KV blocks: {num_q_blocks} x {num_kv_blocks}")
print(f"FlashAttention HBM traffic: {total_hbm/1e6:.1f} MB")
print(f"Standard attention HBM: {standard_hbm/1e6:.1f} MB")
print(f"Traffic reduction: {standard_hbm/total_hbm:.1f}x")
flash_attention_tiling()
Attention HBM Traffic: Standard vs FlashAttention
(MB)Implementation: Fused Bias + GELU Kernel
Complete CUDA Implementation
#include <cuda_fp16.h>
#include <cuda_runtime.h>
// GELU approximation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
__device__ __forceinline__ float gelu_forward(float x) {
const float kSqrt2OverPi = 0.7978845608028654f;
const float kCoeff = 0.044715f;
float cube = x * x * x;
float inner = kSqrt2OverPi * (x + kCoeff * cube);
return 0.5f * x * (1.0f + tanhf(inner));
}
// FP16 vectorized fused bias+GELU
// Processes 8 FP16 elements per thread (128 bits = float4 load)
__global__ void fused_bias_gelu_half(
half* __restrict__ output,
const half* __restrict__ input,
const half* __restrict__ bias,
const int hidden_dim,
const int total_elements
) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int vec_idx = tid * 8; // 8 elements per thread
if (vec_idx + 7 >= total_elements) {
// Scalar fallback for tail elements
for (int i = vec_idx; i < total_elements && i < vec_idx + 8; i++) {
float val = __half2float(input[i]);
val += __half2float(bias[i % hidden_dim]);
val = gelu_forward(val);
output[i] = __float2half(val);
}
return;
}
// Vectorized load: 16 bytes
float4 in_vec = *reinterpret_cast<const float4*>(&input[vec_idx]);
half* in_half = reinterpret_cast<half*>(&in_vec);
// Load bias (assumes hidden_dim is large enough for aligned access)
int bias_offset = vec_idx % hidden_dim;
float4 bias_vec;
if (bias_offset + 7 < hidden_dim) {
bias_vec = *reinterpret_cast<const float4*>(&bias[bias_offset]);
} else {
half bias_buf[8];
for (int i = 0; i < 8; i++) {
bias_buf[i] = bias[(vec_idx + i) % hidden_dim];
}
bias_vec = *reinterpret_cast<float4*>(bias_buf);
}
half* bias_half = reinterpret_cast<half*>(&bias_vec);
// Compute: bias_add + GELU
float4 out_vec;
half* out_half = reinterpret_cast<half*>(&out_vec);
#pragma unroll
for (int i = 0; i < 8; i++) {
float val = __half2float(in_half[i]) + __half2float(bias_half[i]);
val = gelu_forward(val);
out_half[i] = __float2half(val);
}
// Vectorized store: 16 bytes
*reinterpret_cast<float4*>(&output[vec_idx]) = out_vec;
}
// Launch configuration
void launch_fused_bias_gelu(
half* output, const half* input, const half* bias,
int batch_tokens, int hidden_dim, cudaStream_t stream
) {
int total = batch_tokens * hidden_dim;
int threads_needed = (total + 7) / 8; // 8 elements per thread
int block_size = 256;
int grid_size = (threads_needed + block_size - 1) / block_size;
fused_bias_gelu_half<<<grid_size, block_size, 0, stream>>>(
output, input, bias, hidden_dim, total
);
}
Benchmark the Fused Kernel
def benchmark_fused_vs_unfused(M=2048, N=4096, num_iters=1000):
"""Benchmark fused bias+GELU vs separate operations."""
device = 'cuda'
x = torch.randn(M, N, device=device, dtype=torch.float16)
bias = torch.randn(N, device=device, dtype=torch.float16)
# Unfused
def unfused(x, bias):
y = x + bias.unsqueeze(0)
y = torch.nn.functional.gelu(y, approximate='tanh')
return y
# torch.compile fused
fused = torch.compile(unfused)
# Warmup
for _ in range(50):
unfused(x, bias)
fused(x, bias)
# Benchmark
for name, fn in [('Unfused', unfused), ('Fused', fused)]:
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for _ in range(num_iters):
fn(x, bias)
end.record()
torch.cuda.synchronize()
elapsed = start.elapsed_time(end) / num_iters
bandwidth = (M * N * 2 * 2) / (elapsed / 1000) / 1e9
print(f"{name:10s}: {elapsed:.4f} ms, "
f"effective bandwidth: {bandwidth:.0f} GB/s")
Fused vs Unfused Bias+GELU (H100, FP16)
| Shape | Unfused (us) | Fused (us) | Speedup | Fused BW (GB/s) |
|---|---|---|---|---|
| [512, 4096] | 15.2 | 6.8 | 2.24x | 1190 |
| [2048, 4096] | 36.4 | 15.1 | 2.41x | 2180 |
| [2048, 8192] | 68.1 | 28.3 | 2.41x | 2370 |
| [2048, 11008] | 89.2 | 37.5 | 2.38x | 2410 |
| [1, 4096] | 8.1 | 4.2 | 1.93x | 3.9 |
Fusion Opportunities in a Transformer Layer
Mapping All Fusible Operations
def transformer_layer_fusion_map():
"""Identify all fusion opportunities in a decoder layer."""
operations = [
# Pre-attention norm
("RMSNorm", "fused: reduction + elementwise scale"),
# QKV projection
("GEMM (QKV)", "fused epilogue: + bias"),
("RoPE", "standalone (trigonometric, cannot fuse with GEMM)"),
# Attention
("QK^T + scale + mask + softmax + dropout + V@",
"FlashAttention: all fused into one kernel"),
# Output projection
("GEMM (O_proj)", "fused epilogue: + residual add"),
# Post-attention norm
("RMSNorm", "fused: reduction + elementwise scale"),
# MLP
("GEMM (gate_proj)", "standalone"),
("GEMM (up_proj)", "standalone"),
("SiLU(gate) * up", "fused elementwise: SiLU + multiply"),
("GEMM (down_proj)", "fused epilogue: + residual add"),
]
print("=== Transformer Layer Fusion Map ===")
for op, fusion_status in operations:
print(f" {op:45s} -> {fusion_status}")
transformer_layer_fusion_map()
Summary
Kernel fusion eliminates HBM round-trips between operations that would otherwise each launch a separate kernel. The four patterns cover the entire space: elementwise fusion combines independent per-element operations (bias+GELU+dropout), reduction fusion combines dependent operations that require collective computation (LayerNorm), GEMM epilogue fusion applies post-GEMM operations while output tiles are still in registers, and attention fusion (FlashAttention) eliminates the attention matrix entirely.
The implementation hierarchy: torch.compile handles elementwise fusion automatically, CUTLASS handles GEMM epilogue fusion, FlashAttention handles attention fusion, and custom CUDA kernels handle anything that does not fit these patterns. For LLM inference, the combined effect of all fusion patterns reduces per-layer HBM traffic by 2-4x and total latency by 1.5-2x.