Part of Series CUDA Kernel Engineering 4 of 32
1 CUDA Thread Hierarchy: Grids, Blocks, Warps, and the Execution Model That Determines Performance 2 Memory Coalescing: Why Access Patterns Determine 10x Performance Differences 3 Shared Memory and Bank Conflicts: 32 Banks, 4-Byte Width, and the Padding Trick 4 Warp Primitives: Shuffle, Vote, Match, and Cooperative Reduction Without Shared Memory 5 Tensor Cores: WMMA, MMA, and WGMMA — Matrix Multiply at Hardware Speed 6 Triton Kernel Development: Writing GPU Kernels in Python with Auto-Tuning 7 Kernel Fusion Patterns: Elementwise, Reduction, GEMM Epilogue, and Attention Fusion 8 Nsight Compute and Nsight Systems: The Complete GPU Profiling Workflow 9 CUDA Graphs: Capture, Replay, Memory Management, and Dynamic Shape Handling 10 Atomics and Advanced Reductions: Global Atomics, Warp Reductions, and Multi-Block Coordination 11 Occupancy Calculator: Registers, Shared Memory, Block Size, and Finding the Sweet Spot 12 Vectorized Loads: float4, int4, and 128-Bit Memory Transactions for Maximum Bandwidth 13 Cooperative Groups: Sub-Warp Tiles, Block Synchronization, and Grid-Level Cooperation 14 Dynamic Parallelism: Launching Kernels from Kernels and When It Actually Helps 15 CUDA Streams and Events: Concurrent Execution, Overlap, and Synchronization Patterns 16 Reduction Patterns: Sum, Max, Histogram — From Naive to Warp-Optimized 17 Parallel Scan and Prefix Sum: Blelloch Algorithm, Work-Efficient Implementation 18 Matrix Transpose: The Canonical CUDA Optimization Problem — From Naive to Bank-Conflict-Free 19 Writing a Custom Attention Kernel: From Naive to Tiled to FlashAttention-Style 20 Debugging CUDA: compute-sanitizer, cuda-gdb, Common Errors, and Race Condition Detection 21 CUTLASS GEMM Templates: Writing High-Performance Matrix Multiply with NVIDIA's Template Library 22 Persistent Kernels: Long-Running Thread Blocks for Continuous Inference Processing 23 Memory Access Pattern Analysis: From Roofline Model to Kernel Optimization Strategy 24 CUDA Graphs for LLM Inference: Eliminating Kernel Launch Overhead from First Principles 25 CUDA Kernel Fusion: Reducing Memory Traffic for Elementwise-Heavy Workloads 26 CUDA Kernel Optimization: A Systematic Guide from Roofline to Nsight 27 CUDA Streams: Overlapping PCIe Transfers with Compute (and When It Actually Helps) 28 CUDA Unified Memory: When It Helps, When It Hurts, and Grace Hopper 29 CUDA Warp Mastery: Scheduling, Divergence, Shuffles, Occupancy, and Profiling 30 eBPF for LLM Inference Profiling: Kernel-Level Observability 31 GPU Memory Profiling: Finding Leaks, Fragmentation, and Hidden Overhead 32 The Roofline Model for GPU Kernel Optimization: From First Principles to LLM Workload Analysis

When you call __shfl_xor_sync to perform a warp-level reduction, 32 threads exchange values through their registers in a single cycle — no shared memory, no synchronization barrier, just direct register-to-register transfer at 1 TB/s effective bandwidth. A naive reduction through shared memory requires 5 synchronization barriers and 160 bytes of memory traffic. The shuffle-based version requires zero barriers and zero memory — it completes in 5 instructions totaling under 10 nanoseconds on an A100. This 10-20x performance gap is why every production CUDA library uses warp primitives for sub-block reductions.

This post covers every warp primitive in the CUDA instruction set, explains the mask parameter introduced in CUDA 9.0, provides performance data, and implements warp-level reduction that completes in exactly 5 shuffle operations.

All code targets CC 7.0+ (Volta and later, which require explicit masks). Tested on A100 (CC 8.0), CUDA 12.x.

The Mask Parameter: Active Thread Specification

Starting with CUDA 9.0, all warp-level intrinsics require an explicit mask parameter specifying which threads participate. This replaced the implicit full-warp assumption of earlier CUDA versions.

