Part of Series Inference Optimization Timeline 57 of 60
1 Transformer Fundamentals for Systems Engineers: The 10-Minute Bridge from Architecture to Inference 2 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 3 KV Cache: The Hidden Memory Giant in LLM Serving 4 Quantization for LLM Inference: From FP16 to INT4 โ€” A Deep Dive into Precision, Performance, and Production Deployment 5 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 6 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 7 Continuous Batching: The Complete Guide to LLM Inference Scheduling 8 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 9 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 10 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 11 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 12 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 13 Mamba and State Space Models: The O(n) Alternative to Attention 14 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 15 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 16 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 17 Model Loading and Cold Start: safetensors, mmap, and Startup Optimization 18 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 19 Kernel Autotuning: How TensorRT and torch.compile Find Optimal CUDA Kernels 20 Attention Kernel Comparison: FlashAttention vs FlashInfer vs xformers vs Triton 21 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 22 Dynamic Batching: Orca, Sarathi, and Iteration-Level Scheduling Algorithms 23 Memory Pool Management: Slab Allocators for GPU Inference 24 Prefill vs Decode Optimization: Different Bottlenecks, Different Solutions 25 Decode Optimization: CUDA Graphs, Persistent Batches, and Speculative Verification 26 Multi-Model Serving: GPU Sharing, Model Switching, and Adapter Pool Management 27 Structured Output Acceleration: Compressed FSMs, Speculative JSON, and Grammar Caching 28 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 29 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 30 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 31 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 32 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification 33 Disaggregated Serving v2: Mooncake KV-Centric Architecture and LoongServe Elastic SP 34 Request Preemption and Priority Scheduling in Production LLM Serving 35 Autoscaling LLM Inference: Signals, Lag, Warm Pools, and Cost-Optimal Scaling 36 The Inference Stack in 2026: From HTTP Request to GPU Kernel and Back 37 Video and Audio LLM Serving: Temporal Encoding, Chunked Streaming, and Latency Budgets 38 KV Cache Compression and Eviction: H2O, Attention Sinks, Sliding Window, and Quantized KV 39 Distributed Inference: Tensor Parallelism vs Pipeline Parallelism for Serving 40 Serving Benchmark Methodology: How to Properly Measure LLM Inference Performance 41 Compute-Communication Overlap: Hiding Distributed Training Latency 42 DeepSpeed ZeRO: Memory Optimization for Distributed Training at Scale 43 Pipeline Parallelism: From GPipe to DualPipe -- Eliminating the Bubble 44 Gradient Compression for Distributed Training: Promise, Reality, and Where It Still Wins 45 The Definitive Guide to Distributed Parallelism: Data, Tensor, Pipeline, Expert, and Sequence Parallelism for Large-Scale Training 46 Decoding Performance: Beam Search vs Sampling โ€” Latency, Throughput, Memory, and the Full Design Space 47 LLM Prefill Phase Optimization: Why Prompt Processing Is Compute-Bound and How to Fix It 48 LLM Serving Engines: vLLM vs SGLang vs TensorRT-LLM โ€” A Systems Comparison 49 Request Routing for LLM Inference: From Naive Load Balancing to KV Cache-Aware Scheduling 50 Why Adam Is Expensive and What To Do About It: 8-bit Adam, Adafactor, CAME, and the Memory Math of Optimizers 51 How Large Models Actually Get Loaded: Safetensors, mmap, Tensor Parallelism, and Progressive Loading 52 Mixed Precision Training: The Complete Precision Landscape from FP32 to FP4 53 Model Compression: Pruning, Distillation, and Why Quantization Won 54 From NAS to Scaling Laws: How We Design LLM Architectures Now 55 NVIDIA NCCL Performance Tuning for Multi-GPU Training 56 ONNX Runtime in Practice: Graph Optimization, Execution Providers, Quantization, and When ORT Is the Right Choice 57 Optimizing GEMM for Neural Networks: BLAS vs Custom Kernels (Nov 2019) 58 Long Context: From Sparse Attention to Ring Attention 59 TensorRT-LLM: Graph Optimization for Maximum Inference Performance 60 Long Context LLMs: From 2K to 1M Tokens

