A tiled matrix multiply without padding achieves 4 TFLOPS on an A100. Add one column of padding to the shared memory tile and throughput jumps to 12 TFLOPS — a 3x speedup from a single extra integer in the array dimension. The cause: shared memory is divided into 32 banks, and when multiple threads in a warp access the same bank, the hardware serializes the accesses. The padding trick shifts addresses such that consecutive threads hit different banks. This conflict-free access costs 3% extra memory but eliminates 67% of the memory stalls.
This post covers the physical bank organization, the exact rules for when bank conflicts occur, every technique for eliminating them (padding, swizzling, address permutation), the cp.async mechanism for overlapping shared memory loads with computation, and a complete tiled matrix multiply implementation that puts it all together.
All measurements target A100-80GB SXM (CC 8.0). CUDA 12.x.
Shared Memory Physical Organization
Banks and Bandwidth
Shared memory is divided into 32 equally-sized banks. Each bank is 4 bytes wide (one 32-bit word). Consecutive 4-byte words map to consecutive banks:
Address 0- 3 (word 0) -> Bank 0
Address 4- 7 (word 1) -> Bank 1
Address 8-11 (word 2) -> Bank 2
...
Address 124-127 (word 31) -> Bank 31
Address 128-131 (word 32) -> Bank 0 (wraps around)
Address 132-135 (word 33) -> Bank 1
...
The bank index for a 4-byte access at byte address addr is:
Each bank can service one 4-byte read and one 4-byte write per cycle. When all 32 threads in a warp access 32 different banks, the entire warp’s access completes in a single cycle. When threads access the same bank, those accesses serialize into cycles.
Shared Memory Capacity
The shared memory and L1 data cache share a combined on-chip memory pool:
Shared Memory Capacity by Architecture
| Architecture | Combined L1+Smem | Max Shared Memory per SM | Max Shared Memory per Block |
|---|---|---|---|
| Volta (CC 7.0) | 128 KB | 96 KB | 96 KB |
| Turing (CC 7.5) | 96 KB | 64 KB | 64 KB |
| Ampere (CC 8.0) | 192 KB | 164 KB | 163 KB |
| Ada Lovelace (CC 8.9) | 128 KB | 100 KB | 99 KB |
| Hopper (CC 9.0) | 256 KB | 228 KB | 227 KB |
To use more than 48 KB of shared memory per block (the default limit), you must opt in:
__global__ void large_smem_kernel() {
extern __shared__ float smem[];
// Use up to the configured maximum
}
// Before launch:
cudaFuncSetAttribute(
large_smem_kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
164 * 1024 // Request 164 KB on Ampere
);
// Launch with dynamic shared memory
large_smem_kernel<<<grid, block, 164 * 1024>>>();
Bank Conflict Rules
No Conflict: 32 Threads, 32 Different Banks
__shared__ float data[256];
// Thread k reads data[k] -> Bank k (for k = 0..31)
// All 32 banks accessed, no conflicts
float val = data[threadIdx.x]; // Stride-1 access
No Conflict: Broadcast
If multiple threads read the exact same address, the hardware broadcasts the value in a single transaction. This is not a bank conflict.
__shared__ float data[256];
// All 32 threads read data[0] -> Bank 0, same address
// Hardware broadcasts: 1 cycle, no conflict
float val = data[0]; // Broadcast, NOT a bank conflict
2-Way Bank Conflict
__shared__ float data[256];
// Stride-2 access: thread k reads data[2*k]
// Thread 0 -> data[0] -> Bank 0
// Thread 1 -> data[2] -> Bank 2
// Thread 16 -> data[32] -> Bank 0 (conflict with thread 0!)
// Thread 17 -> data[34] -> Bank 2 (conflict with thread 1!)
// Result: 2-way bank conflict, 2 cycles instead of 1
float val = data[threadIdx.x * 2];
N-Way Bank Conflict
Wait — that formula gives the conflict degree only when the stride divides evenly into 32. More precisely, if each thread accesses a word at index threadIdx.x * stride, the conflict degree equals the number of threads mapping to the same bank.
For stride , threads and conflict if:
This simplifies to , which means must be a multiple of . The number of threads per bank (conflict degree) is therefore:
Bank Conflict Degree by Stride
| Stride (words) | Conflict Degree | Cycles per Access | Effective Bandwidth |
|---|---|---|---|
| 1 | 1 (no conflict) | 1 | 100% |
| 2 | 2-way | 2 | 50% |
| 3 | 1 (no conflict) | 1 | 100% |
| 4 | 4-way | 4 | 25% |
| 5 | 1 (no conflict) | 1 | 100% |
| 8 | 8-way | 8 | 12.5% |
| 16 | 16-way | 16 | 6.25% |
| 32 | 32-way (full) | 32 | 3.1% |
| 33 | 1 (no conflict) | 1 | 100% |
, so the conflict degree is 1. Thread accesses word , which maps to bank . Each thread hits a unique bank. This is the basis of the padding trick.
32-Way Bank Conflict: The Worst Case
__shared__ float data[1024]; // 32 x 32
// Column access of a 32x32 matrix stored row-major
// stride = 32 words (one row = 32 floats)
// Thread k reads data[k * 32] -> Bank (k * 32) % 32 = 0 for all k
// ALL 32 threads hit Bank 0: 32-way conflict
float val = data[threadIdx.x * 32]; // 32 cycles instead of 1
The Padding Trick
The padding trick adds one extra element per row to change the stride from 32 to 33, eliminating all bank conflicts.
Before Padding
// 32x32 tile stored in shared memory
__shared__ float tile[32][32];
// Row access: stride-1, no conflicts
float row_val = tile[threadIdx.y][threadIdx.x]; // OK
// Column access: stride-32, 32-way conflict
float col_val = tile[threadIdx.x][threadIdx.y]; // BAD: 32-way conflict
// Thread 0 -> tile[0][j] -> Bank (0*32 + j) % 32 = j
// Thread 1 -> tile[1][j] -> Bank (1*32 + j) % 32 = j SAME BANK
After Padding
// Add 1 padding element per row: 33 elements per row instead of 32
__shared__ float tile[32][32 + 1]; // 32 x 33
// Row access: stride-1, no conflicts (unchanged)
float row_val = tile[threadIdx.y][threadIdx.x]; // OK
// Column access: stride-33, NO conflicts!
float col_val = tile[threadIdx.x][threadIdx.y]; // OK
// Thread 0 -> tile[0][j] -> word index = 0*33 + j -> Bank j
// Thread 1 -> tile[1][j] -> word index = 1*33 + j -> Bank (33+j) % 32 = (j+1)
// Thread 2 -> tile[2][j] -> word index = 2*33 + j -> Bank (66+j) % 32 = (j+2)
// All different banks!
The memory cost of padding is minimal: one extra float (4 bytes) per row, or 32 extra floats per 32x32 tile = 128 bytes overhead on a 4096-byte tile (3.1%).
Padding for Arbitrary Tile Sizes
For a tile of width stored in shared memory:
- If : no padding needed (already conflict-free for column access)
- If : pad to (or any value that makes )
// General rule: ensure row stride is odd (or coprime with 32)
#define TILE_W 64
#define PAD ((TILE_W % 2 == 0) ? 1 : 0) // Pad if even width
__shared__ float tile[TILE_H][TILE_W + PAD];
Swizzled Shared Memory Access
An alternative to padding is address swizzling: XOR-based permutation of the bank index that distributes accesses across banks without wasting memory.
// Swizzled store: XOR row index into column index
__shared__ float tile[32][32]; // No padding needed
// Store phase
int row = threadIdx.y;
int col = threadIdx.x;
int swizzled_col = col ^ row; // XOR row into column
tile[row][swizzled_col] = global_data[...];
// Load phase (column access): reverse the swizzle
int load_row = threadIdx.x; // transposed
int load_col = threadIdx.y; // transposed
int swizzled_load_col = load_col ^ load_row;
float val = tile[load_row][swizzled_load_col];
The XOR swizzle works because:
- For a fixed
col, varyingrowfrom 0 to 31 maps tocol ^ 0,col ^ 1, …,col ^ 31, which are 32 distinct values spanning all banks. - No wasted memory (unlike padding, which wastes one word per row).
- Same number of operations (one XOR per address computation).
Padding is simpler and works everywhere. Swizzling saves the 3% memory overhead and is used in high-performance libraries like CUTLASS. For most kernels, padding is sufficient. Use swizzling when shared memory capacity is the bottleneck and you cannot afford the padding overhead (e.g., very large tiles at the shared memory limit).
Asynchronous Copy: cp.async (Ampere+)
On Volta and earlier, loading data from global to shared memory required two steps:
// Pre-Ampere: global -> registers -> shared memory
float val = global_ptr[idx]; // Load from global to register
__syncthreads();
shared_tile[threadIdx.x] = val; // Store from register to shared
__syncthreads();
Ampere introduced cp.async, which copies directly from global to shared memory without staging in registers. This frees registers and enables overlapping the copy with computation.
#include <cuda/pipeline>
#include <cooperative_groups.h>
__global__ void cp_async_example(const float* global_data,
float* output, int n) {
__shared__ float smem[256];
auto group = cooperative_groups::this_thread_block();
// Stage 1: initiate async copy from global to shared
// Does NOT consume registers
__pipeline_memcpy_async(
&smem[threadIdx.x], // shared memory destination
&global_data[blockIdx.x * 256 + threadIdx.x], // global source
sizeof(float) // bytes to copy
);
// Commit the copy group
__pipeline_commit();
// Stage 2: do other work while copy is in flight
// ... computation using previously loaded data ...
// Stage 3: wait for the copy to complete
__pipeline_wait_prior(0);
__syncthreads();
// Stage 4: use the data in shared memory
output[blockIdx.x * 256 + threadIdx.x] = smem[threadIdx.x] * 2.0f;
}
Multi-Stage Pipeline with cp.async
The real power of cp.async is enabling multi-stage software pipelines where you overlap loading the next tile with computing on the current tile:
#define NUM_STAGES 3
#define TILE_SIZE 256
__global__ void pipelined_kernel(const float* input, float* output,
int n_tiles) {
__shared__ float smem[NUM_STAGES][TILE_SIZE];
int tid = threadIdx.x;
// Fill the pipeline: issue initial async copies
for (int stage = 0; stage < NUM_STAGES - 1 && stage < n_tiles; stage++) {
__pipeline_memcpy_async(
&smem[stage][tid],
&input[stage * TILE_SIZE + tid],
sizeof(float)
);
__pipeline_commit();
}
// Main loop: compute on tile[i] while loading tile[i + NUM_STAGES - 1]
for (int i = 0; i < n_tiles; i++) {
int compute_stage = i % NUM_STAGES;
int load_stage = (i + NUM_STAGES - 1) % NUM_STAGES;
// Issue next async load (if within bounds)
if (i + NUM_STAGES - 1 < n_tiles) {
__pipeline_memcpy_async(
&smem[load_stage][tid],
&input[(i + NUM_STAGES - 1) * TILE_SIZE + tid],
sizeof(float)
);
__pipeline_commit();
}
// Wait for the current compute tile to be ready
__pipeline_wait_prior(NUM_STAGES - 1);
__syncthreads();
// Compute on current tile
float val = smem[compute_stage][tid];
output[i * TILE_SIZE + tid] = val * val + 1.0f;
__syncthreads();
}
}
Hopper (CC 9.0) extends cp.async with cp.async.bulk through the Tensor Memory Accelerator (TMA). TMA can copy entire 2D or higher-dimensional tiles from global to shared memory in a single instruction, handling address computation, swizzling, and out-of-bounds clamping in hardware. This is the mechanism WGMMA uses (covered in Part 5 of this series).
Implementation: Tiled Matrix Multiply (GEMM)
Tiled GEMM is the canonical application of shared memory. The algorithm:
- Divide the output matrix into tiles of size
- For each tile of , iterate over the dimension in chunks of
- Load the corresponding tiles of and into shared memory (coalesced reads)
- Compute partial products from shared memory (fast, on-chip)
- Accumulate into registers
- Write the final result to global memory (coalesced writes)
Why Shared Memory Matters for GEMM
For a naive GEMM computing , each element of and is loaded from global memory times (once for each element of it contributes to). With tiling, each element of and is loaded once from global memory into shared memory, then reused times from shared memory.
For , shared memory tiling reduces global traffic by 32x.
Complete Implementation
#define TILE_M 32
#define TILE_N 32
#define TILE_K 32
#define PAD 1 // Padding to eliminate bank conflicts
// Tiled GEMM: C = A * B
// A: M x K, B: K x N, C: M x N (all row-major)
__global__ void gemm_tiled(const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int M, int N, int K) {
// Shared memory tiles with padding
__shared__ float As[TILE_M][TILE_K + PAD];
__shared__ float Bs[TILE_K][TILE_N + PAD];
// Thread position within the tile
int tx = threadIdx.x; // 0..TILE_N-1 (column within tile)
int ty = threadIdx.y; // 0..TILE_M-1 (row within tile)
// Global position of the output element
int row = blockIdx.y * TILE_M + ty;
int col = blockIdx.x * TILE_N + tx;
// Accumulator in registers
float c_val = 0.0f;
// Loop over K dimension in tiles
int n_tiles = (K + TILE_K - 1) / TILE_K;
for (int t = 0; t < n_tiles; t++) {
// Load A tile: coalesced read (threads in a warp read
// consecutive columns of A)
int a_col = t * TILE_K + tx;
if (row < M && a_col < K) {
As[ty][tx] = A[row * K + a_col];
} else {
As[ty][tx] = 0.0f;
}
// Load B tile: coalesced read (threads in a warp read
// consecutive columns of B)
int b_row = t * TILE_K + ty;
if (b_row < K && col < N) {
Bs[ty][tx] = B[b_row * N + col];
} else {
Bs[ty][tx] = 0.0f;
}
__syncthreads();
// Compute: all reads from shared memory (fast, on-chip)
// As[ty][k]: row access, stride-1, no bank conflict
// Bs[k][tx]: row access, stride-1, no bank conflict
// (Both accesses have the varying index as the column = stride-1)
#pragma unroll
for (int k = 0; k < TILE_K; k++) {
c_val += As[ty][k] * Bs[k][tx];
}
__syncthreads();
}
// Write result: coalesced (consecutive threads write
// consecutive columns)
if (row < M && col < N) {
C[row * N + col] = c_val;
}
}
Launch Configuration
void launch_gemm(const float* A, const float* B, float* C,
int M, int N, int K) {
dim3 block(TILE_N, TILE_M); // 32 x 32 = 1024 threads
dim3 grid((N + TILE_N - 1) / TILE_N,
(M + TILE_M - 1) / TILE_M);
gemm_tiled<<<grid, block>>>(A, B, C, M, N, K);
}
A 32x32 thread block uses 1024 threads (32 warps). On A100, the maximum is 2048 threads per SM, so at most 2 such blocks can be resident. If register pressure is high, this drops to 1 block per SM (50% or worse occupancy). Consider using a smaller thread block (e.g., 16x16 = 256 threads) and having each thread compute multiple output elements.
Optimized: Smaller Block, Multiple Elements per Thread
#define BM 64 // Block tile M
#define BN 64 // Block tile N
#define BK 16 // Block tile K
#define TM 4 // Thread tile M (each thread computes TM x TN output elements)
#define TN 4 // Thread tile N
// Block size: (BN/TN) x (BM/TM) = 16 x 16 = 256 threads
// Each thread computes a 4x4 sub-tile of C
__global__ void gemm_optimized(const float* __restrict__ A,
const float* __restrict__ B,
float* __restrict__ C,
int M, int N, int K) {
__shared__ float As[BK][BM + 1]; // Transposed for coalesced loads
__shared__ float Bs[BK][BN + 1];
// Thread position
int tx = threadIdx.x; // 0..15 (column of thread tiles)
int ty = threadIdx.y; // 0..15 (row of thread tiles)
// Each thread accumulates a TM x TN tile in registers
float c_reg[TM][TN] = {};
int block_row = blockIdx.y * BM;
int block_col = blockIdx.x * BN;
for (int bk = 0; bk < K; bk += BK) {
// Collaborative load of A tile
// 256 threads load BK * BM = 16 * 64 = 1024 elements
// 1024 / 256 = 4 elements per thread
int linear_tid = ty * (BN / TN) + tx;
for (int i = 0; i < (BK * BM) / 256; i++) {
int idx = linear_tid + i * 256;
int k_idx = idx / BM;
int m_idx = idx % BM;
int global_m = block_row + m_idx;
int global_k = bk + k_idx;
As[k_idx][m_idx] = (global_m < M && global_k < K)
? A[global_m * K + global_k] : 0.0f;
}
// Collaborative load of B tile
for (int i = 0; i < (BK * BN) / 256; i++) {
int idx = linear_tid + i * 256;
int k_idx = idx / BN;
int n_idx = idx % BN;
int global_k = bk + k_idx;
int global_n = block_col + n_idx;
Bs[k_idx][n_idx] = (global_k < K && global_n < N)
? B[global_k * N + global_n] : 0.0f;
}
__syncthreads();
// Compute TM x TN output sub-tile per thread
#pragma unroll
for (int k = 0; k < BK; k++) {
// Load TM elements of A into registers
float a_reg[TM];
#pragma unroll
for (int m = 0; m < TM; m++) {
a_reg[m] = As[k][ty * TM + m];
}
// Load TN elements of B into registers
float b_reg[TN];
#pragma unroll
for (int n = 0; n < TN; n++) {
b_reg[n] = Bs[k][tx * TN + n];
}
// Outer product: TM * TN FMAs
#pragma unroll
for (int m = 0; m < TM; m++) {
#pragma unroll
for (int n = 0; n < TN; n++) {
c_reg[m][n] += a_reg[m] * b_reg[n];
}
}
}
__syncthreads();
}
// Write TM x TN results to global memory
for (int m = 0; m < TM; m++) {
for (int n = 0; n < TN; n++) {
int global_m = block_row + ty * TM + m;
int global_n = block_col + tx * TN + n;
if (global_m < M && global_n < N) {
C[global_m * N + global_n] = c_reg[m][n];
}
}
}
}
GEMM Performance Comparison
GEMM Performance (A100, M=N=K=4096, FP32)
| Implementation | GFLOPS | % of Peak (19.5 TFLOPS) | Notes |
|---|---|---|---|
| Naive (no tiling) | 310 | 1.6% | Global memory bound |
| Tiled 32x32 (TILE_K=32) | 2800 | 14.4% | Shared memory, high register pressure |
| Optimized 64x64, TM=TN=4 | 6200 | 31.8% | Register tiling, 256 threads/block |
| Optimized + cp.async pipeline | 7100 | 36.4% | Overlap load and compute |
| cuBLAS (reference) | 18500 | 94.9% | Fully optimized, tensor cores eligible |
GEMM Throughput by Optimization Level
(GFLOPS)Profiling Bank Conflicts with Nsight Compute
Nsight Compute directly reports bank conflict metrics:
ncu --metrics \
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum,\
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st.sum,\
l1tex__data_pipe_lsu_wavefronts_mem_shared_op_ld.sum,\
l1tex__data_pipe_lsu_wavefronts_mem_shared_op_st.sum \
./my_gemm
Key metrics:
l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum: Total bank conflicts on shared memory loads. Should be 0 for a conflict-free kernel.l1tex__data_pipe_lsu_wavefronts_mem_shared_op_ld.sum: Total wavefronts (32-byte transactions) for shared memory loads. More wavefronts per request indicates serialization.
Conflict Degree Calculation from Metrics
A conflict-free kernel has 0 bank conflicts. A fully-conflicted (32-way) kernel has 31 bank conflicts per access (32 wavefronts where 1 would suffice, so 31 extra).
Double Buffering: Overlapping Load and Compute
Double buffering uses two shared memory buffers: while the kernel computes on one, it loads the next tile into the other.
#define TILE 32
__global__ void gemm_double_buffer(const float* A, const float* B,
float* C, int M, int N, int K) {
__shared__ float As[2][TILE][TILE + 1];
__shared__ float Bs[2][TILE][TILE + 1];
int tx = threadIdx.x, ty = threadIdx.y;
int row = blockIdx.y * TILE + ty;
int col = blockIdx.x * TILE + tx;
float acc = 0.0f;
int n_tiles = (K + TILE - 1) / TILE;
int buf = 0;
// Preload first tile into buffer 0
int ak = tx, bk = ty;
As[0][ty][tx] = (row < M && ak < K) ? A[row * K + ak] : 0.0f;
Bs[0][ty][tx] = (bk < K && col < N) ? B[bk * N + col] : 0.0f;
__syncthreads();
for (int t = 0; t < n_tiles; t++) {
int next_buf = 1 - buf;
// Initiate load of NEXT tile into next_buf
if (t + 1 < n_tiles) {
int next_ak = (t + 1) * TILE + tx;
int next_bk = (t + 1) * TILE + ty;
As[next_buf][ty][tx] = (row < M && next_ak < K)
? A[row * K + next_ak] : 0.0f;
Bs[next_buf][ty][tx] = (next_bk < K && col < N)
? B[next_bk * N + col] : 0.0f;
}
// Compute on current buffer (does not conflict with load
// because __syncthreads separates the phases)
#pragma unroll
for (int k = 0; k < TILE; k++) {
acc += As[buf][ty][k] * Bs[buf][k][tx];
}
buf = next_buf;
__syncthreads();
}
if (row < M && col < N) {
C[row * N + col] = acc;
}
}
Double buffering doubles shared memory usage. For 32x33 tiles (with padding), each buffer is bytes. Two buffers for A and B: bytes. This is well within the Ampere limit of 164 KB. For larger tiles, check that total shared memory fits within the SM’s capacity.
Common Mistakes with Shared Memory
Mistake 1: Missing __syncthreads
// BAD: Race condition
__shared__ float smem[256];
smem[threadIdx.x] = input[idx];
// Other threads may not have finished writing!
float val = smem[255 - threadIdx.x]; // Reads stale data
// GOOD: Synchronize before reading
smem[threadIdx.x] = input[idx];
__syncthreads(); // All writes complete before any read
float val = smem[255 - threadIdx.x];
Mistake 2: __syncthreads Inside Divergent Branch
// BAD: Deadlock potential if not all threads reach the sync
if (threadIdx.x < 16) {
smem[threadIdx.x] = compute();
__syncthreads(); // Only 16 of 32 threads in warp reach this
}
// GOOD: Keep sync outside the divergent branch
if (threadIdx.x < 16) {
smem[threadIdx.x] = compute();
}
__syncthreads(); // All threads participate
Mistake 3: Ignoring Padding When Memory is Tight
// If your tile just barely fits shared memory, adding padding
// might push it over the limit. Always check:
// 32 * 33 * 4 = 4224 bytes (padded) vs 32 * 32 * 4 = 4096 bytes
// For large tiles at the capacity boundary, consider swizzling instead
Summary
- Shared memory = 32 banks, 4 bytes wide. Access to the same bank from multiple threads in a warp serializes.
- Stride-1 access has no conflicts. Column access of a 32-wide array has 32-way conflicts (worst case).
- The padding trick (
[32][33]instead of[32][32]) changes the column stride from 32 to 33, eliminating all conflicts. - cp.async (Ampere+) copies global to shared without register staging. Enables multi-stage pipelines that overlap load and compute.
- Register tiling (each thread computes multiple output elements) is as important as shared memory tiling for GEMM performance.
- Profile with Nsight Compute bank conflict metrics to verify your kernel is conflict-free.
This is Part 3 of the CUDA Kernel Engineering series. Part 4 covers warp-level primitives — shuffle, vote, and match instructions that enable register-to-register communication without shared memory.