Rotary Position Embeddings (RoPE) encode position through rotation in complex space. This elegant formulation enables efficient computation and—critically—context length extension through scaling techniques.

RoPE Mathematics

RoPE rotates query and key vectors based on position:

import torch
import torch.nn as nn

def compute_rope_frequencies(dim: int, max_seq_len: int, base: float = 10000.0):
    """
    Compute rotation frequencies for each dimension.
    
    θ_i = 1 / (base^(2i/d)) for i = 0, 1, ..., d/2 - 1
    """
    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(max_seq_len)
    
    # [seq_len, dim/2]
    freqs = torch.outer(positions, inv_freq)
    
    # [seq_len, dim] - interleave cos and sin
    emb = torch.cat([freqs, freqs], dim=-1)
    
    return torch.cos(emb), torch.sin(emb)

def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
    """
    Apply rotary embedding to input tensor.
    
    x: [batch, seq_len, num_heads, head_dim]
    Returns: rotated x
    """
    # Split x into pairs for rotation
    x1 = x[..., ::2]   # Even indices
    x2 = x[..., 1::2]  # Odd indices
    
    # Apply rotation
    # [x1, x2] @ [[cos, -sin], [sin, cos]] = [x1*cos - x2*sin, x1*sin + x2*cos]
    rotated = torch.stack([
        x1 * cos[..., ::2] - x2 * sin[..., ::2],
        x1 * sin[..., 1::2] + x2 * cos[..., 1::2]
    ], dim=-1).flatten(-2)
    
    return rotated
ℹ️ Relative Position Encoding

The key property: Q(m)·K(n) depends only on (m-n) because rotations compose multiplicatively. This gives RoPE its relative position encoding property.

Efficient Implementation

Avoid repeated frequency computation in inference:

class RoPECache:
    """Cache RoPE sin/cos for efficient inference."""
    
    def __init__(self, dim: int, max_seq_len: int, base: float = 10000.0):
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base
        
        # Precompute frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # Cache will be populated on first forward
        self._cos_cached = None
        self._sin_cached = None
        self._seq_len_cached = 0
    
    def _update_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
        if seq_len <= self._seq_len_cached:
            return
        
        self._seq_len_cached = max(seq_len, self.max_seq_len)
        
        positions = torch.arange(self._seq_len_cached, device=device, dtype=dtype)
        freqs = torch.outer(positions, self.inv_freq.to(device))
        
        # [seq_len, dim]
        emb = torch.cat([freqs, freqs], dim=-1)
        self._cos_cached = emb.cos()
        self._sin_cached = emb.sin()
    
    def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
        """
        x: [batch, seq_len, num_heads, head_dim]
        position_ids: [batch, seq_len]
        """
        seq_len = x.shape[1]
        self._update_cache(position_ids.max().item() + 1, x.device, x.dtype)
        
        cos = self._cos_cached[position_ids]  # [batch, seq_len, dim]
        sin = self._sin_cached[position_ids]
        
        return apply_rope(x, cos.unsqueeze(2), sin.unsqueeze(2))

Context Length Extension

Models trained on 4K context can be extended using scaling techniques:

Linear Scaling (Position Interpolation)

def linear_scaling(position_ids: torch.Tensor, scale: float) -> torch.Tensor:
    """
    Scale positions to fit longer sequences into trained range.
    Training: 4K context, Inference: 16K context, scale = 4
    Position 8000 -> 2000 (within training range)
    """
    return position_ids / scale
⚠️ Quality Degradation

Linear scaling beyond 2-4x often degrades quality significantly. The model wasn’t trained on the compressed position patterns.

Dynamic NTK Scaling

def dynamic_ntk_scaling(
    dim: int,
    max_seq_len: int,
    original_max_len: int,
    base: float = 10000.0
) -> torch.Tensor:
    """
    NTK-aware scaling adjusts the base frequency.
    Preserves high-frequency components while scaling low-frequency.
    """
    scale = max_seq_len / original_max_len
    
    # Scale the base exponentially
    scaled_base = base * (scale ** (dim / (dim - 2)))
    
    inv_freq = 1.0 / (scaled_base ** (torch.arange(0, dim, 2).float() / dim))
    return inv_freq

YaRN (Yet another RoPE extensioN)

class YaRNRoPE:
    """
    YaRN combines NTK scaling with attention scaling.
    Achieves better quality at longer contexts.
    """
    
    def __init__(
        self,
        dim: int,
        original_max_len: int,
        target_max_len: int,
        base: float = 10000.0,
        beta_fast: float = 32,
        beta_slow: float = 1,
    ):
        self.scale = target_max_len / original_max_len
        
        # Compute per-dimension scaling factors
        low_freq_factor = torch.arange(0, dim, 2).float() / dim
        high_freq_factor = 1.0 - low_freq_factor
        
        # Blend between linear and NTK scaling per dimension
        ramp = torch.clamp(
            (low_freq_factor - beta_slow) / (beta_fast - beta_slow),
            0, 1
        )
        
        # Low dims: mostly linear scaling
        # High dims: mostly NTK scaling
        self.scaling_factors = (1 - ramp) * (1 / self.scale) + ramp * 1.0
        
        # Also scale attention by sqrt(scale)
        self.attention_scale = 0.1 * math.log(self.scale) + 1.0
    
    def compute_frequencies(self, positions: torch.Tensor):
        # Apply per-dimension scaling to frequencies
        scaled_freqs = positions.unsqueeze(-1) * self.inv_freq * self.scaling_factors
        return scaled_freqs
📊

Context Extension Quality (Perplexity on 16K context)

Method4K→8K4K→16K4K→32K
No scaling (baseline) 5.2 7.8 15.4
Linear scaling 5.4 6.2 8.1
NTK scaling 5.3 5.8 6.5
YaRN 5.2 5.4 5.9
Note: Llama-7B base model, PG19 evaluation

Effective Context Utilization by Extension Method

(% of baseline quality)
Linear (8K)
96 % of baseline quality
Linear (32K)
67 % of baseline quality
YaRN (8K)
99 % of baseline quality
YaRN (32K)
92 % of baseline quality

Implementation Tips

  1. Fuse into attention kernel: RoPE can be applied inside FlashAttention
  2. Cache aggressively: Precompute sin/cos for all positions at startup
  3. Use FP32 for frequencies: Low precision causes drift at long contexts
  4. Test perplexity at target length: Don’t trust training context results
# Testing context extension
def evaluate_context_extension(model, test_data, context_lengths):
    results = {}
    for ctx_len in context_lengths:
        samples = [s[:ctx_len] for s in test_data]
        ppl = compute_perplexity(model, samples)
        results[ctx_len] = ppl
        print(f"Context {ctx_len}: PPL = {ppl:.2f}")
    return results

Conclusion

RoPE enables elegant position encoding with efficient implementation. For context extension:

  • Up to 2x: Linear scaling works fine
  • 2-4x: Use NTK scaling
  • 4x+: Use YaRN or fine-tune on longer data

The field is moving toward training on longer contexts directly (e.g., Llama 3’s 128K), but scaling techniques remain valuable for extending existing models.