Neural network training is 80-95% GEMM operations. Every fully-connected layer, every convolutional layer (via im2col), every attention projection โ€” GEMM. If your GEMM is slow, your training is slow, period. By 2019, the choice was clear: use cuBLAS and accept its generic optimizations for arbitrary matrix sizes, or write custom kernels tuned for the specific shapes neural networks actually use. The tradeoff: cuBLAS gives you portability and correctness with zero effort. Custom kernels give you 1.5-3x speedup on the shapes that matter (batch sizes of 32-256, hidden dimensions of 512-8192) at the cost of maintenance hell. This post covers the BLAS landscape, the performance gaps that motivated custom kernels, and why CUTLASS templates now give you the best of both worlds.

GEMM in Neural Networks

The Fundamental Operation

GEMM operations appear throughout neural networks:

import numpy as np
import torch

def dense_layer_forward(input_tensor, weight_matrix, bias_vector):
    """
    Standard dense layer computation: Y = XW + b
    This is essentially a GEMM operation followed by bias addition
    """
    # GEMM operation: input_tensor @ weight_matrix
    output = torch.mm(input_tensor, weight_matrix)
    
    # Add bias
    if bias_vector is not None:
        output += bias_vector
    
    return output

def convolution_as_gemm(input_tensor, weight_tensor, stride=1, padding=0):
    """
    Convolution can be expressed as GEMM through im2col transformation
    """
    # Convert convolution to matrix multiplication via im2col
    input_col = im2col(input_tensor, weight_tensor.shape, stride, padding)
    weight_row = weight_tensor.view(weight_tensor.size(0), -1)
    
    # GEMM operation
    output_col = torch.mm(input_col, weight_row.t())
    
    # Convert back to convolution output shape
    output = col2im(output_col, input_tensor.shape, weight_tensor.shape, stride, padding)
    
    return output

def im2col(input_tensor, kernel_shape, stride, padding):
    """
    Transform input for convolution-as-matrix-multiplication
    """
    # Implementation details for converting convolution to GEMM
    # This is a simplified representation
    pass

# GEMM is fundamental: C = alpha * A * B + beta * C
def gemm_basic(alpha, A, B, beta, C):
    """
    Basic GEMM operation: C = alpha * A * B + beta * C
    """
    return alpha * torch.mm(A, B) + beta * C
๐Ÿ“Š

GEMM Operations in Neural Networks

Layer TypeOperationGEMM EquivalenceFLOPs per Element
Dense/Linear Y = XW + b Direct (matrix mult) 2 * input_size
Conv2D Convolution Via im2col transform 2 * kernel_sizeยฒ
Attention QK^T, AV Multiple GEMMs 2 * sequence_length
RNN W_ih * x + W_hh * h Multiple GEMMs 2 * hidden_size

BLAS Libraries for GEMM

Industry-Standard Libraries

// Example of using BLAS for GEMM operations
extern "C" {
    #include <cblas.h>
}

void blas_gemm_example() {
    const int M = 1024;  // Rows of A and C
    const int N = 512;   // Columns of B and C
    const int K = 768;   // Columns of A, Rows of B
    
    // Allocate matrices
    float *A = (float*)malloc(M * K * sizeof(float));
    float *B = (float*)malloc(K * N * sizeof(float));
    float *C = (float*)malloc(M * N * sizeof(float));
    
    // Initialize matrices
    // ... initialization code ...
    
    // Perform GEMM: C = alpha * A * B + beta * C
    const float alpha = 1.0f;
    const float beta = 0.0f;
    
    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans,
                M, N, K,
                alpha,
                A, K,    // lda = K for row-major
                B, N,    // ldb = N for row-major
                beta,
                C, N);   // ldc = N for row-major
    
    free(A); free(B); free(C);
}
import torch
import numpy as np
from scipy.linalg.lapack import get_lapack_funcs

def compare_blas_implementations():
    """
    Compare different BLAS implementations
    """
    # Matrix dimensions
    M, N, K = 1024, 1024, 1024
    
    # Create random matrices
    A = torch.randn(M, K).cuda()
    B = torch.randn(K, N).cuda()
    
    # PyTorch uses optimized BLAS under the hood (MKL, cuBLAS, etc.)
    import time
    start_time = time.time()
    C_pytorch = torch.mm(A, B)
    torch.cuda.synchronize()  # Ensure GPU computation completes
    pytorch_time = time.time() - start_time
    
    # Using NumPy (typically linked to MKL, OpenBLAS, etc.)
    A_np = A.cpu().numpy()
    B_np = B.cpu().numpy()
    start_time = time.time()
    C_numpy = np.dot(A_np, B_np)
    numpy_time = time.time() - start_time
    
    return {
        'pytorch_cublas': pytorch_time,
        'numpy_blas': numpy_time,
        'size_gflops': (M * N * K * 2) / 1e9,  # 2 ops per multiply-add
        'pytorch_gflops': (M * N * K * 2) / (pytorch_time * 1e9),
        'numpy_gflops': (M * N * K * 2) / (numpy_time * 1e9)
    }