// mask = 0xffffffff: all 32 threads participate
// mask = 0x0000ffff: only threads 0-15 participate
// mask = 0xaaaaaaaa: only even-numbered threads participate

// The mask serves two purposes:
// 1. Documents which threads are expected to be active
// 2. Hardware uses it for synchronization (threads wait until
//    all masked threads reach the instruction)

unsigned int full_mask = 0xffffffff;
unsigned int lower_half = 0x0000ffff;
⚠️ Mask Must Match Active Threads

The mask must be a subset of the currently active threads. Passing 0xffffffff when some threads have diverged away (e.g., inside an if block that only some threads entered) is undefined behavior. Use __activemask() to get the current active mask, or carefully construct masks from your branch conditions.

Shuffle Instructions: __shfl_sync Family

Shuffle instructions move a 32-bit value between threads within a warp. There are four variants, each implementing a different communication pattern.

__shfl_sync: Indexed Read (Arbitrary Lane)

// T __shfl_sync(unsigned mask, T var, int srcLane, int width=warpSize);
// Returns the value of 'var' from thread 'srcLane'

// Example: all threads read the value from thread 0 (broadcast)
float val = __shfl_sync(0xffffffff, my_value, 0);
// Now all threads have thread 0's value of my_value

// Example: rotate left by 1 (each thread reads from the next thread)
int lane = threadIdx.x % 32;
float rotated = __shfl_sync(0xffffffff, my_value, (lane + 1) % 32);

__shfl_up_sync: Read from Lower Lane

// T __shfl_up_sync(unsigned mask, T var, unsigned delta, int width=warpSize);
// Thread i reads from thread (i - delta). Threads where (i - delta) < 0
// return their own value (identity for prefix operations).

// Example: shift up by 1 (each thread reads the value from the thread below it)
float shifted = __shfl_up_sync(0xffffffff, my_value, 1);
// Thread 0: gets its own value (no source below)
// Thread 1: gets thread 0's value
// Thread 2: gets thread 1's value
// ...
// Thread 31: gets thread 30's value

__shfl_down_sync: Read from Higher Lane

// T __shfl_down_sync(unsigned mask, T var, unsigned delta, int width=warpSize);
// Thread i reads from thread (i + delta). Threads where (i + delta) >= width
// return their own value.

// Example: shift down by 1
float shifted = __shfl_down_sync(0xffffffff, my_value, 1);
// Thread 0: gets thread 1's value
// Thread 1: gets thread 2's value
// ...
// Thread 30: gets thread 31's value
// Thread 31: gets its own value (no source above)

__shfl_xor_sync: Read from XOR Lane

// T __shfl_xor_sync(unsigned mask, T var, int laneMask, int width=warpSize);
// Thread i reads from thread (i XOR laneMask).
// XOR creates butterfly communication patterns used in parallel reductions.

// Example: XOR with 1 (swap adjacent pairs)
float swapped = __shfl_xor_sync(0xffffffff, my_value, 1);
// Thread 0 <-> Thread 1
// Thread 2 <-> Thread 3
// Thread 4 <-> Thread 5
// ...

// Example: XOR with 16 (swap upper and lower halves)
float half_swap = __shfl_xor_sync(0xffffffff, my_value, 16);
// Thread 0 <-> Thread 16
// Thread 1 <-> Thread 17
// ...
// Thread 15 <-> Thread 31

Shuffle Data Types

Shuffles operate on 32-bit values. For other types:

// 32-bit types: direct support
int i = __shfl_sync(mask, int_val, src);
float f = __shfl_sync(mask, float_val, src);

// 64-bit types: must split into two 32-bit halves
__device__ double shfl_double(unsigned mask, double val, int src) {
    int lo = __shfl_sync(mask, __double2loint(val), src);
    int hi = __shfl_sync(mask, __double2hiint(val), src);
    return __hiloint2double(hi, lo);
}

// 16-bit types: pack two into one 32-bit shuffle
__device__ __half2 shfl_half2(unsigned mask, __half2 val, int src) {
    unsigned bits = __shfl_sync(mask, *(unsigned*)&val, src);
    return *(__half2*)&bits;
}

Warp Reduction: 5 Shuffles, 0 Shared Memory

The most important application of shuffle instructions is parallel reduction within a warp. A warp reduction computes the sum (or min, max, etc.) of 32 values in exactly 5 shuffle-and-add steps using the butterfly pattern.

