Rotary Position Embeddings (RoPE) encode position through rotation in complex space. This elegant formulation enables efficient computation and—critically—context length extension through scaling techniques.
RoPE Mathematics
RoPE rotates query and key vectors based on position:
import torch
import torch.nn as nn
def compute_rope_frequencies(dim: int, max_seq_len: int, base: float = 10000.0):
"""
Compute rotation frequencies for each dimension.
θ_i = 1 / (base^(2i/d)) for i = 0, 1, ..., d/2 - 1
"""
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.arange(max_seq_len)
# [seq_len, dim/2]
freqs = torch.outer(positions, inv_freq)
# [seq_len, dim] - interleave cos and sin
emb = torch.cat([freqs, freqs], dim=-1)
return torch.cos(emb), torch.sin(emb)
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
"""
Apply rotary embedding to input tensor.
x: [batch, seq_len, num_heads, head_dim]
Returns: rotated x
"""
# Split x into pairs for rotation
x1 = x[..., ::2] # Even indices
x2 = x[..., 1::2] # Odd indices
# Apply rotation
# [x1, x2] @ [[cos, -sin], [sin, cos]] = [x1*cos - x2*sin, x1*sin + x2*cos]
rotated = torch.stack([
x1 * cos[..., ::2] - x2 * sin[..., ::2],
x1 * sin[..., 1::2] + x2 * cos[..., 1::2]
], dim=-1).flatten(-2)
return rotated
The key property: Q(m)·K(n) depends only on (m-n) because rotations compose multiplicatively. This gives RoPE its relative position encoding property.
Efficient Implementation
Avoid repeated frequency computation in inference:
class RoPECache:
"""Cache RoPE sin/cos for efficient inference."""
def __init__(self, dim: int, max_seq_len: int, base: float = 10000.0):
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Precompute frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
# Cache will be populated on first forward
self._cos_cached = None
self._sin_cached = None
self._seq_len_cached = 0
def _update_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
if seq_len <= self._seq_len_cached:
return
self._seq_len_cached = max(seq_len, self.max_seq_len)
positions = torch.arange(self._seq_len_cached, device=device, dtype=dtype)
freqs = torch.outer(positions, self.inv_freq.to(device))
# [seq_len, dim]
emb = torch.cat([freqs, freqs], dim=-1)
self._cos_cached = emb.cos()
self._sin_cached = emb.sin()
def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
"""
x: [batch, seq_len, num_heads, head_dim]
position_ids: [batch, seq_len]
"""
seq_len = x.shape[1]
self._update_cache(position_ids.max().item() + 1, x.device, x.dtype)
cos = self._cos_cached[position_ids] # [batch, seq_len, dim]
sin = self._sin_cached[position_ids]
return apply_rope(x, cos.unsqueeze(2), sin.unsqueeze(2))
Context Length Extension
Models trained on 4K context can be extended using scaling techniques:
Linear Scaling (Position Interpolation)
def linear_scaling(position_ids: torch.Tensor, scale: float) -> torch.Tensor:
"""
Scale positions to fit longer sequences into trained range.
Training: 4K context, Inference: 16K context, scale = 4
Position 8000 -> 2000 (within training range)
"""
return position_ids / scale
Linear scaling beyond 2-4x often degrades quality significantly. The model wasn’t trained on the compressed position patterns.
Dynamic NTK Scaling
def dynamic_ntk_scaling(
dim: int,
max_seq_len: int,
original_max_len: int,
base: float = 10000.0
) -> torch.Tensor:
"""
NTK-aware scaling adjusts the base frequency.
Preserves high-frequency components while scaling low-frequency.
"""
scale = max_seq_len / original_max_len
# Scale the base exponentially
scaled_base = base * (scale ** (dim / (dim - 2)))
inv_freq = 1.0 / (scaled_base ** (torch.arange(0, dim, 2).float() / dim))
return inv_freq
YaRN (Yet another RoPE extensioN)
class YaRNRoPE:
"""
YaRN combines NTK scaling with attention scaling.
Achieves better quality at longer contexts.
"""
def __init__(
self,
dim: int,
original_max_len: int,
target_max_len: int,
base: float = 10000.0,
beta_fast: float = 32,
beta_slow: float = 1,
):
self.scale = target_max_len / original_max_len
# Compute per-dimension scaling factors
low_freq_factor = torch.arange(0, dim, 2).float() / dim
high_freq_factor = 1.0 - low_freq_factor
# Blend between linear and NTK scaling per dimension
ramp = torch.clamp(
(low_freq_factor - beta_slow) / (beta_fast - beta_slow),
0, 1
)
# Low dims: mostly linear scaling
# High dims: mostly NTK scaling
self.scaling_factors = (1 - ramp) * (1 / self.scale) + ramp * 1.0
# Also scale attention by sqrt(scale)
self.attention_scale = 0.1 * math.log(self.scale) + 1.0
def compute_frequencies(self, positions: torch.Tensor):
# Apply per-dimension scaling to frequencies
scaled_freqs = positions.unsqueeze(-1) * self.inv_freq * self.scaling_factors
return scaled_freqs
Context Extension Quality (Perplexity on 16K context)
| Method | 4K→8K | 4K→16K | 4K→32K |
|---|---|---|---|
| No scaling (baseline) | 5.2 | 7.8 | 15.4 |
| Linear scaling | 5.4 | 6.2 | 8.1 |
| NTK scaling | 5.3 | 5.8 | 6.5 |
| YaRN | 5.2 | 5.4 | 5.9 |
Effective Context Utilization by Extension Method
(% of baseline quality)Implementation Tips
- Fuse into attention kernel: RoPE can be applied inside FlashAttention
- Cache aggressively: Precompute sin/cos for all positions at startup
- Use FP32 for frequencies: Low precision causes drift at long contexts
- Test perplexity at target length: Don’t trust training context results
# Testing context extension
def evaluate_context_extension(model, test_data, context_lengths):
results = {}
for ctx_len in context_lengths:
samples = [s[:ctx_len] for s in test_data]
ppl = compute_perplexity(model, samples)
results[ctx_len] = ppl
print(f"Context {ctx_len}: PPL = {ppl:.2f}")
return results
Conclusion
RoPE enables elegant position encoding with efficient implementation. For context extension:
- Up to 2x: Linear scaling works fine
- 2-4x: Use NTK scaling
- 4x+: Use YaRN or fine-tune on longer data
The field is moving toward training on longer contexts directly (e.g., Llama 3’s 128K), but scaling techniques remain valuable for extending existing models.