# November 2019 BLAS landscape
blas_implementations = {
    'Intel MKL': {
        'vendor': 'Intel',
        'target': 'x86_64 CPUs',
        'optimization': 'AVX-512, threading',
        'typical_performance': '100-300 GFLOPS on modern CPUs'
    },
    'OpenBLAS': {
        'vendor': 'Community',
        'target': 'Various CPUs',
        'optimization': 'Architecture-specific, threading',
        'typical_performance': '80-250 GFLOPS'
    },
    'cuBLAS': {
        'vendor': 'NVIDIA',
        'target': 'NVIDIA GPUs',
        'optimization': 'Tensor Cores, memory coalescing',
        'typical_performance': '5-100 TFLOPS on modern GPUs'
    },
    'clBLAS': {
        'vendor': 'Community',
        'target': 'OpenCL devices',
        'optimization': 'Heterogeneous computing',
        'typical_performance': 'Variable'
    }
}
๐Ÿ“Š

BLAS Library Performance Comparison (Nov 2019)

LibraryPlatformPeak PerformanceEfficiencyOptimization Level
Intel MKL Skylake-X 2.5 TFLOPS 90% High
OpenBLAS Skylake-X 2.2 TFLOPS 80% Medium
cuBLAS V100 125 TFLOPS 95% Very High
cuBLAS T4 65 TFLOPS 85% High
clBLAS AMD Vega 18 TFLOPS 70% Medium

Custom GEMM Kernels

Why Custom Kernels?

While BLAS libraries are highly optimized, custom kernels can provide benefits for specific neural network patterns:

// Custom GEMM kernel optimized for neural network workloads
__global__ void custom_gemm_kernel(
    const float* __restrict__ A,
    const float* __restrict__ B, 
    float* __restrict__ C,
    const int M, const int N, const int K,
    const float alpha, const float beta) {
    
    // Tile size for register blocking
    const int TILE_SIZE = 16;
    
    // Thread block indices
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    
    // Shared memory for tiling
    __shared__ float As[TILE_SIZE][TILE_SIZE];
    __shared__ float Bs[TILE_SIZE][TILE_SIZE];
    
    // Initialize accumulator
    float acc = 0.0f;
    
    // Loop over tiles
    for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; ++t) {
        // Load tile of A into shared memory
        int a_row = by * TILE_SIZE + ty;
        int a_col = t * TILE_SIZE + tx;
        As[ty][tx] = (a_row < M && a_col < K) ? A[a_row * K + a_col] : 0.0f;
        
        // Load tile of B into shared memory
        int b_row = t * TILE_SIZE + ty;
        int b_col = bx * TILE_SIZE + tx;
        Bs[ty][tx] = (b_row < K && b_col < N) ? B[b_row * N + b_col] : 0.0f;
        
        // Synchronize to ensure loading is complete
        __syncthreads();
        
        // Compute partial dot product
        for (int k = 0; k < TILE_SIZE; ++k) {
            acc += As[ty][k] * Bs[k][tx];
        }
        
        // Synchronize before next tile
        __syncthreads();
    }
    
    // Write result
    int c_row = by * TILE_SIZE + ty;
    int c_col = bx * TILE_SIZE + tx;
    if (c_row < M && c_col < N) {
        C[c_row * N + c_col] = alpha * acc + beta * C[c_row * N + c_col];
    }
}

// Host function to launch custom kernel
void launch_custom_gemm(const float* A, const float* B, float* C, 
                       int M, int N, int K, float alpha, float beta) {
    const int TILE_SIZE = 16;
    dim3 block_size(TILE_SIZE, TILE_SIZE);
    dim3 grid_size((N + TILE_SIZE - 1) / TILE_SIZE, 
                   (M + TILE_SIZE - 1) / TILE_SIZE);
    
    custom_gemm_kernel<<<grid_size, block_size>>>(
        A, B, C, M, N, K, alpha, beta
    );
    cudaDeviceSynchronize();
}