The Butterfly Reduction Pattern

Step 1: XOR mask = 16 (swap halves)
  Thread 0  += Thread 16    Thread 16 += Thread 0
  Thread 1  += Thread 17    Thread 17 += Thread 1
  ...                       ...
  Thread 15 += Thread 31    Thread 31 += Thread 15

Step 2: XOR mask = 8 (swap quarter-halves)
  Thread 0  += Thread 8     Thread 8  += Thread 0
  Thread 1  += Thread 9     Thread 9  += Thread 1
  ...

Step 3: XOR mask = 4
Step 4: XOR mask = 2
Step 5: XOR mask = 1

After 5 steps: every thread holds the sum of all 32 values
(because XOR is symmetric: both sides accumulate)

log2(32)=5\log_2(32) = 5 steps. Each step halves the communication distance. After all 5 steps, every thread has the complete reduction.

Implementation: Full Warp Reduction

// Warp reduction using __shfl_xor_sync (symmetric: all threads get result)
__device__ __forceinline__
float warp_reduce_sum_xor(float val) {
    unsigned mask = 0xffffffff;
    val += __shfl_xor_sync(mask, val, 16);
    val += __shfl_xor_sync(mask, val, 8);
    val += __shfl_xor_sync(mask, val, 4);
    val += __shfl_xor_sync(mask, val, 2);
    val += __shfl_xor_sync(mask, val, 1);
    return val;  // All 32 threads have the sum
}

// Warp reduction using __shfl_down_sync (asymmetric: only lane 0 has result)
__device__ __forceinline__
float warp_reduce_sum_down(float val) {
    unsigned mask = 0xffffffff;
    val += __shfl_down_sync(mask, val, 16);
    val += __shfl_down_sync(mask, val, 8);
    val += __shfl_down_sync(mask, val, 4);
    val += __shfl_down_sync(mask, val, 2);
    val += __shfl_down_sync(mask, val, 1);
    return val;  // Only thread 0 has the correct sum
}

Generalized Reduction

// Template for arbitrary reduction operations
template <typename T, typename Op>
__device__ __forceinline__
T warp_reduce(T val, Op op) {
    unsigned mask = 0xffffffff;
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        T other = __shfl_xor_sync(mask, val, offset);
        val = op(val, other);
    }
    return val;
}

// Usage examples
struct MaxOp {
    __device__ float operator()(float a, float b) { return fmaxf(a, b); }
};

struct MinOp {
    __device__ float operator()(float a, float b) { return fminf(a, b); }
};

// In kernel:
float max_val = warp_reduce(my_val, MaxOp{});
float min_val = warp_reduce(my_val, MinOp{});

Warp Prefix Sum (Inclusive Scan)

Prefix sum computes the running sum across threads: thread kk holds i=0kvali\sum_{i=0}^{k} \text{val}_i.

// Inclusive prefix sum within a warp using __shfl_up_sync
// After: thread k has sum of threads 0..k
__device__ __forceinline__
float warp_prefix_sum_inclusive(float val) {
    unsigned mask = 0xffffffff;

    // Step 1: add value from 1 position below
    float n = __shfl_up_sync(mask, val, 1);
    if ((threadIdx.x % 32) >= 1) val += n;

    // Step 2: add value from 2 positions below
    n = __shfl_up_sync(mask, val, 2);
    if ((threadIdx.x % 32) >= 2) val += n;

    // Step 3: add value from 4 positions below
    n = __shfl_up_sync(mask, val, 4);
    if ((threadIdx.x % 32) >= 4) val += n;

    // Step 4: add value from 8 positions below
    n = __shfl_up_sync(mask, val, 8);
    if ((threadIdx.x % 32) >= 8) val += n;

    // Step 5: add value from 16 positions below
    n = __shfl_up_sync(mask, val, 16);
    if ((threadIdx.x % 32) >= 16) val += n;

    return val;
}

// Exclusive prefix sum: shift the inclusive result down by 1
__device__ __forceinline__
float warp_prefix_sum_exclusive(float val) {
    float inclusive = warp_prefix_sum_inclusive(val);
    // Shift down: thread k gets thread (k-1)'s inclusive sum
    float exclusive = __shfl_up_sync(0xffffffff, inclusive, 1);
    return ((threadIdx.x % 32) == 0) ? 0.0f : exclusive;
}

