Writing a fused softmax kernel in CUDA requires 150 lines of pointer arithmetic, shared memory padding, and bank conflict avoidance. The same kernel in Triton requires 35 lines of Python-like code that runs at 90% of the hand-tuned CUDA version. The Triton compiler handles thread mapping, memory coalescing, and shared memory allocation automatically, then auto-tunes over 50+ kernel configurations to find the fastest block size and tiling strategy. For memory-bound kernels and fusion patterns, Triton matches CUDA’s performance in one-third the development time.
Triton does not replace CUDA for all workloads. It achieves 80-95% of hand-written CUDA performance for memory-bound and fusion-heavy kernels. For compute-bound kernels requiring precise control over tensor core scheduling or register allocation, CUDA still wins. This post covers the programming model, the auto-tuning system, a complete fused softmax implementation, and a rigorous performance comparison.
All benchmarks run on A100-80GB SXM, Triton 2.x (nightly, pip install from the OpenAI Triton repository), PyTorch 2.x.
The Triton Programming Model
Block-Level Programs
A Triton kernel is a program that operates on blocks of data. You do not write per-thread code. Instead, you write code that processes a 1D or 2D block of elements, and the compiler maps it to warps and threads.
import triton
import triton.language as tl
@triton.jit
def vector_add_kernel(
a_ptr, # Pointer to first input
b_ptr, # Pointer to second input
c_ptr, # Pointer to output
n_elements, # Total number of elements
BLOCK_SIZE: tl.constexpr, # Number of elements per program instance
):
# Program ID: which block of data this instance processes
pid = tl.program_id(axis=0)
# Compute the range of offsets this program handles
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create a mask for out-of-bounds elements
mask = offsets < n_elements
# Load blocks of data (automatic coalescing)
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
# Compute
c = a + b
# Store results
tl.store(c_ptr + offsets, c, mask=mask)
Key differences from CUDA:
- No
threadIdx,blockIdx,blockDim— you havetl.program_idandtl.arange - No explicit shared memory allocation — the compiler decides when to use it
- No
__syncthreads— the compiler inserts barriers where needed - No manual coalescing — pointer arithmetic with
tl.arangegenerates coalesced patterns maskparameter handles boundary conditions (no manualif (idx < n))
Launching Triton Kernels
import torch
def vector_add(a, b):
assert a.is_cuda and b.is_cuda
c = torch.empty_like(a)
n = a.numel()
# Grid: number of program instances
grid = lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']),)
# Launch
vector_add_kernel[grid](a, b, c, n, BLOCK_SIZE=1024)
return c
The grid is a callable that takes the kernel’s constexpr parameters and returns the grid dimensions. triton.cdiv is ceiling division. The BLOCK_SIZE parameter is a compile-time constant that Triton can specialize on.
Triton’s Compilation Pipeline
Understanding the compiler is essential for writing efficient Triton code.
From Python to PTX
- Triton IR: Your Python code is traced and lowered to Triton’s intermediate representation
- Triton IR optimizations: Block-level optimizations (fusion, redundant load elimination)
- Triton IR to LLVM IR: Block operations are lowered to thread-level operations. The compiler decides:
- How many threads per block
- Which loads go through shared memory
- Where to insert
__syncthreadsbarriers - How to map block operations to warp shuffles
- LLVM IR to PTX: Standard LLVM backend for NVIDIA GPUs
- PTX to SASS: NVIDIA’s
ptxasassembler (same as nvcc)
The Triton compiler caches the compiled kernel. The first call incurs compilation overhead (hundreds of milliseconds to seconds); subsequent calls with the same constexpr parameters hit the cache.
What the Compiler Does Automatically
- Memory coalescing:
tl.load(ptr + tl.arange(0, BLOCK_SIZE))generates stride-1 access - Shared memory tiling: for reductions and operations requiring data reuse, the compiler allocates shared memory and manages synchronization
- Vectorized loads: the compiler emits
ld.global.v2orld.global.v4instructions when alignment permits - Warp-level reductions:
tl.sum,tl.max, etc. compile to warp shuffle trees, not shared memory reductions
What the Compiler Does NOT Do
- Cross-block communication: each program instance is independent (like CUDA blocks)
- Persistent kernel patterns: no direct support for kernels that loop over multiple tiles within one block
- Tensor core scheduling: the compiler can use tensor cores via
tl.dot, but you do not control the MMA instruction selection or register layout - Custom memory access patterns: if your algorithm requires a non-standard access pattern (e.g., gather from a hash table), you may fight the compiler
Core Triton Operations
Loads and Stores
# 1D block load
offsets = tl.arange(0, BLOCK_SIZE)
x = tl.load(ptr + offsets, mask=offsets < n, other=0.0)
# 2D block load (for matrices)
row_offsets = tl.arange(0, BLOCK_M)[:, None] # Column vector
col_offsets = tl.arange(0, BLOCK_N)[None, :] # Row vector
ptrs = base_ptr + row_offsets * stride + col_offsets
data = tl.load(ptrs, mask=(row_offsets < M) & (col_offsets < N), other=0.0)
# Atomic operations
tl.atomic_add(ptr + offsets, values, mask=mask)
Reductions
# Sum reduction across an axis
x = tl.load(ptr + offsets, mask=mask, other=0.0)
total = tl.sum(x, axis=0) # Scalar result
# Max reduction
max_val = tl.max(x, axis=0)
# Combined: stable softmax numerics
max_val = tl.max(x, axis=0)
x = x - max_val
exp_x = tl.exp(x)
sum_exp = tl.sum(exp_x, axis=0)
softmax = exp_x / sum_exp
Matrix Multiply (tl.dot)
# Block-level matrix multiply: compiles to tensor core instructions
# a: (BLOCK_M, BLOCK_K), b: (BLOCK_K, BLOCK_N)
# result: (BLOCK_M, BLOCK_N)
c = tl.dot(a, b)
# Accumulate into existing result
c += tl.dot(a_tile, b_tile)
tl.dot is the primary mechanism for tensor core utilization in Triton. The compiler selects appropriate MMA instructions based on the data types and block dimensions.
Math Operations
# Elementwise operations (all operate on blocks)
y = tl.exp(x)
y = tl.log(x)
y = tl.sqrt(x)
y = tl.sigmoid(x) # 1 / (1 + exp(-x))
y = tl.where(cond, a, b) # Conditional select
y = tl.maximum(a, b)
y = tl.minimum(a, b)
# Type casting
x_fp16 = x.to(tl.float16)
x_fp32 = x.to(tl.float32)
Auto-Tuning with @triton.autotune
Triton’s auto-tuning system searches over kernel configurations to find the fastest one for a specific problem size and GPU.
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 128}, num_warps=2, num_stages=2),
triton.Config({'BLOCK_SIZE': 256}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_SIZE': 512}, num_warps=4, num_stages=3),
triton.Config({'BLOCK_SIZE': 1024}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_SIZE': 2048}, num_warps=8, num_stages=4),
],
key=['n_elements'], # Re-tune when this argument changes
)
@triton.jit
def vector_add_autotuned(
a_ptr, b_ptr, c_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
tl.store(c_ptr + offsets, a + b, mask=mask)
Auto-Tune Parameters
configs: list oftriton.Configobjects specifyingconstexprvalues,num_warps, andnum_stagesnum_warps: number of warps per program instance (threads = num_warps * 32)num_stages: number of software pipeline stages for memory loads (higher = more latency hiding, more shared memory)key: list of kernel argument names that trigger re-tuning when their values change
The auto-tuner compiles and benchmarks each configuration, then caches the winner. On the first call with a new key, it runs all configs. Subsequent calls use the cached best config.
Keep the search space small (5-15 configurations) and meaningful. Vary BLOCK_SIZE by powers of 2, num_warps from 2 to 8, and num_stages from 1 to 4. Do not create a combinatorial explosion (e.g., 10 block sizes x 4 warp counts x 4 stages = 160 configs is too many). Each config requires a full kernel compilation and benchmark run.
Implementation: Fused Softmax Kernel
Softmax is the ideal Triton showcase: it requires a max reduction, subtraction, exponentiation, sum reduction, and division — all of which can be fused into a single kernel, eliminating intermediate global memory traffic.
PyTorch Baseline (Unfused)
# PyTorch: 5 separate kernels, 5 global memory round-trips
def softmax_pytorch(x):
max_val = x.max(dim=-1, keepdim=True).values # Kernel 1: max reduction
x_shifted = x - max_val # Kernel 2: subtraction
exp_x = torch.exp(x_shifted) # Kernel 3: exp
sum_exp = exp_x.sum(dim=-1, keepdim=True) # Kernel 4: sum reduction
return exp_x / sum_exp # Kernel 5: division
Each kernel reads from and writes to HBM. For a matrix of shape (4096, 4096), that is 5 read + 5 write passes = 10 HBM round-trips of 64 MB each = 640 MB of traffic. A fused kernel does 1 read + 1 write = 128 MB.
Triton Fused Softmax
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 256}, num_warps=2, num_stages=2),
triton.Config({'BLOCK_SIZE': 512}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_SIZE': 1024}, num_warps=4, num_stages=3),
triton.Config({'BLOCK_SIZE': 2048}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_SIZE': 4096}, num_warps=8, num_stages=4),
],
key=['n_cols'],
)
@triton.jit
def softmax_kernel(
input_ptr,
output_ptr,
n_cols,
input_row_stride,
output_row_stride,
BLOCK_SIZE: tl.constexpr,
):
# Each program instance processes one row
row_idx = tl.program_id(0)
# Pointers to the start of the current row
row_start_input = input_ptr + row_idx * input_row_stride
row_start_output = output_ptr + row_idx * output_row_stride
# Load the entire row into SRAM
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# Load input row
row = tl.load(row_start_input + col_offsets, mask=mask, other=-float('inf'))
# Step 1: Numerically stable max
row_max = tl.max(row, axis=0)
# Step 2: Subtract max and exponentiate
row = row - row_max
numerator = tl.exp(row)
# Step 3: Sum of exponentials
denominator = tl.sum(numerator, axis=0)
# Step 4: Normalize
softmax_output = numerator / denominator
# Store result
tl.store(row_start_output + col_offsets, softmax_output, mask=mask)
Python Wrapper
def triton_softmax(x):
n_rows, n_cols = x.shape
output = torch.empty_like(x)
# Grid: one program per row
grid = (n_rows,)
softmax_kernel[grid](
x, output,
n_cols,
x.stride(0),
output.stride(0),
)
return output
Wide-Row Softmax: Handling Rows Wider Than BLOCK_SIZE
When exceeds BLOCK_SIZE, we cannot load the entire row at once. We need a two-pass approach within a single kernel: first scan for max, then compute softmax.
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 1024}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_SIZE': 2048}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_SIZE': 4096}, num_warps=8, num_stages=4),
],
key=['n_cols'],
)
@triton.jit
def softmax_wide_kernel(
input_ptr,
output_ptr,
n_cols,
input_row_stride,
output_row_stride,
BLOCK_SIZE: tl.constexpr,
):
row_idx = tl.program_id(0)
row_input = input_ptr + row_idx * input_row_stride
row_output = output_ptr + row_idx * output_row_stride
# Pass 1: Find max across the entire row (online max)
row_max = -float('inf')
for block_start in range(0, n_cols, BLOCK_SIZE):
col_offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
block_data = tl.load(row_input + col_offsets, mask=mask,
other=-float('inf'))
block_max = tl.max(block_data, axis=0)
row_max = tl.maximum(row_max, block_max)
# Pass 2: Compute exp(x - max) and accumulate sum
sum_exp = 0.0
for block_start in range(0, n_cols, BLOCK_SIZE):
col_offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
block_data = tl.load(row_input + col_offsets, mask=mask,
other=-float('inf'))
block_exp = tl.exp(block_data - row_max)
sum_exp += tl.sum(block_exp, axis=0)
# Pass 3: Normalize and write output
for block_start in range(0, n_cols, BLOCK_SIZE):
col_offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
block_data = tl.load(row_input + col_offsets, mask=mask,
other=-float('inf'))
block_exp = tl.exp(block_data - row_max)
softmax_out = block_exp / sum_exp
tl.store(row_output + col_offsets, softmax_out, mask=mask)
For the wide-row case, you can combine Passes 1 and 2 using the online softmax algorithm (track running max and compensated sum simultaneously). This reduces the number of global memory passes from 3 to 2. The single-block version above already achieves this implicitly since the entire row fits in SRAM.
Benchmark: Triton vs PyTorch vs CUDA
Softmax Benchmark
import torch
import triton
import time
def benchmark(fn, x, warmup=25, rep=100):
for _ in range(warmup):
fn(x)
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(rep):
fn(x)
torch.cuda.synchronize()
end = time.perf_counter()
return (end - start) / rep * 1000 # ms
# Test sizes
sizes = [(4096, 256), (4096, 1024), (4096, 4096),
(4096, 8192), (4096, 16384)]
for rows, cols in sizes:
x = torch.randn(rows, cols, device='cuda', dtype=torch.float32)
# PyTorch native (calls cuDNN/internal fused kernel)
t_pytorch = benchmark(lambda x: torch.softmax(x, dim=-1), x)
# Triton
t_triton = benchmark(triton_softmax, x)
# Naive PyTorch (unfused, 5 kernels)
t_naive = benchmark(softmax_pytorch, x)
print(f"{rows}x{cols}: "
f"PyTorch={t_pytorch:.3f}ms "
f"Triton={t_triton:.3f}ms "
f"Naive={t_naive:.3f}ms "
f"Triton/PyTorch={t_triton/t_pytorch:.2f}x")
Softmax Latency (A100, 4096 rows, FP32)
| Columns | Naive PyTorch (ms) | torch.softmax (ms) | Triton (ms) | Triton vs torch |
|---|---|---|---|---|
| 256 | 0.089 | 0.021 | 0.023 | 1.10x |
| 1024 | 0.142 | 0.041 | 0.039 | 0.95x |
| 4096 | 0.385 | 0.112 | 0.098 | 0.88x |
| 8192 | 0.710 | 0.210 | 0.185 | 0.88x |
| 16384 | 1.380 | 0.415 | 0.365 | 0.88x |
Softmax Performance: Triton vs PyTorch (4096 rows)
(ms)GEMM Benchmark
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32,
'GROUP_M': 8}, num_warps=4, num_stages=3),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32,
'GROUP_M': 8}, num_warps=8, num_stages=3),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32,
'GROUP_M': 8}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32,
'GROUP_M': 8}, num_warps=4, num_stages=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
):
pid = tl.program_id(0)
# Swizzle: group adjacent blocks for better L2 cache locality
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# Pointers to first blocks of A and B
offs_am = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_bn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am +
offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
# Accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Main loop over K dimension
for k in range(0, K, BLOCK_K):
a = tl.load(a_ptrs, mask=(offs_am[:, None] < M) &
(offs_k[None, :] + k < K), other=0.0)
b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K) &
(offs_bn[None, :] < N), other=0.0)
# tl.dot compiles to tensor core instructions
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# Store output
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + (offs_cm[:, None] * stride_cm +
offs_cn[None, :] * stride_cn)
mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, acc, mask=mask)
GEMM Performance: Triton vs cuBLAS (A100, FP16 input, FP32 accum)
| M=N=K | cuBLAS (TFLOPS) | Triton (TFLOPS) | Triton / cuBLAS |
|---|---|---|---|
| 1024 | 105 | 82 | 78% |
| 2048 | 225 | 195 | 87% |
| 4096 | 280 | 248 | 89% |
| 8192 | 295 | 272 | 92% |
| 16384 | 302 | 285 | 94% |
When Triton vs When CUDA
Triton vs CUDA: Use Case Decision Matrix
| Workload | Triton | CUDA | Recommendation |
|---|---|---|---|
| Fused elementwise ops | 95%+ of CUDA perf | Baseline | Triton: less code, auto-tuning |
| Softmax / LayerNorm | 90-100% of CUDA perf | Baseline | Triton: excellent fusion |
| GEMM (large matrices) | 85-94% of cuBLAS | cuBLAS: 100% | cuBLAS unless custom epilogue |
| GEMM (small/batched) | 75-85% of cuBLAS | cuBLAS: 100% | cuBLAS |
| FlashAttention | 85-90% of hand-tuned | Hand-tuned: 100% | CUDA for production, Triton for prototyping |
| Custom sparse ops | 70-80% | Hand-tuned: 100% | CUDA: Triton lacks sparse primitives |
| Reduction kernels | 90-95% | Hand-tuned: 100% | Triton: compiler handles warp shuffles |
| Quantized inference | 80-90% | 100% (INT4/INT8 intrinsics) | CUDA for sub-byte quantization |
Where Triton Excels
- Kernel fusion: combining multiple elementwise operations into a single kernel that reads and writes HBM once
- Rapid prototyping: 3-5x less code than CUDA, Python debugging, auto-tuning
- Custom attention variants: modifying attention patterns (e.g., sliding window, causal mask, ALiBi) is straightforward in Triton
- Cross-platform: Triton is adding AMD GPU (ROCm) support, which CUDA cannot target
Where CUDA is Still Required
- Maximum tensor core utilization: CUDA (or CUTLASS) gives explicit control over MMA instruction selection, register layout, and shared memory swizzling
- Warp-level programming: explicit
__shfl_sync,__ballot_sync, custom warp-cooperative algorithms - Sub-warp granularity: algorithms that need per-thread control (e.g., warp-divergent hash table probes)
- Persistent kernels: long-running kernels that process multiple tiles without relaunching
- Multi-GPU communication kernels: NCCL-style all-reduce kernels with direct NVLink access
Advanced Triton Patterns
Fused Residual + LayerNorm + Dropout
@triton.jit
def fused_residual_layernorm_dropout(
x_ptr, residual_ptr, weight_ptr, bias_ptr, output_ptr,
n_cols, eps,
p_drop, # Dropout probability
seed, # RNG seed
stride,
BLOCK_SIZE: tl.constexpr,
):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < n_cols
# Load input and residual
x = tl.load(x_ptr + row * stride + cols, mask=mask, other=0.0)
res = tl.load(residual_ptr + row * stride + cols, mask=mask, other=0.0)
# Fused residual add
x = x + res
# LayerNorm: compute mean and variance
mean = tl.sum(x, axis=0) / n_cols
x_centered = x - mean
var = tl.sum(x_centered * x_centered, axis=0) / n_cols
inv_std = 1.0 / tl.sqrt(var + eps)
# Normalize
x_norm = x_centered * inv_std
# Scale and bias
w = tl.load(weight_ptr + cols, mask=mask, other=1.0)
b = tl.load(bias_ptr + cols, mask=mask, other=0.0)
x_out = x_norm * w + b
# Dropout (using Triton's PRNG)
rng_offsets = row * n_cols + cols
random = tl.rand(seed, rng_offsets)
drop_mask = random > p_drop
x_out = tl.where(drop_mask, x_out / (1.0 - p_drop), 0.0)
# Store
tl.store(output_ptr + row * stride + cols, x_out, mask=mask)
This fuses 6 operations (residual add, mean, variance, normalize, scale+bias, dropout) into a single kernel with 2 global memory reads and 1 write, compared to 6 separate PyTorch kernels with 12+ global memory accesses.
L2 Cache Optimization: Grouped Program Ordering
# Without grouping: programs iterate row by row through the output
# pid 0 -> (0,0), pid 1 -> (0,1), pid 2 -> (0,2), ...
# Adjacent programs access distant B columns -> poor L2 reuse
# With grouping: programs are ordered to maximize L2 locality
# GROUP_M programs in the M direction share the same B tile
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
This swizzle pattern ensures that GROUP_M consecutive programs in the M direction share the same column range of B, increasing L2 cache hit rate for B’s tiles.
Debugging and Profiling Triton Kernels
Printing from Kernels
@triton.jit
def debug_kernel(x_ptr, n, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs, mask=offs < n)
# tl.debug_barrier() # Force synchronization for debugging
# Print from the first program
if pid == 0:
tl.static_print("Block 0 loaded")
Viewing Generated PTX
# Compile and inspect
compiled = softmax_kernel.warmup(
torch.empty(1, device='cuda'), torch.empty(1, device='cuda'),
1024, 1024, 1024,
BLOCK_SIZE=1024, grid=(1,)
)
print(compiled.asm['ptx']) # PTX assembly
print(compiled.asm['ttir']) # Triton IR
Profiling with Nsight
# Profile Triton kernels with Nsight Compute
ncu --set full python my_triton_script.py
# Or with Nsight Systems for timeline
nsys profile python my_triton_script.py
Triton kernels appear in Nsight profiles as regular CUDA kernels. The kernel names are auto-generated but include the function name (e.g., softmax_kernel_0d1d2d3d4d).
Summary and Decision Framework
-
Triton is production-ready for memory-bound and fusion-heavy kernels. It is used in production at Meta (for PyTorch inductor), OpenAI, and other organizations running LLM inference.
-
The 80-95% rule: Triton achieves 80-95% of hand-written CUDA performance for most workloads. The remaining 5-20% requires explicit thread-level control that Triton’s block-level model does not expose.
-
Start with Triton, drop to CUDA when needed: write the kernel in Triton first. If profiling shows it is the bottleneck and you need that last 10%, rewrite in CUDA. The Triton implementation serves as a correct reference.
-
Auto-tuning is not optional: always use
@triton.autotune. The difference between the best and worst configuration for the same kernel can be 3-5x. -
The compiler is evolving rapidly: each Triton release improves code generation. A kernel that achieves 85% of cuBLAS today may achieve 92% with the next compiler release.
This is Part 6, the final post in the CUDA Kernel Engineering series. The series covered the full stack: thread hierarchy (Part 1), memory coalescing (Part 2), shared memory (Part 3), warp primitives (Part 4), tensor cores (Part 5), and Triton (Part 6).