Standard attention computes . The term produces an matrix. At tokens, this matrix has entries — 2 TB in FP16. Even FlashAttention, which avoids materializing this matrix, still performs FLOPs.
Lightning Attention replaces the computation with by exploiting the associativity of matrix multiplication. The key: compute first (size ), THEN multiply by . This reversal changes the complexity from to — linear in sequence length.
The Mathematical Foundation
Standard attention:
Replace with for some feature map :
Define and .
Then:
and are independent of . Compute once: . Query each position: . Total: — linear in .
The core insight: by factoring instead of , we avoid the intermediate.
Feature Map Choice
The simplest feature map: , ensuring non-negative outputs. This is the approach used in early linear attention (Katformer, 2020). Quality degrades because ELU does not sharpen attention like softmax’s exponential.
Lightning Attention uses a refined feature map with learned parameters:
import torch
import torch.nn as nn
import torch.nn.functional as F
class LightningFeatureMap(nn.Module):
"""Learned feature map for Lightning Attention.
Maps d-dimensional Q/K vectors to d'-dimensional feature space
where the kernel approximation is more faithful to softmax.
"""
def __init__(self, d_head, d_feature=None):
super().__init__()
self.d_feature = d_feature or d_head
# Learned random feature projection
self.W = nn.Parameter(torch.randn(d_head, self.d_feature) * 0.1)
def forward(self, x):
# x: [..., d_head]
# Project and apply non-linearity
projected = x @ self.W # [..., d_feature]
# Use ReLU-squared for sharper attention (better than ELU)
return F.relu(projected) ** 2 + 1e-6 # Ensure positive
The non-linearity provides stronger sharpening than ELU, approximating softmax behavior more closely. The learned projection adapts the feature space to the data distribution.
Chunk-Wise Computation
Pure linear attention processes the entire sequence with a single accumulated state . For very long sequences, this works but loses fine-grained local attention patterns. Lightning Attention uses chunk-wise computation: divide the sequence into chunks of size , use standard quadratic attention within each chunk, and linear attention between chunks.
class ChunkedLightningAttention(nn.Module):
"""Lightning Attention with chunk-wise local + linear global."""
def __init__(self, d_head, chunk_size=256, d_feature=None):
super().__init__()
self.chunk_size = chunk_size
self.d_head = d_head
self.feature_map = LightningFeatureMap(d_head, d_feature)
self.scale = 1.0 / (d_head ** 0.5)
def forward(self, Q, K, V):
"""
Q, K, V: [batch, seq_len, d_head]
Returns: [batch, seq_len, d_head]
"""
B, N, D = Q.shape
C = self.chunk_size
# Pad to multiple of chunk_size
pad = (C - N % C) % C
if pad > 0:
Q = F.pad(Q, (0, 0, 0, pad))
K = F.pad(K, (0, 0, 0, pad))
V = F.pad(V, (0, 0, 0, pad))
N_padded = Q.shape[1]
num_chunks = N_padded // C
# Reshape into chunks: [B, num_chunks, C, D]
Q_chunks = Q.view(B, num_chunks, C, D)
K_chunks = K.view(B, num_chunks, C, D)
V_chunks = V.view(B, num_chunks, C, D)
# === Intra-chunk: standard quadratic attention ===
# [B, num_chunks, C, C] score matrix per chunk
intra_scores = torch.matmul(Q_chunks, K_chunks.transpose(-1, -2)) * self.scale
# Causal mask within chunk
causal = torch.triu(torch.ones(C, C, device=Q.device), diagonal=1).bool()
intra_scores.masked_fill_(causal, float('-inf'))
intra_weights = F.softmax(intra_scores, dim=-1)
intra_output = torch.matmul(intra_weights, V_chunks) # [B, nc, C, D]
# === Inter-chunk: linear attention via accumulated state ===
phi_K = self.feature_map(K_chunks) # [B, nc, C, d']
phi_Q = self.feature_map(Q_chunks) # [B, nc, C, d']
d_feat = phi_K.shape[-1]
# Accumulate S and z across chunks
S = torch.zeros(B, d_feat, D, device=Q.device) # State: d' x d
z = torch.zeros(B, d_feat, device=Q.device) # Normalizer: d'
inter_outputs = []
for chunk_idx in range(num_chunks):
# Query this chunk against accumulated state from PREVIOUS chunks
phi_q = phi_Q[:, chunk_idx] # [B, C, d']
inter_attn = torch.matmul(phi_q, S) # [B, C, D]
inter_norm = torch.matmul(phi_q, z.unsqueeze(-1)).squeeze(-1) # [B, C]
inter_norm = inter_norm.clamp(min=1e-6)
inter_out = inter_attn / inter_norm.unsqueeze(-1)
inter_outputs.append(inter_out)
# Update state with this chunk's K, V
phi_k = phi_K[:, chunk_idx] # [B, C, d']
v = V_chunks[:, chunk_idx] # [B, C, D]
S = S + torch.matmul(phi_k.transpose(-1, -2), v) # [B, d', D]
z = z + phi_k.sum(dim=1) # [B, d']
inter_output = torch.stack(inter_outputs, dim=1) # [B, nc, C, D]
# Combine intra-chunk and inter-chunk attention
output = intra_output + inter_output
output = output.view(B, N_padded, D)
# Remove padding
return output[:, :N, :]
Intra-chunk: . With : . Inter-chunk: for state updates + for queries. Total: — linear in N. At : 1000x less compute than standard attention.
Memory Analysis
Memory Usage: Standard vs Lightning Attention (d=128, FP16)
| Sequence Length | Standard Attention | FlashAttention | Lightning (C=256) |
|---|---|---|---|
| 4K | 32 MB (score matrix) | 0.5 MB (tiled) | 2 MB (state + chunks) |
| 128K | 32 GB | 16 MB | 2 MB |
| 1M | 2 TB | 128 MB | 2 MB |
| 4M | 32 TB | 512 MB | 2 MB |
The state is fixed-size regardless of sequence length. This is Lightning Attention’s key advantage: constant memory per layer, compared to FlashAttention’s memory (for KV cache) and standard attention’s .
The Quality Tradeoff
Linear attention approximates softmax. The approximation error depends on the feature map quality:
Perplexity vs Sequence Length: Softmax vs Lightning
(% of softmax perplexity (lower is better))Lightning Attention with learned features loses only 1.2% perplexity on short contexts (4K tokens). At very long contexts (1M+ tokens), it actually closes the gap because softmax attention’s fixed-point-precision issues at extreme sequence lengths degrade quality slightly.
For sequences under 32K tokens: use softmax + FlashAttention. The quality is better and the compute is manageable. For sequences 32K-128K: either works, with FlashAttention slightly ahead on quality. For sequences above 128K: Lightning Attention is the only practical option — softmax attention’s quadratic cost becomes prohibitive regardless of FlashAttention’s memory optimization.
Inference: No KV Cache
A remarkable property of linear attention: there is no KV cache. During autoregressive generation:
Standard attention: cache all previous K, V tensors. Cost: memory, growing with every token.
Lightning Attention: maintain the accumulated state and normalizer . When a new token arrives:
# New token arrives: q_new, k_new, v_new (all shape [d])
phi_k_new = feature_map(k_new) # [d']
phi_q_new = feature_map(q_new) # [d']
# Query the state
output = (phi_q_new @ S) / (phi_q_new @ z + 1e-6) # [d]
# Update state
S += phi_k_new.unsqueeze(-1) @ v_new.unsqueeze(0) # [d', d]
z += phi_k_new # [d']
Cost per new token: for the state query and update. Constant, independent of how many previous tokens exist. No KV cache growth. No PagedAttention needed. No memory pressure from long sequences.
This is why MiniMax-01 can process 4M tokens: the per-layer state is a fixed 32 KB (at , FP16), regardless of whether the sequence is 1K or 4M tokens long.
Reviewer Agent Validation
Challenge: Using only this post, implement the core linear attention forward pass (without chunking) that processes Q, K, V in O(N) time.
Expected implementation:
def linear_attention(Q, K, V, feature_map_fn):
"""O(N) linear attention. No N x N matrix."""
phi_Q = feature_map_fn(Q) # [B, N, d']
phi_K = feature_map_fn(K) # [B, N, d']
# Accumulate S = sum(phi_k * v^T) and z = sum(phi_k)
S = torch.matmul(phi_K.transpose(-1, -2), V) # [B, d', d]
z = phi_K.sum(dim=1) # [B, d']
# Query: O = phi_Q @ S / (phi_Q @ z)
numerator = torch.matmul(phi_Q, S) # [B, N, d]
denominator = torch.matmul(phi_Q, z.unsqueeze(-1)) # [B, N, 1]
return numerator / denominator.clamp(min=1e-6)
This 10-line function is the complete linear attention mechanism. If the Reviewer Agent can produce it, the mathematical foundation was explained with sufficient precision.