Vote Instructions: Collective Predicates

Vote instructions evaluate a boolean predicate across all threads in a warp and return collective results.

__ballot_sync: Predicate to Bitmask

// unsigned __ballot_sync(unsigned mask, int predicate);
// Returns a 32-bit mask where bit i is set if thread i's predicate is nonzero

// Example: find which threads have values greater than threshold
float val = data[idx];
unsigned active = __ballot_sync(0xffffffff, val > threshold);
// active bit i = 1 if thread i has val > threshold

// Count how many threads satisfy the condition
int count = __popc(active);  // Population count (number of set bits)

// Check if thread 5 specifically satisfies the condition
bool thread5_active = (active >> 5) & 1;

__any_sync and __all_sync: Warp-Wide Predicates

// int __any_sync(unsigned mask, int predicate);
// Returns nonzero if ANY thread in the warp has a nonzero predicate

// int __all_sync(unsigned mask, int predicate);
// Returns nonzero if ALL threads in the warp have a nonzero predicate

// Example: early exit if all threads are out of bounds
if (__all_sync(0xffffffff, idx >= n)) {
    return;  // ALL threads in this warp are out of bounds, skip entirely
}

// Example: check if any thread found an error
int error = validate(data[idx]);
if (__any_sync(0xffffffff, error)) {
    // At least one thread in the warp detected an error
    // Handle collectively
    if (threadIdx.x % 32 == 0) {
        atomicAdd(error_count, __popc(__ballot_sync(0xffffffff, error)));
    }
}

Practical Use: Warp-Level Control Flow

// Efficiently skip empty work using ballot
__global__ void sparse_compute(const float* values,
                                const int* flags,
                                float* output, int n) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;
    int flag = (idx < n) ? flags[idx] : 0;

    // Get mask of threads with work to do
    unsigned active_mask = __ballot_sync(0xffffffff, flag != 0);

    if (active_mask == 0) {
        return;  // No thread in this warp has work, skip entirely
    }

    // Compact: compute only for active threads
    if (flag != 0) {
        output[idx] = expensive_computation(values[idx]);
    }

    // Count active threads for load balancing statistics
    if ((threadIdx.x % 32) == 0) {
        atomicAdd(active_count, __popc(active_mask));
    }
}

Match Instructions: Finding Identical Values

Match instructions (CC 7.0+) identify threads that hold the same value, enabling efficient group operations.

__match_any_sync

// unsigned __match_any_sync(unsigned mask, T value);
// Returns a bitmask of all threads in the mask that have the same value
// as the calling thread.

// Example: group threads by their bucket
int bucket = data[idx] % NUM_BUCKETS;
unsigned peers = __match_any_sync(0xffffffff, bucket);

// 'peers' contains bits set for all threads with the same bucket value
// Thread can now do a warp-level reduction among just its peer group

int peer_count = __popc(peers);  // How many threads share my bucket

// Find my rank within the peer group (for scatter)
unsigned lower_peers = peers & ((1u << (threadIdx.x % 32)) - 1);
int my_rank = __popc(lower_peers);

__match_all_sync

// unsigned __match_all_sync(unsigned mask, T value, int* pred);
// Returns mask if ALL threads have the same value, 0 otherwise.
// Sets *pred to 1 if all match, 0 if not.

int pred;
unsigned result = __match_all_sync(0xffffffff, my_value, &pred);

if (pred) {
    // All 32 threads have the same value -- can broadcast/deduplicate
    if ((threadIdx.x % 32) == 0) {
        // Only one thread needs to do the work
        process(my_value);
    }
}

Application: Warp-Level Histogram

// Compute histogram within a warp using match + ballot
__device__ void warp_histogram(int bin, int* histogram, int num_bins) {
    for (int b = 0; b < num_bins; b++) {
        unsigned match = __ballot_sync(0xffffffff, bin == b);
        if ((threadIdx.x % 32) == 0) {
            atomicAdd(&histogram[b], __popc(match));
        }
    }
}

// More efficient version using __match_any_sync
__device__ void warp_histogram_fast(int bin, int* histogram) {
    // Find all threads with the same bin value
    unsigned peers = __match_any_sync(0xffffffff, bin);
    int count = __popc(peers);

    // Only the lowest-numbered thread in each group does the atomic
    unsigned lower = peers & ((1u << (threadIdx.x % 32)) - 1);
    if (__popc(lower) == 0) {
        // I am the first thread in my peer group
        atomicAdd(&histogram[bin], count);
    }
}

