At 1M tokens, standard quadratic attention requires operations per layer per head. Even FlashAttention, which reduces HBM traffic, still performs these FLOPs. The KV cache for a single sequence reaches 260 GB (Llama 70B), far exceeding any single GPU’s memory. Achieving million-token context requires fundamentally different approaches across three dimensions: position encoding that extrapolates, attention mechanisms that scale subquadratically, and distributed systems that split sequences across GPUs.
The Three Challenges
Challenges at Different Context Lengths
| Context | Attention FLOPs/layer | KV Cache (70B FP16) | Primary Bottleneck |
|---|---|---|---|
| 4K | 33M | 1.3 GB | None (fits easily) |
| 32K | 2.1B | 10.5 GB | KV cache memory |
| 128K | 33.6B | 41.9 GB | KV cache exceeds 1 GPU |
| 1M | 2.0T | 327 GB | Both compute and memory |
| 10M | 200T | 3.3 TB | Fundamental hardware limits |
Approach 1: RoPE Scaling for Context Extension
The simplest approach: take a model trained on 4K-8K context and extend it to 128K+ by modifying the RoPE frequencies. Three methods, in order of quality:
Linear Scaling (PI)
Divide all positions by a factor : position 128,000 becomes position 128,000/32 = 4,000 (within training range).
Problem: uniformly compresses all frequencies, losing fine-grained local position discrimination.
NTK-Aware Scaling
Scale the RoPE base frequency instead of positions:
This stretches high frequencies (local patterns) more than low frequencies (global patterns).
YaRN (Yet Another RoPE Extension)
Combine NTK scaling with per-dimension attention temperature adjustment:
def yarn_rope_scaling(dim, base=10000.0, scale=32, original_max_pos=4096):
"""YaRN: NTK scaling + temperature correction."""
# NTK-aware base
new_base = base * (scale ** (dim / (dim - 2)))
# Compute frequencies
freqs = 1.0 / (new_base ** (torch.arange(0, dim, 2).float() / dim))
# Temperature correction per dimension
# Low-frequency dims: no correction needed
# High-frequency dims: scale down to prevent entropy increase
wavelengths = 2 * 3.14159 / freqs
ratios = wavelengths / original_max_pos
# Ramp function: smooth transition between scaled and unscaled
low = 1.0 # Below this wavelength: fully scale
high = 32.0 # Above this: don't scale
smooth = (ratios - low) / (high - low)
smooth = smooth.clamp(0, 1)
# Blend original and scaled frequencies
original_freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
adjusted_freqs = (1 - smooth) * freqs + smooth * original_freqs
return adjusted_freqs
Perplexity at Extended Context (Llama 2 7B, trained on 4K)
(perplexity (lower is better))Approach 2: Ring Attention for Distributed Sequences
When a sequence is too long for one GPU’s memory, split it across GPUs:
GPU 0: tokens [0, N/P) — holds KV cache for this chunk
GPU 1: tokens [N/P, 2N/P) — holds KV cache for this chunk
...
GPU P-1: tokens [(P-1)N/P, N) — holds KV cache for this chunk
Each GPU computes attention for its chunk of queries against ALL keys/values. KV blocks are passed in a ring: GPU 0 sends its KV to GPU 1, GPU 1 sends to GPU 2, …, GPU P-1 sends to GPU 0. After P rounds, every GPU has seen every other GPU’s KV.
def ring_attention_step(Q_local, K_local, V_local, P, comm):
"""One step of ring attention on GPU rank."""
output = torch.zeros_like(Q_local)
max_scores = torch.full((Q_local.shape[0],), float("-inf"))
sum_exp = torch.zeros(Q_local.shape[0])
K_recv, V_recv = K_local, V_local
for step in range(P):
# Compute local attention: Q_local against current K_recv, V_recv
scores = Q_local @ K_recv.T / math.sqrt(d)
# Online softmax update (same as FlashAttention)
new_max = torch.max(max_scores, scores.max(dim=-1).values)
correction = torch.exp(max_scores - new_max)
exp_scores = torch.exp(scores - new_max.unsqueeze(-1))
sum_exp = sum_exp * correction + exp_scores.sum(dim=-1)
output = output * correction.unsqueeze(-1) + exp_scores @ V_recv
max_scores = new_max
# Send K, V to next GPU in ring (async)
K_recv, V_recv = comm.ring_send_recv(K_recv, V_recv)
return output / sum_exp.unsqueeze(-1)
Memory: each GPU stores tokens of KV cache. For 1M tokens on 8 GPUs: 125K tokens per GPU = 40.9 GB (Llama 70B). Fits on one H100.
Communication: each step sends bytes. At NVLink 900 GB/s: negligible compared to compute. At InfiniBand 50 GB/s: becomes the bottleneck for more than 4 GPUs.
Approach 3: Linear Attention (Lightning Attention)
As covered in Frontier Research Part 2: replace softmax attention with linear attention. No KV cache growth. Constant memory per layer. The only approach that truly scales to 10M+ tokens.
Tradeoff: 1-2% quality loss on short contexts (under 32K tokens) where softmax attention is superior.
Production Context Length Limits
Practical Context Limits by Approach (2025)
| Approach | Max Practical Context | Quality at Max | Infrastructure Required |
|---|---|---|---|
| RoPE + FlashAttention (single GPU) | 128K | Native quality | 1x H100 (with KV quantization) |
| Ring Attention (multi-GPU) | 1M | Near-native | 8x H100 (NVLink required) |
| Lightning Attention | 4M+ (train), 10M+ (infer) | 1-2% PPL degradation | 1x H100 (constant memory) |
| Hybrid (Ring + Lightning) | 10M+ | Near-native | 8x H100 |
128K context with RoPE scaling + FlashAttention covers 95% of production use cases. Ring Attention for 1M is needed only for full-codebase analysis, multi-document QA, or long video understanding. Lightning Attention for 4M+ is niche — extremely long documents, genomics, or continuous monitoring.
Reviewer Agent Validation
Challenge: For a Llama 70B model on 4 H100 GPUs with Ring Attention, compute: (a) KV cache per GPU at 512K context, (b) ring communication volume per attention layer.
Expected:
- (a) Each GPU holds 512K/4 = 128K tokens. KV per token per layer = 4,096 bytes. Across 80 layers: 128K x 80 x 4096 = 41.9 GB per GPU. Fits in 80GB H100.
- (b) Each ring step sends 128K tokens x 2 x 8 heads x 128 dim x 2 bytes = 524 MB. Over NVLink (900 GB/s): 0.58 ms per step. 4 steps per layer: 2.3 ms per layer.