Optimized Memory Access Patterns

// Custom kernel with optimized memory access
__global__ void optimized_gemm_kernel(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    const int M, const int N, const int K) {
    
    // Use vectorized loads for better memory bandwidth
    const int ROW_STRIDE = 8;  // Process 8 rows per thread block
    const int COL_STRIDE = 32; // Process 32 cols per thread block
    
    const int row_start = blockIdx.y * ROW_STRIDE;
    const int col_start = blockIdx.x * COL_STRIDE;
    
    // Register blocking for computation
    float reg_A[ROW_STRIDE];
    float reg_B[COL_STRIDE];
    float acc[ROW_STRIDE][COL_STRIDE] = {0.0f};
    
    // Loop over K dimension in tiles
    for (int k = 0; k < K; k += 16) {
        // Load A values
        #pragma unroll
        for (int i = 0; i < ROW_STRIDE; ++i) {
            int row = row_start + i;
            if (row < M && k < K) {
                reg_A[i] = A[row * K + k];
            } else {
                reg_A[i] = 0.0f;
            }
        }
        
        // Load B values
        #pragma unroll
        for (int j = 0; j < COL_STRIDE; ++j) {
            int col = col_start + j;
            if (k < K && col < N) {
                reg_B[j] = B[k * N + col];
            } else {
                reg_B[j] = 0.0f;
            }
        }
        
        // Compute products
        #pragma unroll
        for (int i = 0; i < ROW_STRIDE; ++i) {
            #pragma unroll
            for (int j = 0; j < COL_STRIDE; ++j) {
                acc[i][j] += reg_A[i] * reg_B[j];
            }
        }
    }
    
    // Write results
    #pragma unroll
    for (int i = 0; i < ROW_STRIDE; ++i) {
        int row = row_start + i;
        #pragma unroll
        for (int j = 0; j < COL_STRIDE; ++j) {
            int col = col_start + j;
            if (row < M && col < N) {
                C[row * N + col] = acc[i][j];
            }
        }
    }
}

Custom vs BLAS GEMM Performance

(GFLOPS)
๐Ÿ“Š bar chart (GFLOPS)

Performance Analysis and Comparison

Benchmarking Methodology

import time
import torch
import numpy as np

def benchmark_gemm_implementations():
    """
    Comprehensive benchmark of GEMM implementations
    """
    # Test different matrix sizes typical in neural networks
    test_sizes = [
        (512, 512, 512),    # Small layer
        (1024, 1024, 1024), # Medium layer
        (2048, 2048, 2048), # Large layer
        (4096, 512, 4096),  # Wide matrix (attention)
        (512, 4096, 4096),  # Tall matrix (projection)
    ]
    
    results = {}
    
    for M, N, K in test_sizes:
        print(f"Benchmarking size: ({M}, {N}, {K})")
        
        # Create random matrices
        A = torch.randn(M, K, device='cuda', dtype=torch.float32)
        B = torch.randn(K, N, device='cuda', dtype=torch.float32)
        
        # Warm up GPU
        for _ in range(5):
            _ = torch.mm(A, B)
        torch.cuda.synchronize()
        
        # Benchmark PyTorch/cuBLAS
        times = []
        for _ in range(10):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            
            start.record()
            C_blas = torch.mm(A, B)
            end.record()
            
            torch.cuda.synchronize()
            times.append(start.elapsed_time(end))
        
        blas_avg_time = np.mean(times[2:])  # Skip first few for warmup
        blas_gflops = (2.0 * M * N * K) / (blas_avg_time * 1e6)  # Convert to GFLOPS
        
        # Store results
        size_key = f"{M}x{N}x{K}"
        results[size_key] = {
            'blas_time_ms': blas_avg_time,
            'blas_gflops': blas_gflops,
            'flops_calculated': 2 * M * N * K  # Multiply-adds
        }
    
    return results

def analyze_memory_bandwidth_requirements(matrix_size):
    """
    Analyze memory bandwidth requirements for GEMM
    """
    M, N, K = matrix_size
    
    # Memory operations for GEMM: A(M*K) + B(K*N) + C(M*N) 
    total_memory_bytes = (M * K + K * N + M * N) * 4  # 4 bytes per float32
    arithmetic_intensity = (2 * M * N * K) / total_memory_bytes  # FLOPs per byte
    
    # Theoretical memory bandwidth needed
    peak_gflops = 100  # Example peak performance
    required_bandwidth_gbps = (total_memory_bytes * peak_gflops) / (2 * M * N * K * 1e9)
    
    return {
        'memory_bytes': total_memory_bytes,
        'arithmetic_intensity': arithmetic_intensity,
        'required_bandwidth_gbps': required_bandwidth_gbps,
        'is_memory_bound': arithmetic_intensity < 1.0  # Generally memory bound if < 1
    }