Block-Level Reduction Using Warp Primitives

Combining warp-level reduction with a small amount of shared memory gives an efficient block-level reduction:

#define WARP_SIZE 32

__device__ __forceinline__
float warp_reduce_sum(float val) {
    #pragma unroll
    for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
        val += __shfl_down_sync(0xffffffff, val, offset);
    }
    return val;
}

// Block-level reduction: warp reduces first, then shared memory
// for inter-warp communication
__device__ float block_reduce_sum(float val) {
    __shared__ float warp_sums[32];  // Max 32 warps per block (1024 threads)

    int lane = threadIdx.x % WARP_SIZE;
    int warp_id = threadIdx.x / WARP_SIZE;
    int num_warps = blockDim.x / WARP_SIZE;

    // Step 1: warp-level reduction (no shared memory)
    val = warp_reduce_sum(val);

    // Step 2: warp leaders write to shared memory
    if (lane == 0) {
        warp_sums[warp_id] = val;
    }
    __syncthreads();

    // Step 3: first warp reduces the warp sums
    // (only need num_warps values, which is <= 32)
    if (warp_id == 0) {
        val = (lane < num_warps) ? warp_sums[lane] : 0.0f;
        val = warp_reduce_sum(val);
    }

    return val;  // Only thread 0 has the final result
}

// Complete kernel: reduce N elements to a single sum
__global__ void reduce_kernel(const float* input, float* output, int n) {
    int idx = threadIdx.x + blockIdx.x * blockDim.x;
    int stride = blockDim.x * gridDim.x;

    // Grid-stride loop: accumulate per-thread sum
    float sum = 0.0f;
    for (int i = idx; i < n; i += stride) {
        sum += input[i];
    }

    // Block-level reduction
    sum = block_reduce_sum(sum);

    // Block leader writes to global memory
    if (threadIdx.x == 0) {
        atomicAdd(output, sum);
    }
}

Performance: Warp Shuffle vs Shared Memory Reduction

📊

Reduction Throughput (A100, N=16M float32)

MethodBandwidth (GB/s)Kernel Time (us)Notes
Shared memory tree reduction 1680 38.1 Classic approach: log2(N) steps in smem
Warp shuffle + smem (hybrid) 1850 34.6 Shuffle within warps, smem between warps
Warp shuffle only (CUB) 1870 34.2 CUB DeviceReduce, highly optimized
Theoretical peak (read only) 2039 31.4 Memory bandwidth limited
Note: The hybrid approach (shuffle within warps, shared memory between warps) achieves 91% of peak bandwidth. The improvement over pure shared memory comes from eliminating bank conflicts and synchronization barriers within warps.

Reduction Throughput Comparison

(GB/s)
Shared mem tree
1,680 GB/s
Shuffle + smem 91% peak
1,850 GB/s
CUB DeviceReduce
1,870 GB/s
Peak (read-only) Theoretical
2,039 GB/s

Advanced Patterns

Pattern: Warp-Level Broadcast

// Broadcast a value from a specific lane to all lanes
__device__ float warp_broadcast(float val, int src_lane) {
    return __shfl_sync(0xffffffff, val, src_lane);
}

// Broadcast the maximum value and its lane index
__device__ float warp_broadcast_max(float val, int* max_lane) {
    float max_val = val;
    int my_lane = threadIdx.x % 32;

    // Reduce to find max
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        float other = __shfl_xor_sync(0xffffffff, max_val, offset);
        max_val = fmaxf(max_val, other);
    }

    // Find which lane had the maximum
    unsigned has_max = __ballot_sync(0xffffffff, val == max_val);
    *max_lane = __ffs(has_max) - 1;  // First set bit (lowest lane with max)

    return max_val;
}

Pattern: Warp-Level Compaction (Stream Compaction)

