Part of Series Inference Optimization Timeline 58 of 60
1 Transformer Fundamentals for Systems Engineers: The 10-Minute Bridge from Architecture to Inference 2 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 3 KV Cache: The Hidden Memory Giant in LLM Serving 4 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 5 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 6 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 7 Continuous Batching: The Complete Guide to LLM Inference Scheduling 8 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 9 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 10 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 11 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 12 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 13 Mamba and State Space Models: The O(n) Alternative to Attention 14 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 15 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 16 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 17 Model Loading and Cold Start: safetensors, mmap, and Startup Optimization 18 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 19 Kernel Autotuning: How TensorRT and torch.compile Find Optimal CUDA Kernels 20 Attention Kernel Comparison: FlashAttention vs FlashInfer vs xformers vs Triton 21 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 22 Dynamic Batching: Orca, Sarathi, and Iteration-Level Scheduling Algorithms 23 Memory Pool Management: Slab Allocators for GPU Inference 24 Prefill vs Decode Optimization: Different Bottlenecks, Different Solutions 25 Decode Optimization: CUDA Graphs, Persistent Batches, and Speculative Verification 26 Multi-Model Serving: GPU Sharing, Model Switching, and Adapter Pool Management 27 Structured Output Acceleration: Compressed FSMs, Speculative JSON, and Grammar Caching 28 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 29 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 30 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 31 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 32 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification 33 Disaggregated Serving v2: Mooncake KV-Centric Architecture and LoongServe Elastic SP 34 Request Preemption and Priority Scheduling in Production LLM Serving 35 Autoscaling LLM Inference: Signals, Lag, Warm Pools, and Cost-Optimal Scaling 36 The Inference Stack in 2026: From HTTP Request to GPU Kernel and Back 37 Video and Audio LLM Serving: Temporal Encoding, Chunked Streaming, and Latency Budgets 38 KV Cache Compression and Eviction: H2O, Attention Sinks, Sliding Window, and Quantized KV 39 Distributed Inference: Tensor Parallelism vs Pipeline Parallelism for Serving 40 Serving Benchmark Methodology: How to Properly Measure LLM Inference Performance 41 Compute-Communication Overlap: Hiding Distributed Training Latency 42 DeepSpeed ZeRO: Memory Optimization for Distributed Training at Scale 43 Pipeline Parallelism: From GPipe to DualPipe -- Eliminating the Bubble 44 Gradient Compression for Distributed Training: Promise, Reality, and Where It Still Wins 45 The Definitive Guide to Distributed Parallelism: Data, Tensor, Pipeline, Expert, and Sequence Parallelism for Large-Scale Training 46 Decoding Performance: Beam Search vs Sampling — Latency, Throughput, Memory, and the Full Design Space 47 LLM Prefill Phase Optimization: Why Prompt Processing Is Compute-Bound and How to Fix It 48 LLM Serving Engines: vLLM vs SGLang vs TensorRT-LLM — A Systems Comparison 49 Request Routing for LLM Inference: From Naive Load Balancing to KV Cache-Aware Scheduling 50 Why Adam Is Expensive and What To Do About It: 8-bit Adam, Adafactor, CAME, and the Memory Math of Optimizers 51 How Large Models Actually Get Loaded: Safetensors, mmap, Tensor Parallelism, and Progressive Loading 52 Mixed Precision Training: The Complete Precision Landscape from FP32 to FP4 53 Model Compression: Pruning, Distillation, and Why Quantization Won 54 From NAS to Scaling Laws: How We Design LLM Architectures Now 55 NVIDIA NCCL Performance Tuning for Multi-GPU Training 56 ONNX Runtime in Practice: Graph Optimization, Execution Providers, Quantization, and When ORT Is the Right Choice 57 Optimizing GEMM for Neural Networks: BLAS vs Custom Kernels (Nov 2019) 58 Long Context: From Sparse Attention to Ring Attention 59 TensorRT-LLM: Graph Optimization for Maximum Inference Performance 60 Long Context LLMs: From 2K to 1M Tokens

Sparse attention was supposed to solve the quadratic bottleneck and unlock million-token context windows. Longformer, BigBird, Sparse Transformer — all delivered impressive results on paper with hand-designed sparsity patterns that preserved quality while cutting compute by 10-100x. Then FlashAttention made dense attention fast enough for 32K-64K contexts and the entire sparse attention research program collapsed overnight. The reason: sparse attention requires custom kernels that are harder to optimize than dense attention, and modern accelerators (H100 tensor cores) are so fast at dense matmuls that sparsity only wins at extreme sequence lengths where memory becomes the bottleneck anyway. Sparse attention survives in exactly two niches: sliding window attention (Mistral) for simplicity, and extreme long-context serving where Ring Attention distributes dense computation across GPUs. This post covers the full arc from sparse patterns to why they lost.

The field’s response to this problem has evolved dramatically over five years:

  • 2019-2020: Sparse attention patterns (Longformer, BigBird, Sparse Transformer) traded quality for efficiency with hand-designed patterns.
  • 2022: FlashAttention made dense attention fast enough for moderate sequence lengths (up to 32K-64K), largely killing the sparse attention approach for mainstream use.
  • 2023: Sliding window attention (Mistral) offered a simpler alternative: each token attends to only the last WW tokens.
  • 2023-2024: Ring Attention and context parallelism distributed long sequences across multiple GPUs, enabling 128K-1M+ contexts.
  • 2025: The long-context landscape combines RoPE scaling, continued pretraining, FlashAttention, and distributed attention to deliver production-grade million-token contexts.

This post traces the full arc, from the original quadratic bottleneck to the techniques that power today’s long-context models.

The O(n2)O(n^2) Problem

Why Attention Is Quadratic

The standard scaled dot-product attention computes:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

where Q,K,VRn×dkQ, K, V \in \mathbb{R}^{n \times d_k}. The matrix multiplication QKTQK^T produces an n×nn \times n matrix. Every token’s query must be compared against every token’s key. This is the source of the quadratic cost.