๐Ÿ“Š

GEMM Performance by Matrix Size (Nov 2019)

Matrix SizecuBLAS GFLOPSCustom Kernel GFLOPSEfficiencyMemory Bound
512x512x512 850 780 92% No
1024x1024x1024 870 820 94% No
2048x2048x2048 880 850 97% No
4096x512x4096 820 750 91% Yes
512x4096x4096 780 720 92% Yes

Specialized Optimizations

Tensor Core Utilization (NVIDIA GPUs)

// Example of using NVIDIA Tensor Cores for GEMM
#include <mma.h>

using namespace nvcuda;

__global__ void tensor_core_gemm(
    const half* __restrict__ A,
    const half* __restrict__ B,
    float* __restrict__ C,
    const int M, const int N, const int K) {
    
    // Tensor Core operations use 16x16x16 tiles
    const int BLOCK_M = 16;
    const int BLOCK_N = 16; 
    const int BLOCK_K = 16;
    
    // Warp-level matrix fragments
    wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> frag_a;
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> frag_b;
    wmma::fragment<wmma::accumulator, 16, 16, 16, float> frag_c;
    
    // Calculate thread's tile position
    int warp_m = (blockIdx.y * blockDim.y + threadIdx.y) / 2;  // 2 warps per tile vertically
    int warp_n = (blockIdx.x * blockDim.x + threadIdx.x) / 2;  // 2 warps per tile horizontally
    
    // Bounds checking
    if (warp_m * BLOCK_M >= M || warp_n * BLOCK_N >= N) return;
    
    // Initialize accumulator to zero
    wmma::fill_fragment(frag_c, 0.0f);
    
    // Loop over K dimension
    for (int k = 0; k < K; k += BLOCK_K) {
        // Load A fragment
        wmma::load_matrix_sync(frag_a, 
            &A[(warp_m * BLOCK_M) * K + k], 
            K, wmma::mem_row_major);
        
        // Load B fragment  
        wmma::load_matrix_sync(frag_b,
            &B[k * N + warp_n * BLOCK_N],
            N, wmma::mem_col_major);
        
        // Matrix multiply-accumulate
        wmma::mma_sync(frag_c, frag_a, frag_b, frag_c);
    }
    
    // Store result
    wmma::store_matrix_sync(&C[warp_m * BLOCK_M * N + warp_n * BLOCK_N],
                           frag_c,
                           N, wmma::mem_row_major);
}

Quantized GEMM Operations

def quantized_gemm(A_int8, B_int8, A_scale, B_scale, A_zero_point, B_zero_point):
    """
    Quantized GEMM for neural networks (INT8 operations)
    """
    # Perform integer GEMM
    C_int32 = torch.mm(A_int8.float(), B_int8.float())
    
    # Calculate output scale and zero point
    C_scale = A_scale * B_scale
    C_zero_point = 0  # Typically 0 for output
    
    # Calculate dequantized result
    C_float = C_scale * (C_int32 - A_zero_point * B_sum - B_zero_point * A_sum + A_zero_point * B_zero_point * M)
    
    return C_float

class QuantizedGEMM(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, B, A_scale, B_scale, A_zero_point, B_zero_point):
        # Quantize inputs
        A_quant = ((A / A_scale) + A_zero_point).round().clamp(-128, 127).char()
        B_quant = ((B / B_scale) + B_zero_point).round().clamp(-128, 127).char()
        
        # Perform quantized GEMM
        C_quant = torch.mm(A_quant.float(), B_quant.float())
        
        # Store scales for backward pass
        ctx.A_scale = A_scale
        ctx.B_scale = B_scale
        
        return C_quant * A_scale * B_scale
    
    @staticmethod
    def backward(ctx, grad_output):
        # Simplified backward pass
        A_scale, B_scale = ctx.A_scale, ctx.B_scale
        # Actual implementation would involve more complex quantization-aware gradients
        pass

Quantized vs Float GEMM Performance

(GFLOPS)
๐Ÿ“Š bar chart (GFLOPS)

