Sparse attention was supposed to solve the quadratic bottleneck and unlock million-token context windows. Longformer, BigBird, Sparse Transformer — all delivered impressive results on paper with hand-designed sparsity patterns that preserved quality while cutting compute by 10-100x. Then FlashAttention made dense attention fast enough for 32K-64K contexts and the entire sparse attention research program collapsed overnight. The reason: sparse attention requires custom kernels that are harder to optimize than dense attention, and modern accelerators (H100 tensor cores) are so fast at dense matmuls that sparsity only wins at extreme sequence lengths where memory becomes the bottleneck anyway. Sparse attention survives in exactly two niches: sliding window attention (Mistral) for simplicity, and extreme long-context serving where Ring Attention distributes dense computation across GPUs. This post covers the full arc from sparse patterns to why they lost.
The field’s response to this problem has evolved dramatically over five years:
- 2019-2020: Sparse attention patterns (Longformer, BigBird, Sparse Transformer) traded quality for efficiency with hand-designed patterns.
- 2022: FlashAttention made dense attention fast enough for moderate sequence lengths (up to 32K-64K), largely killing the sparse attention approach for mainstream use.
- 2023: Sliding window attention (Mistral) offered a simpler alternative: each token attends to only the last tokens.
- 2023-2024: Ring Attention and context parallelism distributed long sequences across multiple GPUs, enabling 128K-1M+ contexts.
- 2025: The long-context landscape combines RoPE scaling, continued pretraining, FlashAttention, and distributed attention to deliver production-grade million-token contexts.
This post traces the full arc, from the original quadratic bottleneck to the techniques that power today’s long-context models.
The Problem
Why Attention Is Quadratic
The standard scaled dot-product attention computes:
where . The matrix multiplication produces an matrix. Every token’s query must be compared against every token’s key. This is the source of the quadratic cost.
def standard_attention(Q, K, V):
"""Standard O(n^2) attention."""
d_k = Q.shape[-1]
# QK^T: [batch, heads, n, d_k] x [batch, heads, d_k, n] = [batch, heads, n, n]
# This n x n matrix is the bottleneck
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Softmax over the key dimension
attn_weights = torch.softmax(scores, dim=-1)
# Weighted sum of values: [batch, heads, n, n] x [batch, heads, n, d_k]
output = torch.matmul(attn_weights, V)
return output
The Memory Wall at Scale
The attention matrix must be materialized in GPU memory (at least partially) for both the forward pass and backward pass. During training with gradient checkpointing disabled, you need to store the attention weights for backpropagation.
Attention Memory Requirements (Single Layer, Single Head, FP16)
| Sequence Length | Attention Matrix Size | Memory (FP16) | Feasible on 80GB GPU? |
|---|---|---|---|
| 2,048 | 4.2M entries | 8 MB | Yes -- trivial |
| 8,192 | 67.1M entries | 128 MB | Yes -- comfortable |
| 32,768 | 1.07B entries | 2 GB | Yes -- tight with many layers/heads |
| 131,072 (128K) | 17.2B entries | 32 GB | No -- exceeds memory for multi-layer model |
| 524,288 (512K) | 275B entries | 512 GB | No -- impossible on single GPU |
| 1,048,576 (1M) | 1.1T entries | 2 TB | No -- requires distributed approach |
The table above shows memory for a single layer and single head. A 70B model has 80 layers and 64 heads (8 KV heads with GQA). Even with GQA reducing the KV heads, the total attention memory for a full forward pass is the single-head cost multiplied by the number of layers and query heads. At 128K tokens, this makes naive attention completely infeasible without memory-efficient techniques.
The quadratic scaling also affects computation time. At 2K tokens, attention is a small fraction of total model FLOPS. At 128K tokens, attention dominates:
Attention FLOPS as Fraction of Total Model FLOPS (70B Model)
| Sequence Length | Attention FLOPS | FFN FLOPS | Attention Share |
|---|---|---|---|
| 2,048 | 0.34 TFLOPS | 4.6 TFLOPS | 7% |
| 8,192 | 5.5 TFLOPS | 18.4 TFLOPS | 23% |
| 32,768 | 88 TFLOPS | 73.7 TFLOPS | 54% |
| 131,072 | 1,408 TFLOPS | 294.9 TFLOPS | 83% |
At 128K tokens, attention accounts for 83% of all computation. Any optimization that reduces attention cost has massive impact on total throughput.
Sparse Attention Approaches (2019-2021)
The first generation of solutions attacked the quadratic problem by making the attention matrix sparse. Instead of every token attending to every other token, restrict attention to a subset of positions.
Sparse Transformer (2019)
OpenAI’s Sparse Transformer introduced two key sparse patterns:
- Strided attention: Token attends to tokens at positions (local window) and tokens at positions (every -th position globally).
- Fixed attention: Alternate between local attention layers and layers that attend to fixed global positions.
The complexity reduces from to by combining local and strided patterns.
def sparse_transformer_pattern(seq_len, window_size=256, stride=256):
"""Create Sparse Transformer attention pattern."""
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
for i in range(seq_len):
# Local window: attend to nearby tokens
local_start = max(0, i - window_size)
local_end = min(seq_len, i + 1) # causal
mask[i, local_start:local_end] = True
# Strided: attend to every stride-th token
strided_positions = torch.arange(0, i + 1, stride)
mask[i, strided_positions] = True
return mask
Longformer (2020)
Longformer from Allen AI refined the sparse approach with three attention patterns combined:
- Local sliding window: Each token attends to tokens on each side (window of ).
- Dilated sliding window: Like sliding window but with gaps, increasing receptive field.
- Global attention: Designated tokens (like [CLS] or question tokens) attend to all positions and are attended by all positions.
The complexity is for the local component (linear in ) plus for the global tokens.
def longformer_attention(Q, K, V, window_size=512, global_indices=None):
"""Simplified Longformer attention pattern."""
seq_len = Q.shape[1]
output = torch.zeros_like(Q)
for i in range(seq_len):
# Local window
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
# Keys to attend to: local window + global tokens
attend_to = list(range(start, end))
if global_indices is not None:
attend_to = sorted(set(attend_to) | set(global_indices))
K_subset = K[:, attend_to, :]
V_subset = V[:, attend_to, :]
scores = torch.matmul(Q[:, i:i+1, :], K_subset.transpose(-2, -1))
scores = scores / math.sqrt(Q.shape[-1])
weights = torch.softmax(scores, dim=-1)
output[:, i:i+1, :] = torch.matmul(weights, V_subset)
return output
BigBird (2020)
Google’s BigBird added random attention to the mix, combining three components:
- Local windowed attention: Same as Longformer.
- Global tokens: A set of tokens that attend to all positions.
- Random attention: Each token randomly attends to additional tokens.
The theoretical contribution was proving that this combination is a universal approximator of sequence-to-sequence functions, maintaining the expressiveness of full attention while being sparse.
Sparse Attention Methods Comparison
| Method | Pattern | Complexity | Memory | Quality vs Dense |
|---|---|---|---|---|
| Dense (baseline) | Full n x n | O(n^2 d) | O(n^2) | 100% (baseline) |
| Sparse Transformer | Strided + local | O(n sqrt(n) d) | O(n sqrt(n)) | 98-99% |
| Longformer | Local + global | O(n W d) | O(n W) | 97-99% |
| BigBird | Local + global + random | O(n W d) | O(n W) | 97-99% |
| Linformer | Low-rank projection | O(n k d) | O(n k) | 94-97% |
| Performer | Random features (FAVOR+) | O(n d^2) | O(n d) | 90-95% |
Linear Attention Variants
A separate line of research tried to eliminate the quadratic term entirely:
Linformer projected the key and value matrices to a lower dimension , reducing the attention matrix from to . The complexity becomes , linear in when is fixed.
Performer (FAVOR+) approximated the softmax kernel with random features, enabling attention to be computed as instead of , changing the association order to avoid materializing the matrix. Complexity: .
These linear variants were elegant mathematically but suffered quality degradation on real tasks, particularly for tasks requiring precise long-range attention patterns. The approximation error compounded across layers, leading to noticeable performance drops.
Linear attention methods like Performer and Random Feature Attention approximate the softmax attention kernel. But softmax attention has a crucial property: it produces sharp attention distributions that focus on a few relevant positions. The random feature approximation tends to produce blurry distributions that spread attention more uniformly. This matters most for tasks requiring precise retrieval from context — exactly the tasks where long context is most valuable.
Why Sparse Attention Lost to FlashAttention
In 2022, Tri Dao’s FlashAttention paper changed the landscape entirely. Rather than approximating or sparsifying the attention computation, FlashAttention computed exact dense attention but reorganized the computation to be IO-aware — minimizing reads and writes to GPU high-bandwidth memory (HBM).
The Key Insight: Attention Is Memory-Bound
The standard attention implementation materializes the full attention matrix in HBM, reads it back for softmax, writes the result back, then reads it again for the value multiplication. Each of these reads and writes goes through HBM, which is 10-100x slower than the GPU’s compute units.
FlashAttention’s insight: you can compute attention in tiles that fit in SRAM (the GPU’s fast on-chip memory), never materializing the full matrix in HBM. The algorithm uses the online softmax trick to compute exact softmax incrementally across tiles.
def flash_attention_conceptual(Q, K, V, block_size=256):
"""
Conceptual FlashAttention: tiled attention without materializing n x n matrix.
Real implementation is a fused CUDA kernel.
"""
n, d = Q.shape[0], Q.shape[1]
output = torch.zeros_like(Q)
row_max = torch.full((n,), float('-inf')) # running max for online softmax
row_sum = torch.zeros(n) # running sum for online softmax
# Process in blocks -- never create the full n x n matrix
for j_start in range(0, n, block_size):
j_end = min(j_start + block_size, n)
K_block = K[j_start:j_end] # [block_size, d]
V_block = V[j_start:j_end] # [block_size, d]
# Compute scores for this block: [n, block_size]
scores = Q @ K_block.T / math.sqrt(d)
# Online softmax update
block_max = scores.max(dim=-1).values
new_max = torch.maximum(row_max, block_max)
# Rescale previous accumulator
scale_old = torch.exp(row_max - new_max)
scale_new = torch.exp(block_max - new_max)
# Update running statistics
exp_scores = torch.exp(scores - new_max.unsqueeze(-1))
new_sum = scale_old * row_sum + exp_scores.sum(dim=-1)
# Update output: rescale old output + add new contribution
output = output * (scale_old * row_sum / new_sum).unsqueeze(-1)
output += (exp_scores @ V_block) / new_sum.unsqueeze(-1)
row_max = new_max
row_sum = new_sum
return output
FlashAttention Performance
The results were dramatic. FlashAttention was 2-4x faster than standard PyTorch attention and used memory instead of — while computing exact attention.
FlashAttention vs Standard Attention: Wall Clock Time
(ms)FlashAttention vs Standard vs Sparse (A100, FP16, Causal)
| Sequence Length | Standard Attn (ms) | FlashAttention-2 (ms) | Sparse (Longformer, ms) | FA Speedup vs Standard |
|---|---|---|---|---|
| 2,048 | 1.2 | 0.5 | 0.8 | 2.4x |
| 4,096 | 4.8 | 1.4 | 1.5 | 3.4x |
| 8,192 | 19.2 | 4.2 | 2.9 | 4.6x |
| 16,384 | 76.8 | 14.1 | 5.7 | 5.4x |
| 32,768 | 307.2 | 48.3 | 11.2 | 6.4x |
| 65,536 | OOM | 172.0 | 22.1 | N/A (OOM vs works) |
Before FlashAttention, the choice was: (1) dense attention that is exact but OOMs at long sequences, or (2) sparse attention that works at long sequences but loses quality. FlashAttention created option (3): exact dense attention that works at long sequences. At sequence lengths up to 32K-64K, FlashAttention is fast enough and has zero quality compromise. Sparse attention’s complexity — custom attention patterns, quality monitoring, pattern-specific bugs — was no longer worth the tradeoff for these sequence lengths.
FlashAttention-2 and FlashAttention-3
FlashAttention-2 (2023) improved on the original with better parallelism across sequence length and attention heads, achieving close to the theoretical maximum throughput on A100 GPUs (up to 230 TFLOPS in FP16/BF16, compared to the peak 312 TFLOPS).
FlashAttention-3 (2024) targets H100 GPUs with FP8 support, asynchronous block-wise softmax, and warp specialization, pushing throughput even higher and enabling efficient attention at 128K+ sequence lengths on a single GPU.
FlashAttention Evolution
| Version | Year | Key Innovation | Throughput (A100, 16K seq) | Memory |
|---|---|---|---|---|
| Standard PyTorch | Pre-2022 | Baseline | ~60 TFLOPS | O(n^2) |
| FlashAttention v1 | 2022 | Tiled IO-aware attention | ~120 TFLOPS | O(n) |
| FlashAttention v2 | 2023 | Better parallelism, reduced non-matmul FLOPs | ~200 TFLOPS | O(n) |
| FlashAttention v3 | 2024 | H100 FP8, async softmax, warp specialization | ~300 TFLOPS (H100) | O(n) |
The progression tells a clear story: instead of reducing the amount of computation (sparse attention), FlashAttention made the computation faster by optimizing how data moves through the GPU memory hierarchy. This is an IO-first rather than FLOP-first optimization.
Sliding Window Attention (Mistral, 2023)
While FlashAttention solved the memory problem for dense attention, some models opted for a simpler architectural approach: just limit how far back each token can attend.
The Design
Sliding window attention (SWA) is exactly what it sounds like. Each token at position attends only to tokens in the range , where is the window size. The attention matrix is banded rather than full lower-triangular.
def sliding_window_attention(Q, K, V, window_size=4096):
"""Each token attends to at most window_size previous tokens."""
n = Q.shape[1]
output = torch.zeros_like(Q)
for i in range(n):
start = max(0, i - window_size + 1)
# Only attend to tokens within the window
K_window = K[:, start:i+1, :]
V_window = V[:, start:i+1, :]
scores = torch.matmul(Q[:, i:i+1, :], K_window.transpose(-2, -1))
scores = scores / math.sqrt(Q.shape[-1])
weights = torch.softmax(scores, dim=-1)
output[:, i:i+1, :] = torch.matmul(weights, V_window)
return output
The complexity is — linear in sequence length for fixed . Memory for the KV cache is bounded at tokens per layer regardless of total sequence length.
Why It Works: Information Flow Through Layers
The key insight is that information propagates through layers. If the window size is and the model has layers, then at the final layer, token has an effective receptive field of tokens. Information from distant tokens reaches the current position by hopping through intermediate tokens across layers.
This is analogous to how CNNs build large receptive fields from small convolution kernels stacked in deep networks. Each layer adds tokens of direct context, and the indirect context grows linearly with depth.
Sliding Window Attention: Effective Context
| Window Size | Layers | Direct Context | Effective Receptive Field | KV Cache per Layer |
|---|---|---|---|---|
| 4,096 | 32 | 4,096 tokens | 131,072 tokens | 32 MB (7B model, FP16) |
| 4,096 | 80 | 4,096 tokens | 327,680 tokens | 80 MB (70B model, FP16) |
| 8,192 | 32 | 8,192 tokens | 262,144 tokens | 64 MB (7B model, FP16) |
| 32,768 | 32 | 32,768 tokens | 1,048,576 tokens | 256 MB (7B model, FP16) |
Mistral’s Implementation
Mistral 7B (2023) used sliding window attention with as a core architectural choice. Combined with Grouped-Query Attention (GQA, 8 KV heads), this dramatically reduced memory requirements during inference:
- Full attention KV cache at 32K tokens: MB per layer
- SWA KV cache at 32K tokens: MB per layer (with GQA, much less)
The 8x reduction in KV cache memory translates directly to 8x higher batch sizes during serving, which means 8x higher throughput for the same hardware.
In practice, sliding window attention is implemented using FlashAttention with a causal+window mask. FlashAttention handles the tiling and IO optimization; the window mask just means fewer tiles are non-zero. This combination gives you both the memory savings of SWA and the computational efficiency of FlashAttention. Mistral’s implementation does exactly this.
Limitations of Sliding Window
SWA has a fundamental limitation: the information propagation through layers is lossy. Each hop through an intermediate token involves passing information through a nonlinear transformation. Fine-grained details from distant tokens degrade as they propagate. This means:
- Tasks requiring exact retrieval from early in a long document (needle-in-a-haystack) degrade with distance.
- Copy-paste style tasks (reproduce a specific passage from the context) fail for passages outside the direct window.
- The effective receptive field is large but the effective attention is much smaller — most influence comes from nearby tokens.
For these reasons, models targeting very long contexts (128K+) typically use full attention rather than pure sliding window. Some models (like Mixtral and Jamba) use SWA in most layers with a few full-attention layers interspersed.
Ring Attention: Distributing Across GPUs
For context lengths beyond what a single GPU can handle (roughly 64K-128K for large models with FlashAttention), the solution is to distribute the sequence across multiple GPUs. Ring Attention (Liu et al., 2023) is the most elegant approach.
The Core Mechanism
Ring Attention distributes the sequence across GPUs (or “hosts”), with each GPU holding tokens. The attention computation proceeds in rounds. In each round, each GPU computes attention between its local query chunk and the current key-value chunk, then passes the KV chunk to the next GPU in a ring topology.
def ring_attention(Q_local, K_local, V_local, comm_group, num_gpus):
"""
Ring Attention: distribute sequence across GPUs in a ring.
Each GPU holds n/P tokens.
"""
chunk_size = Q_local.shape[1] # n/P tokens per GPU
rank = get_rank(comm_group)
# Initialize output and softmax statistics
output = torch.zeros_like(Q_local)
row_max = torch.full((chunk_size,), float('-inf'))
row_sum = torch.zeros(chunk_size)
# Current KV chunk to process (starts as local)
K_recv = K_local.clone()
V_recv = V_local.clone()
for step in range(num_gpus):
# Determine which chunk of the sequence this KV belongs to
source_rank = (rank - step) % num_gpus
# Compute attention between local Q and received KV
# Use online softmax for incremental accumulation
scores = Q_local @ K_recv.transpose(-2, -1) / math.sqrt(d_k)
# Apply causal mask if needed (based on position offsets)
if causal:
mask = create_causal_mask(rank, source_rank, chunk_size)
scores = scores.masked_fill(~mask, float('-inf'))
# Online softmax update (same as FlashAttention)
block_max = scores.max(dim=-1).values
new_max = torch.maximum(row_max, block_max)
scale_old = torch.exp(row_max - new_max)
exp_scores = torch.exp(scores - new_max.unsqueeze(-1))
new_sum = scale_old * row_sum + exp_scores.sum(dim=-1)
output = output * (scale_old * row_sum / new_sum).unsqueeze(-1)
output += (exp_scores @ V_recv) / new_sum.unsqueeze(-1)
row_max = new_max
row_sum = new_sum
# OVERLAP: send current KV to next GPU while receiving from previous
# This is the key -- communication is overlapped with computation
K_recv, V_recv = ring_send_recv(K_recv, V_recv, comm_group)
return output
Why Ring Attention Works
The brilliance of ring attention is that communication overlaps with computation. While GPU is computing attention with KV chunk , it is simultaneously sending chunk to GPU and receiving chunk from GPU . If the computation time exceeds the communication time (which it does for large enough chunks), the communication is entirely hidden.
The memory per GPU is — each GPU only stores its local chunk of Q, K, V, and the output. The total computation is the same as standard attention (), but distributed across GPUs with near-zero communication overhead.
Ring Attention: Scaling with GPUs
| Sequence Length | GPUs | Tokens per GPU | Memory per GPU | Effective Throughput |
|---|---|---|---|---|
| 128K | 1 | 128K | OOM (70B model) | N/A |
| 128K | 4 | 32K | ~40 GB | 3.8x of 1 GPU |
| 128K | 8 | 16K | ~20 GB | 7.5x of 1 GPU |
| 512K | 8 | 64K | ~40 GB | 7.2x of 1 GPU |
| 1M | 16 | 64K | ~40 GB | 14.8x of 1 GPU |
| 1M | 64 | 16K | ~10 GB | 58x of 1 GPU |
In practice, ring attention uses FlashAttention for the local attention computation within each step of the ring. The online softmax trick in FlashAttention is the same trick used to accumulate attention across ring steps. This means ring attention inherits FlashAttention’s IO efficiency while distributing across GPUs. The combination is the standard approach for training models with 128K+ context.
Causal Masking in Ring Attention
For autoregressive (causal) models, ring attention needs special handling. In causal attention, token can only attend to tokens . When GPU holds queries for positions and receives KV for positions , the causal mask means:
- If : The entire block is masked (all queries come before all keys). Skip computation entirely.
- If : The entire block is unmasked (all queries come after all keys). Compute full attention.
- If : Apply the standard causal mask within the block.
This means roughly half the ring steps can be skipped entirely for causal attention, effectively doubling throughput compared to bidirectional ring attention.
Context Parallelism: Ulysses and Beyond
Ring attention is not the only approach to distributing attention. DeepSpeed-Ulysses (2023) takes a different approach based on all-to-all communication on the sequence dimension.
Ulysses: All-to-All on the Head Dimension
Ulysses distributes the sequence across GPUs but, before computing attention, performs an all-to-all transpose so that each GPU holds all tokens for a subset of attention heads. Attention is then computed locally on each GPU (full sequence, subset of heads), and another all-to-all transposes back.
def ulysses_attention(Q, K, V, comm_group, num_gpus):
"""
DeepSpeed-Ulysses: all-to-all communication for context parallelism.
Input: each GPU has all heads but n/P tokens
After all-to-all: each GPU has all tokens but H/P heads
"""
# Step 1: All-to-all -- redistribute from [seq_chunk, all_heads] to [all_seq, head_chunk]
Q_full_seq = all_to_all(Q, split_dim=SEQ, gather_dim=HEAD, group=comm_group)
K_full_seq = all_to_all(K, split_dim=SEQ, gather_dim=HEAD, group=comm_group)
V_full_seq = all_to_all(V, split_dim=SEQ, gather_dim=HEAD, group=comm_group)
# Step 2: Standard attention on full sequence, subset of heads
# Each GPU computes exact attention for its heads
output = flash_attention(Q_full_seq, K_full_seq, V_full_seq)
# Step 3: All-to-all -- redistribute back to [seq_chunk, all_heads]
output = all_to_all(output, split_dim=HEAD, gather_dim=SEQ, group=comm_group)
return output
Ring Attention vs Ulysses
The two approaches have different communication patterns and different strengths:
Ring Attention vs Ulysses
| Property | Ring Attention | Ulysses (All-to-All) | Winner |
|---|---|---|---|
| Communication pattern | Point-to-point in ring | All-to-all collective | Depends on network |
| Communication volume | O(n d / P) per step | O(n d) total (two all-to-alls) | Ring (lower volume) |
| Overlap with compute | Yes (pipelined) | No (synchronous) | Ring |
| Implementation complexity | Higher (ring scheduling) | Lower (standard collectives) | Ulysses |
| Load balance (causal) | Unbalanced (triangle mask) | Balanced (each GPU has full seq) | Ulysses |
| Minimum GPUs needed | 2 | H (number of heads) | Ring (more flexible) |
| Works with GQA | Yes | Limited by KV heads | Ring (more flexible) |
Ring attention with causal masking has a load balance issue. The GPU holding the earliest sequence chunk does almost no computation (most KV chunks are masked), while the GPU holding the latest chunk does nearly full computation. Ulysses avoids this because each GPU processes all sequence positions (for a subset of heads). In practice, the load imbalance in ring attention is addressed by overdecomposing — using more ring steps than GPUs and scheduling them to balance work.
Hybrid Approaches
Modern training frameworks combine ring attention and Ulysses-style parallelism with tensor parallelism and data parallelism. For example:
- Tensor parallelism (TP) within a node: split the model’s weight matrices across GPUs in a single node (fast NVLink).
- Context parallelism (CP) across a ring: distribute the sequence across nodes for long context.
- Data parallelism (DP): replicate the model across groups, each processing different batches.
A typical 128K-context training setup for a 70B model might use:
- TP = 8 (one node of 8 GPUs)
- CP = 4 (4 nodes in a ring for context)
- DP = 16 (16 replicas for data parallelism)
- Total: 512 GPUs
Parallelism Strategy for Long-Context Training
| Model Size | Context Length | TP | CP | DP | Total GPUs |
|---|---|---|---|---|---|
| 7B | 32K | 1 | 1 | 64 | 64 |
| 7B | 128K | 2 | 4 | 32 | 256 |
| 70B | 32K | 8 | 1 | 32 | 256 |
| 70B | 128K | 8 | 4 | 16 | 512 |
| 405B | 128K | 8 | 8 | 8 | 512 |
| 405B | 1M | 8 | 32 | 4 | 1024 |
The Long-Context Landscape in 2025
Modern frontier models support remarkable context lengths:
- Gemini 1.5 Pro: 1M tokens (2M in research previews)
- Claude 3.5: 200K tokens
- GPT-4 Turbo/GPT-4o: 128K tokens
- Llama 3.1: 128K tokens
- Mistral Large: 128K tokens
- DeepSeek-V3: 128K tokens
- Qwen-2.5: 128K tokens (1M with YaRN scaling)
How do they achieve this? The answer is a combination of techniques, not any single innovation.
Recipe 1: RoPE Scaling
Rotary Position Embeddings (RoPE) encode position information through rotation matrices applied to queries and keys. The base frequency determines how position information is encoded:
where for dimension .
To extend context length beyond the training length, you can scale the frequencies:
NTK-aware scaling modifies the base frequency: where is the scaling factor. This interpolates smoothly between original and extended contexts.
YaRN (Yet another RoPE extensioN) combines NTK-aware scaling with attention temperature scaling and fine-tuning, achieving the best quality for extended contexts:
def yarn_rope_scaling(dim, max_position, original_max=4096, scale_factor=32):
"""YaRN RoPE scaling for context extension."""
# NTK-aware interpolation
beta_fast = 32 # high frequency boundary
beta_slow = 1 # low frequency boundary
base = 10000.0
# Compute per-dimension scaling factors
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
old_context_len = original_max
new_context_len = original_max * scale_factor
# Linear ramp between low and high frequency regions
low_freq_factor = old_context_len / (2 * math.pi / freqs)
ramp = (low_freq_factor - beta_slow) / (beta_fast - beta_slow)
ramp = ramp.clamp(0, 1)
# Interpolate between original and scaled frequencies
scaled_freqs = freqs / scale_factor
new_freqs = (1 - ramp) * scaled_freqs + ramp * freqs
# Attention temperature scaling
attn_scale = 0.1 * math.log(scale_factor) + 1.0
return new_freqs, attn_scale
Recipe 2: Continued Pre-Training on Long Documents
RoPE scaling alone is not sufficient. The model also needs to see long documents during training. The standard approach is:
- Pre-train the model at a shorter context (e.g., 4K-8K tokens) for most of training.
- Continue pre-training at the target context length (e.g., 128K tokens) for a small fraction of total training (1-5% of tokens).
- Use a dataset enriched with long documents (books, code repositories, long articles).
Llama 3.1’s technical report describes this process: after pre-training Llama 3.1 405B at 8K context on 15T tokens, they continued training at 128K context on approximately 800B additional tokens. The context extension phase is roughly 5% of total training compute.
Recipe 3: Infrastructure (FlashAttention + Ring Attention)
The training infrastructure must support the long context lengths efficiently:
- FlashAttention v2/v3 for IO-efficient attention within each GPU.
- Ring attention or context parallelism to distribute the sequence across GPUs.
- Sequence packing to minimize padding waste when mixing different-length documents.
- Gradient checkpointing to reduce memory for activations (at the cost of recomputation).
How Frontier Models Achieve Long Context
| Model | Context Length | Position Encoding | Attention Method | Training Approach |
|---|---|---|---|---|
| Llama 3.1 405B | 128K | RoPE (scaled) | FlashAttention + CP | Continued pretraining at 128K |
| Claude 3.5 Sonnet | 200K | Unknown (proprietary) | Unknown | Unknown (proprietary) |
| GPT-4 Turbo | 128K | Unknown | Unknown | Unknown (proprietary) |
| Gemini 1.5 Pro | 1M | Unknown (likely RoPE variant) | Likely ring attention variant | Trained on long documents |
| Mistral Large | 128K | RoPE | SWA + full attention layers | Mixed-length training |
| Qwen-2.5 72B | 128K (1M w/ YaRN) | RoPE + YaRN | FlashAttention | Continued pretraining + YaRN |
Recipe 4: Inference Optimizations for Long Context
Serving models at long context lengths requires additional optimizations beyond training:
KV cache quantization: Compress the KV cache from FP16 to FP8 or INT4, halving or quartering memory. Recent work (KIVI, KVQuant) shows this can be done with minimal quality loss, especially for older KV entries that have been “attended over” many times.
KV cache eviction: Discard KV entries for tokens that receive low attention. H2O (Heavy Hitter Oracle) keeps only the tokens with the highest cumulative attention scores. StreamingLLM keeps the first few tokens (attention sinks) plus a sliding window of recent tokens.
Prefix caching: For repeated prefixes (system prompts, RAG contexts), cache the KV values and reuse them across requests. This avoids recomputing the prefill for shared context.
def kv_cache_management_strategies():
"""Compare KV cache optimization strategies for long context."""
strategies = {
"full_kv_cache": {
"memory": "O(n * L * d)",
"quality": "100% (baseline)",
"max_context_80GB": "~32K tokens (70B model)",
},
"gqa_kv_cache": {
"memory": "O(n * L * d * (kv_heads / q_heads))",
"quality": "~99.5%",
"max_context_80GB": "~128K tokens (70B model, 8 KV heads)",
},
"quantized_kv_fp8": {
"memory": "O(n * L * d / 2)",
"quality": "~99%",
"max_context_80GB": "~64K tokens (70B model)",
},
"quantized_kv_int4": {
"memory": "O(n * L * d / 4)",
"quality": "~97%",
"max_context_80GB": "~128K tokens (70B model)",
},
"sliding_window_eviction": {
"memory": "O(W * L * d)",
"quality": "~95% (degrades with distance)",
"max_context_80GB": "Unlimited (fixed memory)",
},
"h2o_eviction": {
"memory": "O(budget * L * d)",
"quality": "~97% (keeps important tokens)",
"max_context_80GB": "Unlimited (fixed memory)",
},
}
return strategies
KV Cache Optimization: Memory vs Quality (70B Model, 128K Context)
| Strategy | KV Cache Memory | Needle-in-Haystack Accuracy | Throughput (tok/s) |
|---|---|---|---|
| Full KV (FP16) | 42 GB | 99.2% | 450 |
| GQA (8 KV heads, FP16) | 5.2 GB | 99.0% | 1,200 |
| GQA + FP8 quantization | 2.6 GB | 98.5% | 1,800 |
| GQA + INT4 quantization | 1.3 GB | 96.8% | 2,400 |
| Sliding window (32K) | 1.3 GB | 85% (at 128K distance) | 2,800 |
| H2O (budget=4K) | 0.16 GB | 92% | 3,200 |
When Sparse Attention Still Wins
Despite FlashAttention’s dominance, there are scenarios where sparse attention remains the right approach.
Very Long Sequences (greater than 1M Tokens)
At sequence lengths beyond 1M tokens, even FlashAttention’s computation becomes prohibitive. The attention FLOPS for a single layer at 1M tokens with :
Even on an H100 at 1 PFLOPS, that is 256 milliseconds per layer for attention alone. With 80 layers, attention takes 20 seconds per forward pass. Sparse attention with or complexity would reduce this to under 1 second.
Emerging use cases that need very long contexts:
- Genomics: DNA sequences of 1M-10M base pairs. Each base is a token.
- Whole-codebase analysis: Entire repositories as context (10M+ tokens).
- Long-form video understanding: 1 hour of video at 1 frame/second with 256 tokens per frame = 921K tokens.
- Multi-document reasoning: 100+ documents in context for complex question answering.
Resource-Constrained Deployment
For edge and mobile deployment, even FlashAttention does not help if the device cannot store the full KV cache. Sparse or windowed attention bounds memory consumption:
- On-device LLMs (phones, laptops): Limited to 4-8 GB of memory for the entire model + KV cache.
- Embedded systems: Even more constrained. Sliding window with may be necessary.
- Cost-sensitive serving: When serving millions of users, every GB of KV cache per request translates to hardware costs. Sparse attention or aggressive eviction directly reduces serving costs.
When to Use Which Attention Method
| Scenario | Best Method | Why | Quality Trade-off |
|---|---|---|---|
| Seq length less than 8K | FlashAttention (dense) | Fast enough, exact attention | None |
| 8K-64K, single GPU | FlashAttention (dense) | Manageable memory with GQA | None |
| 64K-256K, multi-GPU | Ring Attention + FlashAttention | Distribute across GPUs | None (exact) |
| 256K-1M, multi-GPU | Ring Attention + FlashAttention | More GPUs for context parallelism | None (exact) |
| Greater than 1M tokens | Sparse + Ring Attention hybrid | Quadratic cost too high even distributed | Small (1-3%) |
| Edge / mobile | Sliding window or sparse | Memory bounded | Moderate (5-10% at long range) |
| Cost-optimized serving | SWA + KV eviction | Minimize memory per request | Moderate (context-dependent) |
Specific Task Types
Some tasks have inherent locality that sparse attention exploits:
- Language modeling: Most dependencies are local. Sliding window captures 95%+ of the relevant context.
- DNA/protein sequences: Local structure matters most. Sparse attention with biological priors outperforms dense attention at the same compute budget.
- Time series: Temporal locality is strong. Local + periodic attention patterns match the data structure.
- Structured data (tables, code with nesting): Sparse patterns aligned with document structure outperform random dense attention.
Start with FlashAttention (dense). If your sequence length exceeds what your hardware can handle, add context parallelism (ring attention). If your sequences exceed 1M tokens or you are severely memory-constrained, consider sparse attention. Do not reach for sparse attention as a first resort — the quality-complexity trade-off is rarely worth it when FlashAttention works.
The Evolution of the Attention Landscape
The trajectory of attention efficiency tells a clear story about how systems engineering trumps algorithmic cleverness when hardware constraints shift.
Phase 1: Algorithmic Innovation (2019-2021)
The initial response to the quadratic bottleneck was algorithmic: design new attention patterns that avoid the cost. This produced Sparse Transformer, Longformer, BigBird, Performer, Linformer, and many others. Each paper proposed a different sparsity pattern with different theoretical properties.
The problem: each pattern required custom implementation, custom CUDA kernels, custom backward passes, and careful validation that the quality trade-off was acceptable for the target task. The ecosystem was fragmented. No single sparse attention method became a standard.
Phase 2: Systems Optimization (2022-2023)
FlashAttention shifted the conversation from “how to avoid computing attention” to “how to compute attention efficiently on actual hardware.” The insight was that the bottleneck was not FLOPS but memory bandwidth — and you could solve that without approximating the computation.
This was a systems engineering insight, not an algorithmic one. The same computation, reorganized for the memory hierarchy, became 2-4x faster and used memory. The entire sparse attention research direction became less relevant for moderate sequence lengths.
Phase 3: Distributed Computation (2023-2025)
For truly long contexts, the solution was distributed systems: ring attention, context parallelism, sequence sharding. These methods keep exact attention but distribute it across GPUs.
The evolution mirrors a common pattern in computing: when a problem seems to require a clever algorithm, often a better systems approach (faster hardware, better memory management, more parallelism) makes the naive algorithm fast enough.
Long-Context Capability Over Time
(Maximum Context Length (tokens)) lineThe Attention Efficiency Timeline
| Year | Key Development | Maximum Practical Context | Approach |
|---|---|---|---|
| 2017 | Original Transformer | ~512 tokens | Dense attention, no optimization |
| 2019 | Sparse Transformer, XL | ~8K tokens | Sparse patterns, recurrence |
| 2020 | Longformer, BigBird | ~16K tokens | Local + global sparse attention |
| 2021 | ALiBi, RoPE | ~8K tokens (better extrapolation) | Position encoding improvements |
| 2022 | FlashAttention v1 | ~16K-32K tokens | IO-aware dense attention |
| 2023 | FlashAttention v2, Ring Attention | ~128K tokens | Better FA + distributed attention |
| 2024 | FA3, YaRN, context parallelism | ~1M tokens | H100 optimization + RoPE scaling + CP |
| 2025 | Production 1M+ context | ~2M tokens | Full stack: hardware + FA + CP + RoPE |
Performance Benchmarks: Long Context in Practice
How well do modern long-context models actually use their context? The needle-in-a-haystack test measures this: insert a specific fact at various positions in a long document and test retrieval accuracy.
Needle-in-a-Haystack Performance (2025 Models)
| Model | Context Claimed | 95%+ Accuracy Up To | Attention Method | Notes |
|---|---|---|---|---|
| GPT-4 Turbo | 128K | ~120K | Unknown (likely FA variant) | Some degradation at end of context |
| Claude 3.5 Sonnet | 200K | ~195K | Unknown (proprietary) | Near-perfect across full context |
| Llama 3.1 70B | 128K | ~120K | FlashAttention + CP | Slight degradation near boundaries |
| Gemini 1.5 Pro | 1M | ~1M | Unknown | Near-perfect even at 1M tokens |
| Mistral Large | 128K | ~100K | SWA + full attention | Degrades past SWA window distance |
| Qwen-2.5 72B | 128K | ~110K | FlashAttention | Good performance with YaRN scaling |
There is an important distinction between a model’s nominal context length and how effectively it uses that context. Many models can accept 128K tokens but show degraded performance for information in the middle of the context (“lost in the middle” phenomenon). True long-context capability means uniform performance regardless of where the relevant information appears. This is still an active area of improvement.
Throughput at Long Context
Long context dramatically affects serving throughput because the KV cache consumes memory that would otherwise be used for batching:
Serving Throughput vs Context Length (Llama 3.1 70B, 8xH100)
| Context Length | Max Batch Size | Throughput (tok/s) | Cost per 1M tokens | KV Cache per Request |
|---|---|---|---|---|
| 4K | 128 | 24,000 | $0.12 | 0.16 GB |
| 16K | 64 | 14,400 | $0.20 | 0.65 GB |
| 32K | 32 | 8,800 | $0.33 | 1.3 GB |
| 128K | 4 | 1,600 | $1.80 | 5.2 GB |
| 512K (theoretical) | 1 | 320 | $9.00 | 20.8 GB |
The economics are stark: going from 4K to 128K context increases cost per token by 15x, primarily because the KV cache limits batch size. This is why long-context API pricing is significantly higher and why KV cache optimization is critical for production viability.
Future Directions
Attention-Free Architectures
State space models (Mamba, RWKV) and linear attention variants avoid the quadratic bottleneck entirely. Mamba processes sequences in time with state size per token. If these architectures can match transformer quality at scale, the entire attention optimization story becomes moot.
Current status: Mamba-2 and Jamba show promise at 7B-52B scale, but no attention-free model has yet demonstrated frontier-level quality at 400B+ parameters. The transformer’s attention mechanism appears to provide something that linear models struggle to replicate — precise, content-based retrieval over long contexts.
Native Long-Context Training
Current models train primarily on short sequences and extend to long context via continued pre-training. Future models may train natively at long context from the start, enabled by cheaper hardware and better distributed training frameworks. This could improve long-context quality significantly.
Hybrid Sparse-Dense Approaches
Some emerging approaches use dense attention for nearby tokens and sparse attention for distant tokens, with the transition controlled by a learned boundary or a fixed heuristic. This matches the empirical observation that attention weights are typically concentrated locally with sparse long-range connections.
Hardware-Software Co-Design
Future accelerators may have hardware support for attention-specific operations: on-chip softmax units, dedicated KV cache memory, or attention-aware memory hierarchies. Groq’s LPU and Cerebras’s wafer-scale engine are early examples of hardware designed around the attention bottleneck.
Conclusion
The long-context story is a case study in how systems engineering solves problems that algorithms alone cannot. The original response to the attention bottleneck was algorithmic: sparse attention patterns that traded quality for efficiency. This worked, but the trade-off was uncomfortable — every sparse pattern had failure modes, and the implementation complexity was high.
FlashAttention disrupted this by showing that the bottleneck was memory IO, not computation. By reorganizing the same computation to be IO-aware, FlashAttention made dense attention fast enough for moderate context lengths (up to 32K-64K) with zero quality loss. Sparse attention’s complexity was no longer justified for these lengths.
For truly long contexts (128K-1M+), the solution shifted to distributed systems. Ring attention and context parallelism distribute the sequence across GPUs, keeping exact attention while scaling memory linearly with GPU count. Combined with RoPE scaling and continued pre-training on long documents, this stack enables production-grade million-token contexts.
Sparse attention is not dead. It retains clear advantages for very long sequences (greater than 1M tokens), resource-constrained deployment, and domain-specific tasks with inherent locality. But it is no longer the primary tool for extending context length. The primary tools are now FlashAttention for single-GPU efficiency and ring attention for multi-GPU distribution.
The practical lesson is clear: when you need longer context, first try FlashAttention with your existing architecture. If that is not enough, add context parallelism with more GPUs. Only reach for sparse attention when you have exhausted the dense-attention approaches or have specific constraints (memory, latency, domain structure) that sparse attention uniquely addresses. The field has learned that exact computation with good systems engineering beats approximate computation with clever algorithms — at least until the sequences get long enough that even the best systems cannot make quadratic scaling work.