def standard_attention(Q, K, V):
    """Standard O(n^2) attention."""
    d_k = Q.shape[-1]

    # QK^T: [batch, heads, n, d_k] x [batch, heads, d_k, n] = [batch, heads, n, n]
    # This n x n matrix is the bottleneck
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # Softmax over the key dimension
    attn_weights = torch.softmax(scores, dim=-1)

    # Weighted sum of values: [batch, heads, n, n] x [batch, heads, n, d_k]
    output = torch.matmul(attn_weights, V)
    return output

The Memory Wall at Scale

The attention matrix must be materialized in GPU memory (at least partially) for both the forward pass and backward pass. During training with gradient checkpointing disabled, you need to store the attention weights for backpropagation.

📊

Attention Memory Requirements (Single Layer, Single Head, FP16)

Sequence LengthAttention Matrix SizeMemory (FP16)Feasible on 80GB GPU?
2,048 4.2M entries 8 MB Yes -- trivial
8,192 67.1M entries 128 MB Yes -- comfortable
32,768 1.07B entries 2 GB Yes -- tight with many layers/heads
131,072 (128K) 17.2B entries 32 GB No -- exceeds memory for multi-layer model
524,288 (512K) 275B entries 512 GB No -- impossible on single GPU
1,048,576 (1M) 1.1T entries 2 TB No -- requires distributed approach
The Real Numbers Are Worse

The table above shows memory for a single layer and single head. A 70B model has 80 layers and 64 heads (8 KV heads with GQA). Even with GQA reducing the KV heads, the total attention memory for a full forward pass is the single-head cost multiplied by the number of layers and query heads. At 128K tokens, this makes naive attention completely infeasible without memory-efficient techniques.

The quadratic scaling also affects computation time. At 2K tokens, attention is a small fraction of total model FLOPS. At 128K tokens, attention dominates:

📊

Attention FLOPS as Fraction of Total Model FLOPS (70B Model)

Sequence LengthAttention FLOPSFFN FLOPSAttention Share
2,048 0.34 TFLOPS 4.6 TFLOPS 7%
8,192 5.5 TFLOPS 18.4 TFLOPS 23%
32,768 88 TFLOPS 73.7 TFLOPS 54%
131,072 1,408 TFLOPS 294.9 TFLOPS 83%

At 128K tokens, attention accounts for 83% of all computation. Any optimization that reduces attention cost has massive impact on total throughput.

Sparse Attention Approaches (2019-2021)

The first generation of solutions attacked the quadratic problem by making the attention matrix sparse. Instead of every token attending to every other token, restrict attention to a subset of positions.

Sparse Transformer (2019)

OpenAI’s Sparse Transformer introduced two key sparse patterns:

  1. Strided attention: Token ii attends to tokens at positions {iW,iW+1,,i}\{i - W, i - W + 1, \ldots, i\} (local window) and tokens at positions {j:jmods=0}\{j : j \bmod s = 0\} (every ss-th position globally).
  2. Fixed attention: Alternate between local attention layers and layers that attend to fixed global positions.

The complexity reduces from O(n2)O(n^2) to O(nn)O(n\sqrt{n}) by combining local and strided patterns.

def sparse_transformer_pattern(seq_len, window_size=256, stride=256):
    """Create Sparse Transformer attention pattern."""
    mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)

    for i in range(seq_len):
        # Local window: attend to nearby tokens
        local_start = max(0, i - window_size)
        local_end = min(seq_len, i + 1)  # causal
        mask[i, local_start:local_end] = True

        # Strided: attend to every stride-th token
        strided_positions = torch.arange(0, i + 1, stride)
        mask[i, strided_positions] = True

    return mask

Longformer (2020)

Longformer from Allen AI refined the sparse approach with three attention patterns combined:

  1. Local sliding window: Each token attends to WW tokens on each side (window of 2W+12W + 1).
  2. Dilated sliding window: Like sliding window but with gaps, increasing receptive field.
  3. Global attention: Designated tokens (like [CLS] or question tokens) attend to all positions and are attended by all positions.

The complexity is O(nW)O(n \cdot W) for the local component (linear in nn) plus O(ng)O(n \cdot g) for the gg global tokens.