// Compact active elements within a warp: remove "holes"
// Input:  [a, _, b, _, _, c, d, _, ...]  (some elements invalid)
// Output: [a, b, c, d, ...]  (compacted to front)
__device__ int warp_compact(float val, bool active, float* output,
                             int output_offset) {
    unsigned active_mask = __ballot_sync(0xffffffff, active);
    int active_count = __popc(active_mask);

    // Compute my position in the compacted output
    unsigned lower_mask = active_mask & ((1u << (threadIdx.x % 32)) - 1);
    int my_pos = __popc(lower_mask);

    if (active) {
        output[output_offset + my_pos] = val;
    }

    return active_count;  // Return number of active elements
}

Pattern: Segmented Reduction

// Reduce within segments of a warp (variable-length groups)
// flags[i] = 1 marks the START of a new segment
__device__ float segmented_reduce_sum(float val, int flag) {
    // Build segment membership mask
    unsigned segment_heads = __ballot_sync(0xffffffff, flag);
    int lane = threadIdx.x % 32;

    // Find the start of my segment (highest set bit below my position)
    unsigned my_segment_mask = segment_heads & ((1u << (lane + 1)) - 1);
    int segment_start = 31 - __clz(my_segment_mask);  // Highest set bit

    // Create a mask for my segment
    unsigned full_mask = 0xffffffff;
    unsigned segment_mask;
    if (segment_start == lane) {
        // I am a segment head -- find the next segment head
        unsigned above = segment_heads & ~((1u << (lane + 1)) - 1);
        int next_head = (above != 0) ? __ffs(above) - 1 : 32;
        segment_mask = ((next_head < 32)
                       ? ((1u << next_head) - 1) : full_mask)
                       & ~((1u << lane) - 1);
    }
    // Broadcast the segment mask from the segment head to all members
    segment_mask = __shfl_sync(full_mask, segment_mask, segment_start);

    // Reduce within segment using masked XOR shuffles
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        float other = __shfl_xor_sync(full_mask, val, offset);
        // Only accumulate if the other thread is in the same segment
        if (segment_mask & (1u << ((lane ^ offset) % 32))) {
            val += other;
        }
    }

    return val;
}

Full Implementation: Softmax with Warp Primitives

Softmax is a natural fit for warp primitives: it requires a max reduction, an exponentiation, and a sum reduction over each row.

// Warp-level softmax: one warp processes one row of 32 elements
// For rows wider than 32, each thread handles multiple elements
__global__ void warp_softmax(const float* __restrict__ input,
                              float* __restrict__ output,
                              int rows, int cols) {
    int row = blockIdx.x * (blockDim.x / 32) + (threadIdx.x / 32);
    int lane = threadIdx.x % 32;

    if (row >= rows) return;

    const float* row_input = input + row * cols;
    float* row_output = output + row * cols;

    // Step 1: find max (for numerical stability)
    float max_val = -INFINITY;
    for (int c = lane; c < cols; c += 32) {
        max_val = fmaxf(max_val, row_input[c]);
    }
    // Warp reduce to find row maximum
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        max_val = fmaxf(max_val, __shfl_xor_sync(0xffffffff, max_val, offset));
    }
    // Now all 32 threads have the row maximum

    // Step 2: compute exp(x - max) and sum
    float sum = 0.0f;
    for (int c = lane; c < cols; c += 32) {
        float exp_val = expf(row_input[c] - max_val);
        row_output[c] = exp_val;  // Store intermediate
        sum += exp_val;
    }
    // Warp reduce to find sum of exponentials
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        sum += __shfl_xor_sync(0xffffffff, sum, offset);
    }

    // Step 3: normalize
    float inv_sum = 1.0f / sum;
    for (int c = lane; c < cols; c += 32) {
        row_output[c] *= inv_sum;
    }
}

Optimized: Online Softmax (Single Pass)

The online softmax algorithm computes max and sum in a single pass using the identity:

ieximnew=emoldmnewieximold\sum_i e^{x_i - m_{\text{new}}} = e^{m_{\text{old}} - m_{\text{new}}} \sum_i e^{x_i - m_{\text{old}}}

