At 1M tokens, standard quadratic attention requires 101210^{12} operations per layer per head. Even FlashAttention, which reduces HBM traffic, still performs these FLOPs. The KV cache for a single sequence reaches 260 GB (Llama 70B), far exceeding any single GPU’s memory. Achieving million-token context requires fundamentally different approaches across three dimensions: position encoding that extrapolates, attention mechanisms that scale subquadratically, and distributed systems that split sequences across GPUs.

The Three Challenges

📊

Challenges at Different Context Lengths

ContextAttention FLOPs/layerKV Cache (70B FP16)Primary Bottleneck
4K 33M 1.3 GB None (fits easily)
32K 2.1B 10.5 GB KV cache memory
128K 33.6B 41.9 GB KV cache exceeds 1 GPU
1M 2.0T 327 GB Both compute and memory
10M 200T 3.3 TB Fundamental hardware limits

Approach 1: RoPE Scaling for Context Extension

The simplest approach: take a model trained on 4K-8K context and extend it to 128K+ by modifying the RoPE frequencies. Three methods, in order of quality:

Linear Scaling (PI)

Divide all positions by a factor ss: position 128,000 becomes position 128,000/32 = 4,000 (within training range).

θi=θi,pos=pos/s\theta'_i = \theta_i, \quad \text{pos}' = \text{pos} / s

Problem: uniformly compresses all frequencies, losing fine-grained local position discrimination.

NTK-Aware Scaling

Scale the RoPE base frequency instead of positions:

θi=base2i/dwhere base=base×sd/(d2)\theta'_i = \text{base}'^{-2i/d} \quad \text{where } \text{base}' = \text{base} \times s^{d/(d-2)}

This stretches high frequencies (local patterns) more than low frequencies (global patterns).

YaRN (Yet Another RoPE Extension)

Combine NTK scaling with per-dimension attention temperature adjustment:

def yarn_rope_scaling(dim, base=10000.0, scale=32, original_max_pos=4096):
    """YaRN: NTK scaling + temperature correction."""
    # NTK-aware base
    new_base = base * (scale ** (dim / (dim - 2)))

    # Compute frequencies
    freqs = 1.0 / (new_base ** (torch.arange(0, dim, 2).float() / dim))

    # Temperature correction per dimension
    # Low-frequency dims: no correction needed
    # High-frequency dims: scale down to prevent entropy increase
    wavelengths = 2 * 3.14159 / freqs
    ratios = wavelengths / original_max_pos

    # Ramp function: smooth transition between scaled and unscaled
    low = 1.0    # Below this wavelength: fully scale
    high = 32.0  # Above this: don't scale
    smooth = (ratios - low) / (high - low)
    smooth = smooth.clamp(0, 1)

    # Blend original and scaled frequencies
    original_freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    adjusted_freqs = (1 - smooth) * freqs + smooth * original_freqs

    return adjusted_freqs

Perplexity at Extended Context (Llama 2 7B, trained on 4K)

(perplexity (lower is better))
No extension at 32K Broken (no extrapolation)
1,000 perplexity (lower is better)
Linear PI at 32K Usable but degraded
12.8 perplexity (lower is better)
NTK-aware at 32K Good
8.2 perplexity (lower is better)
YaRN at 32K Near-native quality
6.1 perplexity (lower is better)
Native 32K training Best (but expensive)
5.8 perplexity (lower is better)

Approach 2: Ring Attention for Distributed Sequences

When a sequence is too long for one GPU’s memory, split it across PP GPUs:

GPU 0: tokens [0, N/P)       — holds KV cache for this chunk
GPU 1: tokens [N/P, 2N/P)    — holds KV cache for this chunk
...
GPU P-1: tokens [(P-1)N/P, N) — holds KV cache for this chunk

Each GPU computes attention for its chunk of queries against ALL keys/values. KV blocks are passed in a ring: GPU 0 sends its KV to GPU 1, GPU 1 sends to GPU 2, …, GPU P-1 sends to GPU 0. After P rounds, every GPU has seen every other GPU’s KV.

def ring_attention_step(Q_local, K_local, V_local, P, comm):
    """One step of ring attention on GPU rank."""
    output = torch.zeros_like(Q_local)
    max_scores = torch.full((Q_local.shape[0],), float("-inf"))
    sum_exp = torch.zeros(Q_local.shape[0])

    K_recv, V_recv = K_local, V_local

    for step in range(P):
        # Compute local attention: Q_local against current K_recv, V_recv
        scores = Q_local @ K_recv.T / math.sqrt(d)
        # Online softmax update (same as FlashAttention)
        new_max = torch.max(max_scores, scores.max(dim=-1).values)
        correction = torch.exp(max_scores - new_max)
        exp_scores = torch.exp(scores - new_max.unsqueeze(-1))
        sum_exp = sum_exp * correction + exp_scores.sum(dim=-1)
        output = output * correction.unsqueeze(-1) + exp_scores @ V_recv
        max_scores = new_max

        # Send K, V to next GPU in ring (async)
        K_recv, V_recv = comm.ring_send_recv(K_recv, V_recv)

    return output / sum_exp.unsqueeze(-1)

Memory: each GPU stores N/PN/P tokens of KV cache. For 1M tokens on 8 GPUs: 125K tokens per GPU = 40.9 GB (Llama 70B). Fits on one H100.

Communication: each step sends N/P×2×nkv×d×2N/P \times 2 \times n_{\text{kv}} \times d \times 2 bytes. At NVLink 900 GB/s: negligible compared to compute. At InfiniBand 50 GB/s: becomes the bottleneck for more than 4 GPUs.

Approach 3: Linear Attention (Lightning Attention)

As covered in Frontier Research Part 2: replace O(N2)O(N^2) softmax attention with O(N)O(N) linear attention. No KV cache growth. Constant memory per layer. The only approach that truly scales to 10M+ tokens.

Tradeoff: 1-2% quality loss on short contexts (under 32K tokens) where softmax attention is superior.

Production Context Length Limits

📊

Practical Context Limits by Approach (2025)

ApproachMax Practical ContextQuality at MaxInfrastructure Required
RoPE + FlashAttention (single GPU) 128K Native quality 1x H100 (with KV quantization)
Ring Attention (multi-GPU) 1M Near-native 8x H100 (NVLink required)
Lightning Attention 4M+ (train), 10M+ (infer) 1-2% PPL degradation 1x H100 (constant memory)
Hybrid (Ring + Lightning) 10M+ Near-native 8x H100
💡 The Practical Choice for Most Applications

128K context with RoPE scaling + FlashAttention covers 95% of production use cases. Ring Attention for 1M is needed only for full-codebase analysis, multi-document QA, or long video understanding. Lightning Attention for 4M+ is niche — extremely long documents, genomics, or continuous monitoring.

Reviewer Agent Validation

Challenge: For a Llama 70B model on 4 H100 GPUs with Ring Attention, compute: (a) KV cache per GPU at 512K context, (b) ring communication volume per attention layer.

Expected:

  • (a) Each GPU holds 512K/4 = 128K tokens. KV per token per layer = 4,096 bytes. Across 80 layers: 128K x 80 x 4096 = 41.9 GB per GPU. Fits in 80GB H100.
  • (b) Each ring step sends 128K tokens x 2 x 8 heads x 128 dim x 2 bytes = 524 MB. Over NVLink (900 GB/s): 0.58 ms per step. 4 steps per layer: 2.3 ms per layer.