def longformer_attention(Q, K, V, window_size=512, global_indices=None):
    """Simplified Longformer attention pattern."""
    seq_len = Q.shape[1]
    output = torch.zeros_like(Q)

    for i in range(seq_len):
        # Local window
        start = max(0, i - window_size // 2)
        end = min(seq_len, i + window_size // 2 + 1)

        # Keys to attend to: local window + global tokens
        attend_to = list(range(start, end))
        if global_indices is not None:
            attend_to = sorted(set(attend_to) | set(global_indices))

        K_subset = K[:, attend_to, :]
        V_subset = V[:, attend_to, :]

        scores = torch.matmul(Q[:, i:i+1, :], K_subset.transpose(-2, -1))
        scores = scores / math.sqrt(Q.shape[-1])
        weights = torch.softmax(scores, dim=-1)
        output[:, i:i+1, :] = torch.matmul(weights, V_subset)

    return output

BigBird (2020)

Google’s BigBird added random attention to the mix, combining three components:

  1. Local windowed attention: Same as Longformer.
  2. Global tokens: A set of tokens that attend to all positions.
  3. Random attention: Each token randomly attends to rr additional tokens.

The theoretical contribution was proving that this combination is a universal approximator of sequence-to-sequence functions, maintaining the expressiveness of full attention while being sparse.

📊

Sparse Attention Methods Comparison

MethodPatternComplexityMemoryQuality vs Dense
Dense (baseline) Full n x n O(n^2 d) O(n^2) 100% (baseline)
Sparse Transformer Strided + local O(n sqrt(n) d) O(n sqrt(n)) 98-99%
Longformer Local + global O(n W d) O(n W) 97-99%
BigBird Local + global + random O(n W d) O(n W) 97-99%
Linformer Low-rank projection O(n k d) O(n k) 94-97%
Performer Random features (FAVOR+) O(n d^2) O(n d) 90-95%

Linear Attention Variants

A separate line of research tried to eliminate the quadratic term entirely:

Linformer projected the key and value matrices to a lower dimension knk \ll n, reducing the attention matrix from n×nn \times n to n×kn \times k. The complexity becomes O(nkd)O(n \cdot k \cdot d), linear in nn when kk is fixed.

Performer (FAVOR+) approximated the softmax kernel with random features, enabling attention to be computed as Q(KTV)Q'(K'^T V) instead of (QKT)V(Q' K'^T) V, changing the association order to avoid materializing the n×nn \times n matrix. Complexity: O(nd2)O(n \cdot d^2).

These linear variants were elegant mathematically but suffered quality degradation on real tasks, particularly for tasks requiring precise long-range attention patterns. The approximation error compounded across layers, leading to noticeable performance drops.

ℹ️ The Approximation Quality Problem

Linear attention methods like Performer and Random Feature Attention approximate the softmax attention kernel. But softmax attention has a crucial property: it produces sharp attention distributions that focus on a few relevant positions. The random feature approximation tends to produce blurry distributions that spread attention more uniformly. This matters most for tasks requiring precise retrieval from context — exactly the tasks where long context is most valuable.

Why Sparse Attention Lost to FlashAttention

In 2022, Tri Dao’s FlashAttention paper changed the landscape entirely. Rather than approximating or sparsifying the attention computation, FlashAttention computed exact dense attention but reorganized the computation to be IO-aware — minimizing reads and writes to GPU high-bandwidth memory (HBM).

The Key Insight: Attention Is Memory-Bound

The standard attention implementation materializes the full n×nn \times n attention matrix in HBM, reads it back for softmax, writes the result back, then reads it again for the value multiplication. Each of these reads and writes goes through HBM, which is 10-100x slower than the GPU’s compute units.

FlashAttention’s insight: you can compute attention in tiles that fit in SRAM (the GPU’s fast on-chip memory), never materializing the full n×nn \times n matrix in HBM. The algorithm uses the online softmax trick to compute exact softmax incrementally across tiles.

def flash_attention_conceptual(Q, K, V, block_size=256):
    """
    Conceptual FlashAttention: tiled attention without materializing n x n matrix.
    Real implementation is a fused CUDA kernel.
    """
    n, d = Q.shape[0], Q.shape[1]
    output = torch.zeros_like(Q)
    row_max = torch.full((n,), float('-inf'))  # running max for online softmax
    row_sum = torch.zeros(n)                   # running sum for online softmax

    # Process in blocks -- never create the full n x n matrix
    for j_start in range(0, n, block_size):
        j_end = min(j_start + block_size, n)
        K_block = K[j_start:j_end]  # [block_size, d]
        V_block = V[j_start:j_end]  # [block_size, d]

        # Compute scores for this block: [n, block_size]
        scores = Q @ K_block.T / math.sqrt(d)

        # Online softmax update
        block_max = scores.max(dim=-1).values
        new_max = torch.maximum(row_max, block_max)

        # Rescale previous accumulator
        scale_old = torch.exp(row_max - new_max)
        scale_new = torch.exp(block_max - new_max)

        # Update running statistics
        exp_scores = torch.exp(scores - new_max.unsqueeze(-1))
        new_sum = scale_old * row_sum + exp_scores.sum(dim=-1)

        # Update output: rescale old output + add new contribution
        output = output * (scale_old * row_sum / new_sum).unsqueeze(-1)
        output += (exp_scores @ V_block) / new_sum.unsqueeze(-1)

        row_max = new_max
        row_sum = new_sum

    return output

FlashAttention Performance

The results were dramatic. FlashAttention was 2-4x faster than standard PyTorch attention and used O(n)O(n) memory instead of O(n2)O(n^2) — while computing exact attention.

FlashAttention vs Standard Attention: Wall Clock Time

(ms)
📊 bar chart (ms)
📊

FlashAttention vs Standard vs Sparse (A100, FP16, Causal)

Sequence LengthStandard Attn (ms)FlashAttention-2 (ms)Sparse (Longformer, ms)FA Speedup vs Standard
2,048 1.2 0.5 0.8 2.4x
4,096 4.8 1.4 1.5 3.4x
8,192 19.2 4.2 2.9 4.6x
16,384 76.8 14.1 5.7 5.4x
32,768 307.2 48.3 11.2 6.4x
65,536 OOM 172.0 22.1 N/A (OOM vs works)
Why FlashAttention Killed Sparse Attention for Most Use Cases

Before FlashAttention, the choice was: (1) dense attention that is exact but OOMs at long sequences, or (2) sparse attention that works at long sequences but loses quality. FlashAttention created option (3): exact dense attention that works at long sequences. At sequence lengths up to 32K-64K, FlashAttention is fast enough and has zero quality compromise. Sparse attention’s complexity — custom attention patterns, quality monitoring, pattern-specific bugs — was no longer worth the tradeoff for these sequence lengths.

FlashAttention-2 and FlashAttention-3

FlashAttention-2 (2023) improved on the original with better parallelism across sequence length and attention heads, achieving close to the theoretical maximum throughput on A100 GPUs (up to 230 TFLOPS in FP16/BF16, compared to the peak 312 TFLOPS).

FlashAttention-3 (2024) targets H100 GPUs with FP8 support, asynchronous block-wise softmax, and warp specialization, pushing throughput even higher and enabling efficient attention at 128K+ sequence lengths on a single GPU.

📊

FlashAttention Evolution

VersionYearKey InnovationThroughput (A100, 16K seq)Memory
Standard PyTorch Pre-2022 Baseline ~60 TFLOPS O(n^2)
FlashAttention v1 2022 Tiled IO-aware attention ~120 TFLOPS O(n)
FlashAttention v2 2023 Better parallelism, reduced non-matmul FLOPs ~200 TFLOPS O(n)
FlashAttention v3 2024 H100 FP8, async softmax, warp specialization ~300 TFLOPS (H100) O(n)

The progression tells a clear story: instead of reducing the amount of computation (sparse attention), FlashAttention made the computation faster by optimizing how data moves through the GPU memory hierarchy. This is an IO-first rather than FLOP-first optimization.

Sliding Window Attention (Mistral, 2023)

While FlashAttention solved the memory problem for dense attention, some models opted for a simpler architectural approach: just limit how far back each token can attend.

The Design

Sliding window attention (SWA) is exactly what it sounds like. Each token at position ii attends only to tokens in the range [iW,i][i - W, i], where WW is the window size. The attention matrix is banded rather than full lower-triangular.

def sliding_window_attention(Q, K, V, window_size=4096):
    """Each token attends to at most window_size previous tokens."""
    n = Q.shape[1]
    output = torch.zeros_like(Q)

    for i in range(n):
        start = max(0, i - window_size + 1)
        # Only attend to tokens within the window
        K_window = K[:, start:i+1, :]
        V_window = V[:, start:i+1, :]

        scores = torch.matmul(Q[:, i:i+1, :], K_window.transpose(-2, -1))
        scores = scores / math.sqrt(Q.shape[-1])
        weights = torch.softmax(scores, dim=-1)
        output[:, i:i+1, :] = torch.matmul(weights, V_window)

    return output

The complexity is O(nWd)O(n \cdot W \cdot d) — linear in sequence length for fixed WW. Memory for the KV cache is bounded at WW tokens per layer regardless of total sequence length.

Why It Works: Information Flow Through Layers

The key insight is that information propagates through layers. If the window size is W=4,096W = 4{,}096 and the model has L=32L = 32 layers, then at the final layer, token ii has an effective receptive field of W×L=131,072W \times L = 131{,}072 tokens. Information from distant tokens reaches the current position by hopping through intermediate tokens across layers.

This is analogous to how CNNs build large receptive fields from small convolution kernels stacked in deep networks. Each layer adds WW tokens of direct context, and the indirect context grows linearly with depth.

📊

Sliding Window Attention: Effective Context

Window SizeLayersDirect ContextEffective Receptive FieldKV Cache per Layer
4,096 32 4,096 tokens 131,072 tokens 32 MB (7B model, FP16)
4,096 80 4,096 tokens 327,680 tokens 80 MB (70B model, FP16)
8,192 32 8,192 tokens 262,144 tokens 64 MB (7B model, FP16)
32,768 32 32,768 tokens 1,048,576 tokens 256 MB (7B model, FP16)

Mistral’s Implementation

Mistral 7B (2023) used sliding window attention with W=4,096W = 4{,}096 as a core architectural choice. Combined with Grouped-Query Attention (GQA, 8 KV heads), this dramatically reduced memory requirements during inference:

  • Full attention KV cache at 32K tokens: 32×2×32,768×128×2=51232 \times 2 \times 32{,}768 \times 128 \times 2 = 512 MB per layer
  • SWA KV cache at 32K tokens: 32×2×4,096×128×2=6432 \times 2 \times 4{,}096 \times 128 \times 2 = 64 MB per layer (with GQA, much less)

The 8x reduction in KV cache memory translates directly to 8x higher batch sizes during serving, which means 8x higher throughput for the same hardware.

💡 SWA + FlashAttention = Best of Both

In practice, sliding window attention is implemented using FlashAttention with a causal+window mask. FlashAttention handles the tiling and IO optimization; the window mask just means fewer tiles are non-zero. This combination gives you both the memory savings of SWA and the computational efficiency of FlashAttention. Mistral’s implementation does exactly this.

Limitations of Sliding Window

SWA has a fundamental limitation: the information propagation through layers is lossy. Each hop through an intermediate token involves passing information through a nonlinear transformation. Fine-grained details from distant tokens degrade as they propagate. This means:

  • Tasks requiring exact retrieval from early in a long document (needle-in-a-haystack) degrade with distance.
  • Copy-paste style tasks (reproduce a specific passage from the context) fail for passages outside the direct window.
  • The effective receptive field is large but the effective attention is much smaller — most influence comes from nearby tokens.

For these reasons, models targeting very long contexts (128K+) typically use full attention rather than pure sliding window. Some models (like Mixtral and Jamba) use SWA in most layers with a few full-attention layers interspersed.

Ring Attention: Distributing Across GPUs

For context lengths beyond what a single GPU can handle (roughly 64K-128K for large models with FlashAttention), the solution is to distribute the sequence across multiple GPUs. Ring Attention (Liu et al., 2023) is the most elegant approach.

The Core Mechanism

Ring Attention distributes the sequence across PP GPUs (or “hosts”), with each GPU holding n/Pn/P tokens. The attention computation proceeds in PP rounds. In each round, each GPU computes attention between its local query chunk and the current key-value chunk, then passes the KV chunk to the next GPU in a ring topology.

def ring_attention(Q_local, K_local, V_local, comm_group, num_gpus):
    """
    Ring Attention: distribute sequence across GPUs in a ring.
    Each GPU holds n/P tokens.
    """
    chunk_size = Q_local.shape[1]  # n/P tokens per GPU
    rank = get_rank(comm_group)

    # Initialize output and softmax statistics
    output = torch.zeros_like(Q_local)
    row_max = torch.full((chunk_size,), float('-inf'))
    row_sum = torch.zeros(chunk_size)

    # Current KV chunk to process (starts as local)
    K_recv = K_local.clone()
    V_recv = V_local.clone()

    for step in range(num_gpus):
        # Determine which chunk of the sequence this KV belongs to
        source_rank = (rank - step) % num_gpus

        # Compute attention between local Q and received KV
        # Use online softmax for incremental accumulation
        scores = Q_local @ K_recv.transpose(-2, -1) / math.sqrt(d_k)

        # Apply causal mask if needed (based on position offsets)
        if causal:
            mask = create_causal_mask(rank, source_rank, chunk_size)
            scores = scores.masked_fill(~mask, float('-inf'))

        # Online softmax update (same as FlashAttention)
        block_max = scores.max(dim=-1).values
        new_max = torch.maximum(row_max, block_max)
        scale_old = torch.exp(row_max - new_max)
        exp_scores = torch.exp(scores - new_max.unsqueeze(-1))
        new_sum = scale_old * row_sum + exp_scores.sum(dim=-1)

        output = output * (scale_old * row_sum / new_sum).unsqueeze(-1)
        output += (exp_scores @ V_recv) / new_sum.unsqueeze(-1)

        row_max = new_max
        row_sum = new_sum

        # OVERLAP: send current KV to next GPU while receiving from previous
        # This is the key -- communication is overlapped with computation
        K_recv, V_recv = ring_send_recv(K_recv, V_recv, comm_group)

    return output

Why Ring Attention Works

The brilliance of ring attention is that communication overlaps with computation. While GPU ii is computing attention with KV chunk jj, it is simultaneously sending chunk jj to GPU i+1i+1 and receiving chunk j1j-1 from GPU i1i-1. If the computation time exceeds the communication time (which it does for large enough chunks), the communication is entirely hidden.

The memory per GPU is O(n/P)O(n/P) — each GPU only stores its local chunk of Q, K, V, and the output. The total computation is the same as standard attention (O(n2d)O(n^2 d)), but distributed across PP GPUs with near-zero communication overhead.

📊

Ring Attention: Scaling with GPUs

Sequence LengthGPUsTokens per GPUMemory per GPUEffective Throughput
128K 1 128K OOM (70B model) N/A
128K 4 32K ~40 GB 3.8x of 1 GPU
128K 8 16K ~20 GB 7.5x of 1 GPU
512K 8 64K ~40 GB 7.2x of 1 GPU
1M 16 64K ~40 GB 14.8x of 1 GPU
1M 64 16K ~10 GB 58x of 1 GPU
Ring Attention + FlashAttention

In practice, ring attention uses FlashAttention for the local attention computation within each step of the ring. The online softmax trick in FlashAttention is the same trick used to accumulate attention across ring steps. This means ring attention inherits FlashAttention’s IO efficiency while distributing across GPUs. The combination is the standard approach for training models with 128K+ context.

Causal Masking in Ring Attention

For autoregressive (causal) models, ring attention needs special handling. In causal attention, token ii can only attend to tokens jij \leq i. When GPU kk holds queries for positions [kC,(k+1)C)[k \cdot C, (k+1) \cdot C) and receives KV for positions [mC,(m+1)C)[m \cdot C, (m+1) \cdot C), the causal mask means:

  • If m>km \gt k: The entire block is masked (all queries come before all keys). Skip computation entirely.
  • If m<km \lt k: The entire block is unmasked (all queries come after all keys). Compute full attention.
  • If m=km = k: Apply the standard causal mask within the block.

This means roughly half the ring steps can be skipped entirely for causal attention, effectively doubling throughput compared to bidirectional ring attention.

Context Parallelism: Ulysses and Beyond

Ring attention is not the only approach to distributing attention. DeepSpeed-Ulysses (2023) takes a different approach based on all-to-all communication on the sequence dimension.

Ulysses: All-to-All on the Head Dimension

Ulysses distributes the sequence across GPUs but, before computing attention, performs an all-to-all transpose so that each GPU holds all tokens for a subset of attention heads. Attention is then computed locally on each GPU (full sequence, subset of heads), and another all-to-all transposes back.

def ulysses_attention(Q, K, V, comm_group, num_gpus):
    """
    DeepSpeed-Ulysses: all-to-all communication for context parallelism.

    Input: each GPU has all heads but n/P tokens
    After all-to-all: each GPU has all tokens but H/P heads
    """
    # Step 1: All-to-all -- redistribute from [seq_chunk, all_heads] to [all_seq, head_chunk]
    Q_full_seq = all_to_all(Q, split_dim=SEQ, gather_dim=HEAD, group=comm_group)
    K_full_seq = all_to_all(K, split_dim=SEQ, gather_dim=HEAD, group=comm_group)
    V_full_seq = all_to_all(V, split_dim=SEQ, gather_dim=HEAD, group=comm_group)

    # Step 2: Standard attention on full sequence, subset of heads
    # Each GPU computes exact attention for its heads
    output = flash_attention(Q_full_seq, K_full_seq, V_full_seq)

    # Step 3: All-to-all -- redistribute back to [seq_chunk, all_heads]
    output = all_to_all(output, split_dim=HEAD, gather_dim=SEQ, group=comm_group)

    return output

Ring Attention vs Ulysses

The two approaches have different communication patterns and different strengths:

📊

Ring Attention vs Ulysses

PropertyRing AttentionUlysses (All-to-All)Winner
Communication pattern Point-to-point in ring All-to-all collective Depends on network
Communication volume O(n d / P) per step O(n d) total (two all-to-alls) Ring (lower volume)
Overlap with compute Yes (pipelined) No (synchronous) Ring
Implementation complexity Higher (ring scheduling) Lower (standard collectives) Ulysses
Load balance (causal) Unbalanced (triangle mask) Balanced (each GPU has full seq) Ulysses
Minimum GPUs needed 2 H (number of heads) Ring (more flexible)
Works with GQA Yes Limited by KV heads Ring (more flexible)
ℹ️ The Load Balance Problem

Ring attention with causal masking has a load balance issue. The GPU holding the earliest sequence chunk does almost no computation (most KV chunks are masked), while the GPU holding the latest chunk does nearly full computation. Ulysses avoids this because each GPU processes all sequence positions (for a subset of heads). In practice, the load imbalance in ring attention is addressed by overdecomposing — using more ring steps than GPUs and scheduling them to balance work.

Hybrid Approaches

Modern training frameworks combine ring attention and Ulysses-style parallelism with tensor parallelism and data parallelism. For example:

  • Tensor parallelism (TP) within a node: split the model’s weight matrices across GPUs in a single node (fast NVLink).
  • Context parallelism (CP) across a ring: distribute the sequence across nodes for long context.
  • Data parallelism (DP): replicate the model across groups, each processing different batches.

A typical 128K-context training setup for a 70B model might use:

  • TP = 8 (one node of 8 GPUs)
  • CP = 4 (4 nodes in a ring for context)
  • DP = 16 (16 replicas for data parallelism)
  • Total: 512 GPUs
📊

Parallelism Strategy for Long-Context Training

Model SizeContext LengthTPCPDPTotal GPUs
7B 32K 1 1 64 64
7B 128K 2 4 32 256
70B 32K 8 1 32 256
70B 128K 8 4 16 512
405B 128K 8 8 8 512
405B 1M 8 32 4 1024

The Long-Context Landscape in 2025

Modern frontier models support remarkable context lengths:

  • Gemini 1.5 Pro: 1M tokens (2M in research previews)
  • Claude 3.5: 200K tokens
  • GPT-4 Turbo/GPT-4o: 128K tokens
  • Llama 3.1: 128K tokens
  • Mistral Large: 128K tokens
  • DeepSeek-V3: 128K tokens
  • Qwen-2.5: 128K tokens (1M with YaRN scaling)

How do they achieve this? The answer is a combination of techniques, not any single innovation.

Recipe 1: RoPE Scaling

Rotary Position Embeddings (RoPE) encode position information through rotation matrices applied to queries and keys. The base frequency θ\theta determines how position information is encoded:

qm(i)=q(i)eimθi,kn(i)=k(i)einθiq_m^{(i)} = q^{(i)} e^{im\theta_i}, \quad k_n^{(i)} = k^{(i)} e^{in\theta_i}

where θi=100002i/d\theta_i = 10000^{-2i/d} for dimension ii.

To extend context length beyond the training length, you can scale the frequencies:

NTK-aware scaling modifies the base frequency: θnew=θαd/(d2)\theta_{\text{new}} = \theta \cdot \alpha^{d/(d-2)} where α\alpha is the scaling factor. This interpolates smoothly between original and extended contexts.

YaRN (Yet another RoPE extensioN) combines NTK-aware scaling with attention temperature scaling and fine-tuning, achieving the best quality for extended contexts:

def yarn_rope_scaling(dim, max_position, original_max=4096, scale_factor=32):
    """YaRN RoPE scaling for context extension."""
    # NTK-aware interpolation
    beta_fast = 32  # high frequency boundary
    beta_slow = 1   # low frequency boundary
    base = 10000.0

    # Compute per-dimension scaling factors
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    old_context_len = original_max
    new_context_len = original_max * scale_factor

    # Linear ramp between low and high frequency regions
    low_freq_factor = old_context_len / (2 * math.pi / freqs)
    ramp = (low_freq_factor - beta_slow) / (beta_fast - beta_slow)
    ramp = ramp.clamp(0, 1)

    # Interpolate between original and scaled frequencies
    scaled_freqs = freqs / scale_factor
    new_freqs = (1 - ramp) * scaled_freqs + ramp * freqs

    # Attention temperature scaling
    attn_scale = 0.1 * math.log(scale_factor) + 1.0

    return new_freqs, attn_scale

Recipe 2: Continued Pre-Training on Long Documents

RoPE scaling alone is not sufficient. The model also needs to see long documents during training. The standard approach is:

  1. Pre-train the model at a shorter context (e.g., 4K-8K tokens) for most of training.
  2. Continue pre-training at the target context length (e.g., 128K tokens) for a small fraction of total training (1-5% of tokens).
  3. Use a dataset enriched with long documents (books, code repositories, long articles).

Llama 3.1’s technical report describes this process: after pre-training Llama 3.1 405B at 8K context on 15T tokens, they continued training at 128K context on approximately 800B additional tokens. The context extension phase is roughly 5% of total training compute.

Recipe 3: Infrastructure (FlashAttention + Ring Attention)

The training infrastructure must support the long context lengths efficiently:

  • FlashAttention v2/v3 for IO-efficient attention within each GPU.
  • Ring attention or context parallelism to distribute the sequence across GPUs.
  • Sequence packing to minimize padding waste when mixing different-length documents.
  • Gradient checkpointing to reduce memory for activations (at the cost of recomputation).
📊

How Frontier Models Achieve Long Context

ModelContext LengthPosition EncodingAttention MethodTraining Approach
Llama 3.1 405B 128K RoPE (scaled) FlashAttention + CP Continued pretraining at 128K
Claude 3.5 Sonnet 200K Unknown (proprietary) Unknown Unknown (proprietary)
GPT-4 Turbo 128K Unknown Unknown Unknown (proprietary)
Gemini 1.5 Pro 1M Unknown (likely RoPE variant) Likely ring attention variant Trained on long documents
Mistral Large 128K RoPE SWA + full attention layers Mixed-length training
Qwen-2.5 72B 128K (1M w/ YaRN) RoPE + YaRN FlashAttention Continued pretraining + YaRN

Recipe 4: Inference Optimizations for Long Context

Serving models at long context lengths requires additional optimizations beyond training:

KV cache quantization: Compress the KV cache from FP16 to FP8 or INT4, halving or quartering memory. Recent work (KIVI, KVQuant) shows this can be done with minimal quality loss, especially for older KV entries that have been “attended over” many times.

KV cache eviction: Discard KV entries for tokens that receive low attention. H2O (Heavy Hitter Oracle) keeps only the tokens with the highest cumulative attention scores. StreamingLLM keeps the first few tokens (attention sinks) plus a sliding window of recent tokens.

Prefix caching: For repeated prefixes (system prompts, RAG contexts), cache the KV values and reuse them across requests. This avoids recomputing the prefill for shared context.

def kv_cache_management_strategies():
    """Compare KV cache optimization strategies for long context."""
    strategies = {
        "full_kv_cache": {
            "memory": "O(n * L * d)",
            "quality": "100% (baseline)",
            "max_context_80GB": "~32K tokens (70B model)",
        },
        "gqa_kv_cache": {
            "memory": "O(n * L * d * (kv_heads / q_heads))",
            "quality": "~99.5%",
            "max_context_80GB": "~128K tokens (70B model, 8 KV heads)",
        },
        "quantized_kv_fp8": {
            "memory": "O(n * L * d / 2)",
            "quality": "~99%",
            "max_context_80GB": "~64K tokens (70B model)",
        },
        "quantized_kv_int4": {
            "memory": "O(n * L * d / 4)",
            "quality": "~97%",
            "max_context_80GB": "~128K tokens (70B model)",
        },
        "sliding_window_eviction": {
            "memory": "O(W * L * d)",
            "quality": "~95% (degrades with distance)",
            "max_context_80GB": "Unlimited (fixed memory)",
        },
        "h2o_eviction": {
            "memory": "O(budget * L * d)",
            "quality": "~97% (keeps important tokens)",
            "max_context_80GB": "Unlimited (fixed memory)",
        },
    }
    return strategies
📊

KV Cache Optimization: Memory vs Quality (70B Model, 128K Context)

StrategyKV Cache MemoryNeedle-in-Haystack AccuracyThroughput (tok/s)
Full KV (FP16) 42 GB 99.2% 450
GQA (8 KV heads, FP16) 5.2 GB 99.0% 1,200
GQA + FP8 quantization 2.6 GB 98.5% 1,800
GQA + INT4 quantization 1.3 GB 96.8% 2,400
Sliding window (32K) 1.3 GB 85% (at 128K distance) 2,800
H2O (budget=4K) 0.16 GB 92% 3,200

When Sparse Attention Still Wins

Despite FlashAttention’s dominance, there are scenarios where sparse attention remains the right approach.

Very Long Sequences (greater than 1M Tokens)

At sequence lengths beyond 1M tokens, even FlashAttention’s O(n2)O(n^2) computation becomes prohibitive. The attention FLOPS for a single layer at 1M tokens with d=128d = 128:

2×1,000,0002×128=2.56×1014 FLOPS2 \times 1{,}000{,}000^2 \times 128 = 2.56 \times 10^{14} \text{ FLOPS}

Even on an H100 at 1 PFLOPS, that is 256 milliseconds per layer for attention alone. With 80 layers, attention takes 20 seconds per forward pass. Sparse attention with O(nlogn)O(n \log n) or O(nn)O(n \sqrt{n}) complexity would reduce this to under 1 second.

Emerging use cases that need very long contexts:

  • Genomics: DNA sequences of 1M-10M base pairs. Each base is a token.
  • Whole-codebase analysis: Entire repositories as context (10M+ tokens).
  • Long-form video understanding: 1 hour of video at 1 frame/second with 256 tokens per frame = 921K tokens.
  • Multi-document reasoning: 100+ documents in context for complex question answering.

Resource-Constrained Deployment

For edge and mobile deployment, even FlashAttention does not help if the device cannot store the full KV cache. Sparse or windowed attention bounds memory consumption:

  • On-device LLMs (phones, laptops): Limited to 4-8 GB of memory for the entire model + KV cache.
  • Embedded systems: Even more constrained. Sliding window with W=512W = 512 may be necessary.
  • Cost-sensitive serving: When serving millions of users, every GB of KV cache per request translates to hardware costs. Sparse attention or aggressive eviction directly reduces serving costs.
📊

When to Use Which Attention Method

ScenarioBest MethodWhyQuality Trade-off
Seq length less than 8K FlashAttention (dense) Fast enough, exact attention None
8K-64K, single GPU FlashAttention (dense) Manageable memory with GQA None
64K-256K, multi-GPU Ring Attention + FlashAttention Distribute across GPUs None (exact)
256K-1M, multi-GPU Ring Attention + FlashAttention More GPUs for context parallelism None (exact)
Greater than 1M tokens Sparse + Ring Attention hybrid Quadratic cost too high even distributed Small (1-3%)
Edge / mobile Sliding window or sparse Memory bounded Moderate (5-10% at long range)
Cost-optimized serving SWA + KV eviction Minimize memory per request Moderate (context-dependent)

Specific Task Types

Some tasks have inherent locality that sparse attention exploits:

  • Language modeling: Most dependencies are local. Sliding window captures 95%+ of the relevant context.
  • DNA/protein sequences: Local structure matters most. Sparse attention with biological priors outperforms dense attention at the same compute budget.
  • Time series: Temporal locality is strong. Local + periodic attention patterns match the data structure.
  • Structured data (tables, code with nesting): Sparse patterns aligned with document structure outperform random dense attention.
💡 The Practitioner's Rule of Thumb

Start with FlashAttention (dense). If your sequence length exceeds what your hardware can handle, add context parallelism (ring attention). If your sequences exceed 1M tokens or you are severely memory-constrained, consider sparse attention. Do not reach for sparse attention as a first resort — the quality-complexity trade-off is rarely worth it when FlashAttention works.

The Evolution of the Attention Landscape

The trajectory of attention efficiency tells a clear story about how systems engineering trumps algorithmic cleverness when hardware constraints shift.

Phase 1: Algorithmic Innovation (2019-2021)

The initial response to the quadratic bottleneck was algorithmic: design new attention patterns that avoid the O(n2)O(n^2) cost. This produced Sparse Transformer, Longformer, BigBird, Performer, Linformer, and many others. Each paper proposed a different sparsity pattern with different theoretical properties.

The problem: each pattern required custom implementation, custom CUDA kernels, custom backward passes, and careful validation that the quality trade-off was acceptable for the target task. The ecosystem was fragmented. No single sparse attention method became a standard.

Phase 2: Systems Optimization (2022-2023)

FlashAttention shifted the conversation from “how to avoid computing attention” to “how to compute attention efficiently on actual hardware.” The insight was that the bottleneck was not FLOPS but memory bandwidth — and you could solve that without approximating the computation.

This was a systems engineering insight, not an algorithmic one. The same O(n2)O(n^2) computation, reorganized for the memory hierarchy, became 2-4x faster and used O(n)O(n) memory. The entire sparse attention research direction became less relevant for moderate sequence lengths.

Phase 3: Distributed Computation (2023-2025)

For truly long contexts, the solution was distributed systems: ring attention, context parallelism, sequence sharding. These methods keep exact attention but distribute it across GPUs.

The evolution mirrors a common pattern in computing: when a problem seems to require a clever algorithm, often a better systems approach (faster hardware, better memory management, more parallelism) makes the naive algorithm fast enough.

Long-Context Capability Over Time

(Maximum Context Length (tokens)) line
📊 line chart (Maximum Context Length (tokens))
📊

The Attention Efficiency Timeline

YearKey DevelopmentMaximum Practical ContextApproach
2017 Original Transformer ~512 tokens Dense attention, no optimization
2019 Sparse Transformer, XL ~8K tokens Sparse patterns, recurrence
2020 Longformer, BigBird ~16K tokens Local + global sparse attention
2021 ALiBi, RoPE ~8K tokens (better extrapolation) Position encoding improvements
2022 FlashAttention v1 ~16K-32K tokens IO-aware dense attention
2023 FlashAttention v2, Ring Attention ~128K tokens Better FA + distributed attention
2024 FA3, YaRN, context parallelism ~1M tokens H100 optimization + RoPE scaling + CP
2025 Production 1M+ context ~2M tokens Full stack: hardware + FA + CP + RoPE

Performance Benchmarks: Long Context in Practice

How well do modern long-context models actually use their context? The needle-in-a-haystack test measures this: insert a specific fact at various positions in a long document and test retrieval accuracy.

📊

Needle-in-a-Haystack Performance (2025 Models)

ModelContext Claimed95%+ Accuracy Up ToAttention MethodNotes
GPT-4 Turbo 128K ~120K Unknown (likely FA variant) Some degradation at end of context
Claude 3.5 Sonnet 200K ~195K Unknown (proprietary) Near-perfect across full context
Llama 3.1 70B 128K ~120K FlashAttention + CP Slight degradation near boundaries
Gemini 1.5 Pro 1M ~1M Unknown Near-perfect even at 1M tokens
Mistral Large 128K ~100K SWA + full attention Degrades past SWA window distance
Qwen-2.5 72B 128K ~110K FlashAttention Good performance with YaRN scaling
ℹ️ Context Length vs Context Usage

There is an important distinction between a model’s nominal context length and how effectively it uses that context. Many models can accept 128K tokens but show degraded performance for information in the middle of the context (“lost in the middle” phenomenon). True long-context capability means uniform performance regardless of where the relevant information appears. This is still an active area of improvement.

Throughput at Long Context

Long context dramatically affects serving throughput because the KV cache consumes memory that would otherwise be used for batching:

📊

Serving Throughput vs Context Length (Llama 3.1 70B, 8xH100)

Context LengthMax Batch SizeThroughput (tok/s)Cost per 1M tokensKV Cache per Request
4K 128 24,000 $0.12 0.16 GB
16K 64 14,400 $0.20 0.65 GB
32K 32 8,800 $0.33 1.3 GB
128K 4 1,600 $1.80 5.2 GB
512K (theoretical) 1 320 $9.00 20.8 GB

The economics are stark: going from 4K to 128K context increases cost per token by 15x, primarily because the KV cache limits batch size. This is why long-context API pricing is significantly higher and why KV cache optimization is critical for production viability.

Future Directions

Attention-Free Architectures

State space models (Mamba, RWKV) and linear attention variants avoid the quadratic bottleneck entirely. Mamba processes sequences in O(n)O(n) time with O(1)O(1) state size per token. If these architectures can match transformer quality at scale, the entire attention optimization story becomes moot.

Current status: Mamba-2 and Jamba show promise at 7B-52B scale, but no attention-free model has yet demonstrated frontier-level quality at 400B+ parameters. The transformer’s attention mechanism appears to provide something that linear models struggle to replicate — precise, content-based retrieval over long contexts.

Native Long-Context Training

Current models train primarily on short sequences and extend to long context via continued pre-training. Future models may train natively at long context from the start, enabled by cheaper hardware and better distributed training frameworks. This could improve long-context quality significantly.

Hybrid Sparse-Dense Approaches

Some emerging approaches use dense attention for nearby tokens and sparse attention for distant tokens, with the transition controlled by a learned boundary or a fixed heuristic. This matches the empirical observation that attention weights are typically concentrated locally with sparse long-range connections.

Hardware-Software Co-Design

Future accelerators may have hardware support for attention-specific operations: on-chip softmax units, dedicated KV cache memory, or attention-aware memory hierarchies. Groq’s LPU and Cerebras’s wafer-scale engine are early examples of hardware designed around the attention bottleneck.

Conclusion

The long-context story is a case study in how systems engineering solves problems that algorithms alone cannot. The original response to the O(n2)O(n^2) attention bottleneck was algorithmic: sparse attention patterns that traded quality for efficiency. This worked, but the trade-off was uncomfortable — every sparse pattern had failure modes, and the implementation complexity was high.

FlashAttention disrupted this by showing that the bottleneck was memory IO, not computation. By reorganizing the same O(n2)O(n^2) computation to be IO-aware, FlashAttention made dense attention fast enough for moderate context lengths (up to 32K-64K) with zero quality loss. Sparse attention’s complexity was no longer justified for these lengths.

For truly long contexts (128K-1M+), the solution shifted to distributed systems. Ring attention and context parallelism distribute the sequence across GPUs, keeping exact attention while scaling memory linearly with GPU count. Combined with RoPE scaling and continued pre-training on long documents, this stack enables production-grade million-token contexts.

Sparse attention is not dead. It retains clear advantages for very long sequences (greater than 1M tokens), resource-constrained deployment, and domain-specific tasks with inherent locality. But it is no longer the primary tool for extending context length. The primary tools are now FlashAttention for single-GPU efficiency and ring attention for multi-GPU distribution.

The practical lesson is clear: when you need longer context, first try FlashAttention with your existing architecture. If that is not enough, add context parallelism with more GPUs. Only reach for sparse attention when you have exhausted the dense-attention approaches or have specific constraints (memory, latency, domain structure) that sparse attention uniquely addresses. The field has learned that exact computation with good systems engineering beats approximate computation with clever algorithms — at least until the sequences get long enough that even the best systems cannot make quadratic scaling work.