Hardware-Specific Optimizations

CPU Optimizations

// Optimized GEMM for CPU with vectorization
#include <immintrin.h>

void cpu_optimized_gemm(float* A, float* B, float* C, int M, int N, int K) {
    // Use AVX/FMA instructions for vectorized computation
    const int UNROLL_FACTOR = 8;
    
    for (int i = 0; i < M; ++i) {
        for (int j = 0; j < N; j += UNROLL_FACTOR) {
            // Vectorized accumulation using AVX registers
            __m256 acc = _mm256_setzero_ps();
            
            for (int k = 0; k < K; ++k) {
                __m256 a_val = _mm256_broadcast_ss(&A[i * K + k]);
                __m256 b_vals = _mm256_loadu_ps(&B[k * N + j]);
                acc = _mm256_fmadd_ps(a_val, b_vals, acc);
            }
            
            // Store results
            _mm256_storeu_ps(&C[i * N + j], acc);
        }
        
        // Handle remaining elements
        for (int j = (N / UNROLL_FACTOR) * UNROLL_FACTOR; j < N; ++j) {
            float sum = 0.0f;
            for (int k = 0; k < K; ++k) {
                sum += A[i * K + k] * B[k * N + j];
            }
            C[i * N + j] = sum;
        }
    }
}

// Cache-optimized GEMM
void cache_optimized_gemm(float* A, float* B, float* C, int M, int N, int K) {
    const int MC = 256;  // Panel of A
    const int NC = 128;  // Panel of B
    const int KC = 128;  // Inner dimension panel
    
    for (int mc = 0; mc < M; mc += MC) {
        int mc_size = min(M - mc, MC);
        
        for (int nc = 0; nc < N; nc += NC) {
            int nc_size = min(N - nc, NC);
            
            // Initialize C panel
            for (int i = mc; i < mc + mc_size; ++i) {
                for (int j = nc; j < nc + nc_size; ++j) {
                    C[i * N + j] = 0.0f;
                }
            }
            
            for (int kc = 0; kc < K; kc += KC) {
                int kc_size = min(K - kc, KC);
                
                // Compute panel: A(mc:mc+MC, kc:kc+KC) * B(kc:kc+KC, nc:nc+NC)
                for (int i = mc; i < mc + mc_size; ++i) {
                    for (int j = nc; j < nc + nc_size; ++j) {
                        float sum = 0.0f;
                        for (int k = kc; k < kc + kc_size; ++k) {
                            sum += A[i * K + k] * B[k * N + j];
                        }
                        C[i * N + j] += sum;
                    }
                }
            }
        }
    }
}

GPU Memory Optimization

// Coalesced memory access optimization for GPU
__global__ void coalesced_gemm_kernel(
    const float* __restrict__ A,
    const float* __restrict__ B,
    float* __restrict__ C,
    const int M, const int N, const int K) {
    
    // Thread indices with coalesced access pattern
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (row < M && col < N) {
        float sum = 0.0f;
        
        // Vectorized access along K dimension for better coalescing
        for (int k = 0; k < K; k += 4) {
            // Process 4 elements at once for better memory efficiency
            float4 a_vec, b_vec;
            
            if (k + 3 < K) {
                // Load 4 consecutive A elements (same row)
                a_vec = make_float4(
                    A[row * K + k],
                    A[row * K + k + 1], 
                    A[row * K + k + 2],
                    A[row * K + k + 3]
                );
                
                // Load 4 B elements from same column but different rows
                b_vec = make_float4(
                    B[(k + 0) * N + col],
                    B[(k + 1) * N + col],
                    B[(k + 2) * N + col], 
                    B[(k + 3) * N + col]
                );
                
                sum += a_vec.x * b_vec.x + 
                       a_vec.y * b_vec.y + 
                       a_vec.z * b_vec.z + 
                       a_vec.w * b_vec.w;
            } else {
                // Handle remaining elements
                for (int kk = k; kk < K; ++kk) {
                    sum += A[row * K + kk] * B[kk * N + col];
                }
            }
        }
        
        C[row * N + col] = sum;
    }
}
๐Ÿ“Š

Hardware-Specific GEMM Optimizations