// Online softmax: single-pass max + sum computation
__global__ void warp_softmax_online(const float* __restrict__ input,
                                     float* __restrict__ output,
                                     int rows, int cols) {
    int row = blockIdx.x * (blockDim.x / 32) + (threadIdx.x / 32);
    int lane = threadIdx.x % 32;

    if (row >= rows) return;

    const float* row_in = input + row * cols;
    float* row_out = output + row * cols;

    // Single-pass: track running max and compensated sum
    float max_val = -INFINITY;
    float sum_exp = 0.0f;

    for (int c = lane; c < cols; c += 32) {
        float x = row_in[c];
        float old_max = max_val;
        max_val = fmaxf(max_val, x);
        // Rescale existing sum to new max
        sum_exp = sum_exp * expf(old_max - max_val) + expf(x - max_val);
    }

    // Warp-level online reduction: combine (max, sum) pairs
    // When merging two partial results (max_a, sum_a) and (max_b, sum_b):
    // new_max = max(max_a, max_b)
    // new_sum = sum_a * exp(max_a - new_max) + sum_b * exp(max_b - new_max)
    #pragma unroll
    for (int offset = 16; offset > 0; offset /= 2) {
        float other_max = __shfl_xor_sync(0xffffffff, max_val, offset);
        float other_sum = __shfl_xor_sync(0xffffffff, sum_exp, offset);
        float new_max = fmaxf(max_val, other_max);
        sum_exp = sum_exp * expf(max_val - new_max)
                + other_sum * expf(other_max - new_max);
        max_val = new_max;
    }

    // Second pass: write normalized values
    float inv_sum = 1.0f / sum_exp;
    for (int c = lane; c < cols; c += 32) {
        row_out[c] = expf(row_in[c] - max_val) * inv_sum;
    }
}
📊

Softmax Performance (A100, 4096 rows x 1024 cols, FP32)

ImplementationKernel Time (us)Bandwidth (GB/s)Notes
Shared memory (3 passes) 89 564 max + exp_sum + normalize
Warp shuffle (2 passes) 62 812 Online max+sum, then normalize
Warp shuffle (online, 2 passes) 55 915 Fused max+sum, coalesced output
PyTorch torch.softmax 48 1048 Fused kernel, vectorized loads
Note: Warp primitives eliminate shared memory overhead for the reduction phases. The remaining gap to PyTorch is due to vectorized memory access (float4 loads) and more aggressive instruction scheduling.

Performance Characteristics and Limits

📊

Warp Primitive Latency (Volta+ architectures)

InstructionLatency (cycles)Throughput (per SM per cycle)Notes
__shfl_sync (any variant) 1 32 (one per warp scheduler) Register-to-register, no memory
__ballot_sync 1 32 Predicate-to-bitmask
__any_sync / __all_sync 1 32 Predicate reduction
__match_any_sync 1-2 16-32 Value comparison across warp
__popc (bit count) 1 16 (SFU) Special function unit
Shared memory load (no conflict) ~20 32 On-chip SRAM
Global memory load (L1 hit) ~30 32 Cache hit
Global memory load (L2 hit) ~200 varies Cross-SM cache
Note: Warp shuffles are 20x faster than shared memory and 200x faster than uncached global memory. Use them wherever the communication pattern fits within a warp.

Common Mistakes

Mistake 1: Using Shuffle with Divergent Threads

// BAD: Some threads may be inactive
if (threadIdx.x < 16) {
    // Only 16 threads reach here, but mask says 32
    float val = __shfl_sync(0xffffffff, data, 0);  // UB!
}

// GOOD: Use correct mask
if (threadIdx.x < 16) {
    float val = __shfl_sync(0x0000ffff, data, 0);  // Only lower 16
}

Mistake 2: Assuming __shfl_down Returns 0 for Out-of-Range

// __shfl_down_sync returns the CALLING thread's value when
// (lane + delta) >= width, NOT zero

float val = __shfl_down_sync(0xffffffff, my_val, 16);
// Thread 31: val = my_val (NOT 0), because 31+16 >= 32

// If you need 0 for out-of-range:
int lane = threadIdx.x % 32;
float val = (lane + 16 < 32) ?
    __shfl_down_sync(0xffffffff, my_val, 16) : 0.0f;

Mistake 3: Reducing Across Block Without __syncthreads

// BAD: warp sums written to shared memory without sync
float warp_sum = warp_reduce_sum(val);
if (lane == 0) shared[warp_id] = warp_sum;
// Missing __syncthreads() here!
if (warp_id == 0) {
    float total = warp_reduce_sum(shared[lane]);  // Reads stale data
}
ℹ️ Series Navigation

This is Part 4 of the CUDA Kernel Engineering series. Part 5 covers tensor cores — WMMA, MMA, and WGMMA interfaces for matrix multiply at hardware speed.