PlatformBaselineOptimizedImprovementOptimization Type
V100 GPU 500 GFLOPS 900 GFLOPS 80% Tensor Cores
T4 GPU 250 GFLOPS 450 GFLOPS 80% INT8 Quantization
Skylake CPU 50 GFLOPS 180 GFLOPS 260% AVX-512
ARM CPU 15 GFLOPS 45 GFLOPS 200% NEON SIMD
TPU 15 TFLOPS 15 TFLOPS 0% Specialized

Performance Bottleneck Analysis

Identifying Performance Limits

def analyze_gemm_bottlenecks(M, N, K):
    """
    Analyze potential bottlenecks in GEMM operations
    """
    # Calculate arithmetic intensity
    flops = 2 * M * N * K  # Multiply-add operations
    bytes_loaded = (M * K + K * N + M * N) * 4  # 4 bytes per float32
    arithmetic_intensity = flops / bytes_loaded  # FLOPs per byte
    
    # Theoretical peak performance (example values)
    peak_flops_gpu = 10e12  # 10 TFLOPS (V100 example)
    peak_bandwidth_gpu = 900e9  # 900 GB/s (V100 example)
    
    # Compute bounds
    compute_bound_gflops = peak_flops_gpu / 1e9
    memory_bound_gflops = peak_bandwidth_gpu * arithmetic_intensity / 1e9
    
    bottleneck = "compute" if compute_bound_gflops < memory_bound_gflops else "memory"
    
    return {
        'arithmetic_intensity': arithmetic_intensity,
        'compute_bound_gflops': compute_bound_gflops,
        'memory_bound_gflops': memory_bound_gflops,
        'predicted_performance': min(compute_bound_gflops, memory_bound_gflops),
        'bottleneck': bottleneck,
        'optimization_priority': "memory" if bottleneck == "memory" else "compute"
    }

def profile_gemm_performance():
    """
    Profile GEMM performance to identify bottlenecks
    """
    import torch
    import torch.profiler as profiler
    
    M, N, K = 2048, 2048, 2048
    A = torch.randn(M, K, device='cuda', dtype=torch.float32)
    B = torch.randn(K, N, device='cuda', dtype=torch.float32)
    
    with profiler.profile(
        activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA],
        record_shapes=True,
        profile_memory=True,
        with_stack=True
    ) as prof:
        for _ in range(10):
            C = torch.mm(A, B)
            torch.cuda.synchronize()
    
    # Analyze results
    events = prof.events()
    avg_duration = sum(e.self_cuda_time_total for e in events if 'mm' in e.name) / 10
    
    return {
        'average_duration_ms': avg_duration / 1000,  # Convert microseconds to milliseconds
        'gflops_achieved': (2 * M * N * K) / (avg_duration / 1000 * 1e9),
        'memory_utilization': 'profile_memory option shows memory usage',
        'kernel_launch_overhead': 'measured in profiler'
    }
๐Ÿ“Š

GEMM Bottleneck Analysis

SizeArithmetic IntensityPredicted BottleneckAchieved Performance
512x512x512 0.5 Memory 85% of peak
1024x1024x1024 0.5 Memory 87% of peak
2048x2048x2048 0.5 Memory 88% of peak
4096x512x4096 0.25 Memory 75% of peak
128x128x8192 2.0 Compute 95% of peak

Practical Implementation Guidelines

When to Use Custom Kernels

๐Ÿ’ก Custom Kernel Selection

Use custom kernels when: (1) Specific hardware features (Tensor Cores, special instructions), (2) Unique memory access patterns in your workload, (3) Quantization requirements, or (4) Specialized fused operations are needed. Otherwise, optimized BLAS libraries typically provide the best performance with less development effort.

๐Ÿ“Š

Custom vs BLAS Selection Criteria

ScenarioRecommendationRationalePerformance Impact
Standard dense layers BLAS Well-optimized, less work Best for general use
Quantized inference Custom BLAS doesn't support quantization Required for INT8
Tensor Core usage Custom Need specialized code 2-4x improvement
Fused operations Custom BLAS can't fuse Reduces memory
Research/prototyping BLAS Faster development Productivity gain

Optimization Strategies

def gemm_optimization_strategies():
    """
    Different optimization strategies for GEMM operations
    """
    strategies = {
        'algorithmic': [
            'Use Strassen\'s algorithm for very large matrices (>10000x10000)',
            'Apply blocking/tiling for cache efficiency',
            'Exploit sparsity when present in matrices'
        ],
        'memory': [
            'Align memory to cache line boundaries (64-byte)',
            'Use packed formats to improve memory access',
            'Minimize memory copies between host and device'
        ],
        'computation': [
            'Fuse multiple operations when possible',
            'Use appropriate precision for the task',
            'Exploit symmetry in matrices when present'
        ],
        'parallelization': [
            'Use threading for CPU implementations',
            'Optimize thread block sizes for GPU',
            'Consider wave quantization for large models'
        ]
    }
    
    return strategies

def select_optimal_gemm_implementation(problem_size, hardware, precision):
    """
    Select optimal GEMM implementation based on parameters
    """
    M, N, K = problem_size
    
    if precision == 'int8' or precision == 'int4':
        return 'custom_quantized'
    elif hardware.vendor == 'NVIDIA' and 512 <= min(M, N, K) <= 8192:
        if hardware.supports_tensor_cores and precision == 'fp16':
            return 'custom_tensor_core'
        else:
            return 'cublas'
    elif hardware.architecture == 'ARM' and precision == 'fp16':
        return 'custom_neon'
    elif max(M, N, K) > 10000:
        if hardware.supports_advanced_simd:
            return 'custom_strassen'
        else:
            return 'blas_scaled'
    else:
        return 'standard_blas'  # Safe fallback

Limitations and Considerations

BLAS Limitations

def blas_limitations_analysis():
    """
    Analyze limitations of standard BLAS implementations
    """
    limitations = {
        'specialized_operations': {
            'issue': 'BLAS libraries don\'t support fused operations',
            'impact': 'Extra memory transfers between operations',
            'workaround': 'Custom kernels with fused operations'
        },
        'quantization': {
            'issue': 'Most BLAS libraries don\'t support integer operations',
            'impact': 'Quantized models require custom implementations',
            'workaround': 'Specialized quantized GEMM libraries'
        },
        'hardware_specific': {
            'issue': 'BLAS libraries may not use latest hardware features optimally',
            'impact': 'Suboptimal performance on new architectures',
            'workaround': 'Custom kernels targeting specific hardware'
        },
        'small_matrices': {
            'issue': 'BLAS overhead significant for small matrices',
            'impact': 'Poor performance for small operations',
            'workaround': 'Specialized small matrix kernels'
        }
    }
    
    return limitations

def custom_kernel_challenges():
    """
    Challenges with custom kernel development
    """
    challenges = {
        'development_complexity': {
            'difficulty': 'High',
            'time_investment': 'Months for complex optimizations',
            'expertise_required': 'GPU/CPU architecture knowledge'
        },
        'portability': {
            'issue': 'Custom kernels are hardware-specific',
            'impact': 'Need different versions for different platforms',
            'solution': 'Abstract interfaces, multiple implementations'
        },
        'maintenance': {
            'issue': 'Hard to maintain and debug',
            'impact': 'Increased development costs',
            'solution': 'Comprehensive testing, documentation'
        },
        'optimization_validation': {
            'issue': 'Difficult to verify optimization correctness',
            'impact': 'Potential numerical errors',
            'solution': 'Extensive numerical testing, precision analysis'
        }
    }
    
    return challenges

Future Developments

By November 2019, GEMM optimization was evolving rapidly:

๐Ÿ“Š

GEMM Optimization Evolution

YearDevelopmentPerformance ImpactAdoption Timeline
2015 cuBLAS optimizations 2x improvement Immediate
2017 Tensor Core introduction 4-8x improvement 2018-2019
2018 Quantized GEMM 2-4x inference speed 2019-2020
2019 Sparse GEMM Variable 2020+
2019 Custom kernels for transformers 1.5-3x improvement Ongoing

Conclusion

GEMM optimization represents a critical aspect of neural network performance as of November 2019. The choice between BLAS libraries and custom kernels depends on several factors:

  • BLAS libraries provide well-tested, portable, and generally well-optimized implementations that work well for most use cases
  • Custom kernels offer potential for significant performance improvements when targeting specific hardware features or operation patterns

The key insights for November 2019 were:

  1. Tensor Cores provided substantial performance gains for half-precision operations on NVIDIA GPUs
  2. Quantized GEMM operations became increasingly important for efficient inference
  3. Memory access patterns remained critical for achieving peak performance
  4. Fused operations offered opportunities for reducing memory overhead

The optimal approach often involves using BLAS libraries as a baseline and developing custom kernels only when specific hardware features or performance requirements justify the additional complexity. This balance between performance and development efficiency has continued to guide GEMM optimization strategies in the years since.