The transformer architecture from “Attention Is All You Need” (2017) is unrecognizable in 2026. Every component has been replaced, optimized, or augmented. Yet the core computation — residual stream plus attention plus feedforward, repeated times — remains. This post catalogs what has settled into consensus, what is actively contested, and what might change next.
This is a snapshot, not a prediction. The field moves fast. But certain architectural choices have converged across labs, and certain research directions have enough momentum that their trajectory is predictable.
The Settled Stack
These components appear in nearly every frontier model deployed in 2025-2026. They are no longer research questions — they are engineering defaults.
1.1 RMSNorm Over LayerNorm
Every major model (Llama 3, Mistral, Qwen 2.5, DeepSeek V3, Gemma 2) uses RMSNorm instead of LayerNorm. The mean subtraction in LayerNorm is unnecessary for transformer performance and adds computational cost.
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
"""The standard normalization in 2026 transformers."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
# No mean subtraction -- just normalize by RMS
rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
class LayerNorm(nn.Module):
"""The 2017-2022 default. Now obsolete for LLMs."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim))
self.eps = eps
def forward(self, x):
# Mean subtraction + variance normalization
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
return (x - mean) / torch.sqrt(var + self.eps) * self.weight + self.bias
Why RMSNorm won: (1) 20-30% faster than LayerNorm due to no mean computation, (2) no bias parameter needed, (3) empirically equivalent quality in all tested settings. The mean subtraction in LayerNorm was a legacy from batch normalization that transformers never needed.
1.2 SwiGLU Over Standard FFN
The feedforward network in every modern transformer uses SwiGLU (or a variant like GeGLU):
class SwiGLU(nn.Module):
"""Gate * SiLU(gate_proj) * up_proj -- the 2026 standard."""
def __init__(self, d_model, d_ff=None):
super().__init__()
if d_ff is None:
# Llama convention: d_ff = 8/3 * d_model, rounded to multiple of 256
d_ff = int(8 / 3 * d_model)
d_ff = ((d_ff + 255) // 256) * 256
self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
self.up_proj = nn.Linear(d_model, d_ff, bias=False)
self.down_proj = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x):
return self.down_proj(
nn.functional.silu(self.gate_proj(x)) * self.up_proj(x)
)
class StandardFFN(nn.Module):
"""The 2017 original. No longer used in frontier models."""
def __init__(self, d_model, d_ff):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
def forward(self, x):
return self.linear2(torch.relu(self.linear1(x)))
SwiGLU has 50% more parameters than the standard FFN at the same (three projections instead of two), but empirical scaling laws show it provides better loss per FLOP. The ratio for compensates: , matching the standard FFN with two projections ( total weight matrices per token dimension).
1.3 RoPE for Positional Encoding
Rotary Position Embeddings (RoPE) replaced all alternatives: learned absolute, sinusoidal, ALiBi. Every major model uses RoPE.
def precompute_rope_freqs(dim, max_seq_len, base=10000.0):
"""Precompute the complex exponentials for RoPE.
Each pair of dimensions rotates at a different frequency,
determined by the geometric sequence base^(-2i/dim).
"""
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.arange(max_seq_len).float()
# Outer product: [seq_len, dim/2]
angles = torch.outer(positions, freqs)
# Complex representation: cos(theta) + i*sin(theta)
return torch.polar(torch.ones_like(angles), angles)
def apply_rope(q, k, rope_freqs):
"""Apply rotary embeddings to query and key tensors.
q, k: [B, n_heads, S, head_dim]
rope_freqs: [S, head_dim/2] complex tensor
"""
# Reshape to pairs: [B, n_heads, S, head_dim/2, 2]
q_pairs = q.float().reshape(*q.shape[:-1], -1, 2)
k_pairs = k.float().reshape(*k.shape[:-1], -1, 2)
# Convert to complex
q_complex = torch.view_as_complex(q_pairs)
k_complex = torch.view_as_complex(k_pairs)
# Rotate by element-wise multiplication with complex exponentials
freqs = rope_freqs.unsqueeze(0).unsqueeze(0) # [1, 1, S, head_dim/2]
q_rotated = torch.view_as_real(q_complex * freqs).flatten(-2)
k_rotated = torch.view_as_real(k_complex * freqs).flatten(-2)
return q_rotated.type_as(q), k_rotated.type_as(k)
RoPE won because: (1) it encodes relative position through the dot product (the rotation angle between two positions depends only on their distance), (2) it extrapolates to longer sequences than seen in training (with NTK-aware scaling), (3) it adds zero parameters, (4) it is compatible with KV caching.
1.4 BPE Tokenization
Byte-Pair Encoding with a vocabulary of 100K-200K tokens is the universal standard. GPT-4 uses cl100k (100K tokens), Llama 3 uses a 128K vocabulary, Qwen 2.5 uses 152K. The tokenizer is trained on multilingual data with byte-level fallback.
def bpe_characteristics_2026():
"""Summary of modern BPE tokenizer properties."""
return {
"vocab_size": "100K-200K tokens",
"algorithm": "Byte-level BPE (SentencePiece or tiktoken)",
"byte_fallback": True, # Any byte sequence is representable
"special_tokens": [
"<|begin_of_text|>",
"<|end_of_text|>",
"<|start_header_id|>",
"<|end_header_id|>",
"<|eot_id|>",
],
"compression_ratio": "3.5-4.5 characters per token (English)",
"multilingual": True,
"training_data": "Trillions of tokens, balanced across languages",
}
The settled stack as of early 2026: Pre-RMSNorm (before attention and FFN), SwiGLU feedforward, RoPE positional encoding, BPE tokenization with 100K+ vocabulary, GQA (Grouped Query Attention) for inference efficiency, and no bias terms in linear layers. If you are building a new model from scratch, use exactly this stack.
Active Frontiers
These techniques are deployed in at least one major model but are not universal. The field is still evaluating their tradeoffs.
2.1 Multi-head Latent Attention (MLA)
DeepSeek V2/V3 introduced MLA, which compresses the KV cache by projecting keys and values into a low-rank latent space before caching:
class MultiHeadLatentAttention(nn.Module):
"""MLA: compress KV cache via low-rank projection.
Standard GQA caches: n_kv_heads * head_dim * 2 * seq_len per layer
MLA caches: latent_dim * seq_len per layer (shared across heads)
For DeepSeek V3: latent_dim=512 vs GQA would need 1024+
"""
def __init__(self, d_model, n_heads, latent_dim, head_dim):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.latent_dim = latent_dim
# Compress input to latent representation (cached)
self.down_proj = nn.Linear(d_model, latent_dim, bias=False)
# Expand latent to per-head K and V (computed on the fly)
self.k_proj = nn.Linear(latent_dim, n_heads * head_dim, bias=False)
self.v_proj = nn.Linear(latent_dim, n_heads * head_dim, bias=False)
# Standard Q projection
self.q_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, d_model, bias=False)
self.scale = head_dim ** -0.5
def forward(self, x, cached_latent=None):
B, S, _ = x.shape
# Compress to latent (this is what gets cached)
latent = self.down_proj(x) # [B, S, latent_dim]
# Expand to K, V
k = self.k_proj(latent).reshape(B, S, self.n_heads, self.head_dim)
v = self.v_proj(latent).reshape(B, S, self.n_heads, self.head_dim)
q = self.q_proj(x).reshape(B, S, self.n_heads, self.head_dim)
# Transpose for attention: [B, n_heads, S, head_dim]
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
# Standard attention
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
out = attn @ v
out = out.transpose(1, 2).reshape(B, S, -1)
return self.o_proj(out), latent # Return latent for caching
MLA KV cache comparison for a 67B model at sequence length 128K. Standard MHA: GB per layer. GQA (8 KV heads): MB per layer. MLA (latent_dim=512): MB per layer. MLA halves the KV cache compared to GQA with 8 heads.
2.2 Linear Attention and State Space Models
The quadratic cost of attention () motivates subquadratic alternatives. Linear attention replaces the softmax with a kernel function that allows the attention computation to be factored:
class LinearAttention(nn.Module):
"""Linear attention: O(n * d^2) instead of O(n^2 * d).
Replaces softmax(QK^T) with phi(Q) * phi(K)^T, where phi is a
feature map. The key insight: phi(Q)(phi(K)^T V) can be computed
left-to-right in O(n * d^2) by maintaining a running sum.
"""
def __init__(self, d_model, n_heads, head_dim):
super().__init__()
self.n_heads = n_heads
self.head_dim = head_dim
self.q_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
self.k_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
self.v_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
self.o_proj = nn.Linear(n_heads * head_dim, d_model, bias=False)
def feature_map(self, x):
"""ELU + 1 feature map (from Katharopoulos et al.)"""
return nn.functional.elu(x) + 1
def forward(self, x):
B, S, _ = x.shape
q = self.q_proj(x).reshape(B, S, self.n_heads, self.head_dim)
k = self.k_proj(x).reshape(B, S, self.n_heads, self.head_dim)
v = self.v_proj(x).reshape(B, S, self.n_heads, self.head_dim)
# Apply feature map
q = self.feature_map(q) # [B, S, H, D]
k = self.feature_map(k) # [B, S, H, D]
# Causal linear attention via cumulative sum
# S = cumsum(k^T @ v) -- the "state" matrix [B, H, D, D]
# output_t = q_t @ S_t / (q_t @ cumsum(k))
kv = torch.einsum("bshd,bshe->bshde", k, v) # [B, S, H, D, D]
state = kv.cumsum(dim=1) # Running sum of outer products
# Numerator: q @ state
num = torch.einsum("bshd,bshde->bshe", q, state) # [B, S, H, D]
# Denominator: q @ cumsum(k) for normalization
k_cumsum = k.cumsum(dim=1)
den = torch.einsum("bshd,bshd->bsh", q, k_cumsum).unsqueeze(-1)
out = num / (den + 1e-6)
out = out.reshape(B, S, -1)
return self.o_proj(out)
The problem with linear attention: the state matrix must capture all historical context, which limits recall ability. Recent work (Based, GLA, RWKV-6, Mamba-2) addresses this through data-dependent state transitions and selective forgetting.
2.3 Mixture of Experts (MoE)
MoE replaces the dense FFN with a set of expert FFNs and a router that selects experts per token:
class MoELayer(nn.Module):
"""Mixture of Experts: route each token to top-k experts.
Mixtral: 8 experts, top-2
DeepSeek V3: 256 experts, top-8
"""
def __init__(self, d_model, d_ff, n_experts=8, top_k=2):
super().__init__()
self.n_experts = n_experts
self.top_k = top_k
# Router: produces expert selection probabilities
self.router = nn.Linear(d_model, n_experts, bias=False)
# Expert FFNs (each is a SwiGLU)
self.experts = nn.ModuleList([
SwiGLU(d_model, d_ff) for _ in range(n_experts)
])
def forward(self, x):
B, S, D = x.shape
x_flat = x.reshape(-1, D) # [B*S, D]
# Route
router_logits = self.router(x_flat) # [B*S, n_experts]
router_probs = router_logits.softmax(dim=-1)
# Select top-k experts per token
top_k_probs, top_k_indices = router_probs.topk(self.top_k, dim=-1)
# Normalize selected expert weights
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# Compute expert outputs (simplified -- real impl uses grouped GEMM)
output = torch.zeros_like(x_flat)
for i in range(self.top_k):
expert_idx = top_k_indices[:, i] # [B*S]
expert_weight = top_k_probs[:, i] # [B*S]
for e in range(self.n_experts):
mask = expert_idx == e
if mask.any():
expert_input = x_flat[mask]
expert_output = self.experts[e](expert_input)
output[mask] += expert_weight[mask].unsqueeze(-1) * expert_output
return output.reshape(B, S, D)
MoE is now proven at scale: Mixtral 8x7B (2023), DeepSeek V3 671B (2024), and several undisclosed models. The key tradeoff: MoE models have more total parameters (higher memory) but activate fewer per token (lower FLOPs). A 671B MoE model with top-8 of 256 experts activates roughly 37B parameters per token — comparable FLOPs to a 37B dense model but with the representational capacity of a much larger network.
2.4 Multi-Token Prediction (MTP)
Instead of predicting one next token, predict tokens simultaneously. This provides a denser training signal and enables self-speculative decoding:
class MTPHead(nn.Module):
"""Multi-Token Prediction: predict K future tokens.
DeepSeek V3 uses K=2 (next token + one ahead).
During inference, the extra prediction enables speculative decoding
without a separate draft model.
"""
def __init__(self, d_model, vocab_size, K=2):
super().__init__()
self.K = K
# Each future position gets a projection layer that
# transforms the hidden state for that prediction horizon
self.transform = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_model, bias=False),
RMSNorm(d_model),
) for _ in range(K - 1) # First token uses raw hidden state
])
# Shared output embedding (tied with input embedding)
self.output = nn.Linear(d_model, vocab_size, bias=False)
def forward(self, hidden_states):
"""
hidden_states: [B, S, d_model]
Returns: list of K logit tensors, each [B, S, vocab_size]
"""
logits = [self.output(hidden_states)] # Token t+1
h = hidden_states
for k in range(self.K - 1):
h = self.transform[k](h)
logits.append(self.output(h)) # Token t+k+2
return logits
def mtp_loss(self, logits_list, target_ids):
"""
logits_list: K tensors of [B, S, V]
target_ids: [B, S] ground truth token IDs
"""
total_loss = 0
for k, logits in enumerate(logits_list):
# Shift targets: predict token at position t+k+1
shift = k + 1
shifted_logits = logits[:, :-shift, :]
shifted_targets = target_ids[:, shift:]
loss = nn.functional.cross_entropy(
shifted_logits.reshape(-1, shifted_logits.size(-1)),
shifted_targets.reshape(-1),
)
# Weight future predictions less
weight = 1.0 / (k + 1)
total_loss += weight * loss
return total_loss
MTP’s primary value during training is providing a richer gradient signal — predicting future tokens forces the model to build better internal representations. The speculative decoding benefit at inference is a bonus. DeepSeek reports MTP improves training efficiency by 10-15% (fewer tokens needed to reach the same loss).
Emerging Directions
These are research-stage techniques with strong results but limited production deployment. They represent the likely next wave.
3.1 Mamba-Transformer Hybrids
Pure Mamba (selective state space model) struggles with in-context retrieval tasks that require precise token-to-token attention. Pure transformers struggle with very long sequences due to cost. Hybrids interleave the two:
class MambaTransformerBlock(nn.Module):
"""Hybrid block: alternate Mamba and attention layers.
Jamba (AI21, 2024) pattern: 7 Mamba layers per 1 attention layer.
The attention layers handle precise retrieval.
The Mamba layers handle long-range context compression.
"""
def __init__(self, d_model, layer_type="mamba"):
super().__init__()
self.layer_type = layer_type
self.norm = RMSNorm(d_model)
if layer_type == "mamba":
self.core = MambaBlock(d_model) # O(n) per layer
elif layer_type == "attention":
self.core = AttentionBlock(d_model) # O(n^2) per layer
self.ffn_norm = RMSNorm(d_model)
self.ffn = SwiGLU(d_model)
def forward(self, x, **kwargs):
x = x + self.core(self.norm(x), **kwargs)
x = x + self.ffn(self.ffn_norm(x))
return x
class MambaBlock(nn.Module):
"""Simplified Mamba-2 block (selective state space model).
Core idea: state transition matrix A is data-dependent,
allowing the model to selectively remember or forget.
"""
def __init__(self, d_model, d_state=64, d_conv=4, expand=2):
super().__init__()
d_inner = d_model * expand
self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
self.conv1d = nn.Conv1d(
d_inner, d_inner, kernel_size=d_conv,
padding=d_conv - 1, groups=d_inner
)
# SSM parameters (data-dependent)
self.dt_proj = nn.Linear(d_inner, d_inner, bias=True)
self.B_proj = nn.Linear(d_inner, d_state, bias=False)
self.C_proj = nn.Linear(d_inner, d_state, bias=False)
# Fixed A: log-space parameterization for stability
A = torch.arange(1, d_state + 1).float()
self.A_log = nn.Parameter(torch.log(A).unsqueeze(0).expand(d_inner, -1))
self.D = nn.Parameter(torch.ones(d_inner))
self.out_proj = nn.Linear(d_inner, d_model, bias=False)
def forward(self, x):
B, S, D = x.shape
# Project and split into main path and gate
xz = self.in_proj(x) # [B, S, 2*d_inner]
x_main, z = xz.chunk(2, dim=-1) # Each [B, S, d_inner]
# Conv1d on the main path
x_conv = self.conv1d(x_main.transpose(1, 2))[:, :, :S].transpose(1, 2)
x_conv = nn.functional.silu(x_conv)
# Compute data-dependent SSM parameters
dt = nn.functional.softplus(self.dt_proj(x_conv)) # [B, S, d_inner]
B_param = self.B_proj(x_conv) # [B, S, d_state]
C_param = self.C_proj(x_conv) # [B, S, d_state]
A = -torch.exp(self.A_log) # [d_inner, d_state]
# Selective scan (simplified -- real impl uses CUDA kernel)
y = selective_scan(x_conv, dt, A, B_param, C_param, self.D)
# Gate and project out
y = y * nn.functional.silu(z)
return self.out_proj(y)
def selective_scan(x, dt, A, B, C, D):
"""Selective scan: the core Mamba operation.
Processes sequence left-to-right, maintaining a hidden state
that is selectively updated based on the input.
x: [B, S, d_inner]
dt: [B, S, d_inner] -- discretization step (data-dependent)
A: [d_inner, d_state] -- state transition (log-space)
B: [B, S, d_state] -- input projection (data-dependent)
C: [B, S, d_state] -- output projection (data-dependent)
D: [d_inner] -- skip connection
"""
B_batch, S, d_inner = x.shape
d_state = A.shape[1]
# Discretize: A_bar = exp(dt * A), B_bar = dt * B
dA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0))
dB = dt.unsqueeze(-1) * B.unsqueeze(2) # [B, S, d_inner, d_state]
# Sequential scan
h = torch.zeros(B_batch, d_inner, d_state, device=x.device)
outputs = []
for t in range(S):
h = dA[:, t] * h + dB[:, t] * x[:, t].unsqueeze(-1)
y_t = (h * C[:, t].unsqueeze(1)).sum(dim=-1) # [B, d_inner]
outputs.append(y_t)
y = torch.stack(outputs, dim=1) # [B, S, d_inner]
y = y + x * D.unsqueeze(0).unsqueeze(0) # Skip connection
return y
Empirical results from Jamba and follow-up work show that hybrids with a 7:1 or 4:1 ratio of Mamba-to-attention layers match pure transformer quality while reducing KV cache by 80%+ and improving throughput on long sequences by 2-3x.
3.2 Mixture of Depths (MoD)
Not every token needs every layer. MoD adds a router at each layer that decides whether a token should be processed or skip via the residual connection:
class MoDLayer(nn.Module):
"""Mixture of Depths: skip layers for easy tokens.
capacity_ratio=0.5 means only 50% of tokens are processed.
The other 50% pass through via residual connection.
"""
def __init__(self, d_model, capacity_ratio=0.5):
super().__init__()
self.router = nn.Linear(d_model, 1, bias=True)
self.capacity_ratio = capacity_ratio
self.attention = AttentionBlock(d_model)
self.ffn = SwiGLU(d_model)
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)
def forward(self, x, training=True):
B, S, D = x.shape
scores = self.router(x).squeeze(-1) # [B, S]
# Select top-k tokens (k = capacity_ratio * S)
k = int(S * self.capacity_ratio)
_, top_indices = scores.topk(k, dim=1)
# Process selected tokens
selected = torch.gather(
x, 1, top_indices.unsqueeze(-1).expand(-1, -1, D)
)
# Attention + FFN on selected tokens only
selected = selected + self.attention(self.norm1(selected))
selected = selected + self.ffn(self.norm2(selected))
# Scatter back
output = x.clone()
output.scatter_(1, top_indices.unsqueeze(-1).expand(-1, -1, D), selected)
return output
Mixture of Depths Impact (32-layer Transformer)
| Configuration | Relative FLOPs | Perplexity | FLOP Savings |
|---|---|---|---|
| Standard (capacity=1.0) | 1.00x | 8.42 | baseline |
| MoD (capacity=0.75) | 0.78x | 8.48 | -22% |
| MoD (capacity=0.50) | 0.56x | 8.71 | -44% |
| MoD (capacity=0.25) | 0.38x | 9.34 | -62% |
3.3 Test-Time Compute Scaling
A major shift in 2024-2025: instead of only scaling training compute, scale inference compute. Models like o1, o3, DeepSeek R1, and QwQ allocate variable compute per query based on difficulty. The key mechanism is chain-of-thought sampling with verification:
def test_time_scaling(model, tokenizer, prompt,
max_attempts=16, verifier=None):
"""Scale test-time compute by sampling multiple solutions
and selecting the best one.
Budget allocation:
- Easy questions: 1 sample, short chain-of-thought
- Hard questions: 16+ samples, long chain-of-thought, verification
"""
solutions = []
for attempt in range(max_attempts):
# Sample with temperature > 0 for diversity
output = model.generate(
tokenizer.encode(prompt, return_tensors="pt").cuda(),
max_new_tokens=4096,
temperature=0.7,
top_p=0.95,
do_sample=True,
)
solution = tokenizer.decode(output[0])
solutions.append(solution)
# Early stopping: if verifier is confident, stop sampling
if verifier is not None:
confidence = verifier.score(prompt, solution)
if confidence > 0.95:
return solution
# Select best solution via majority voting or verifier
if verifier is not None:
scores = [verifier.score(prompt, s) for s in solutions]
return solutions[scores.index(max(scores))]
else:
# Majority voting on final answer
return majority_vote(solutions)
def majority_vote(solutions):
"""Extract final answers and return the most common one."""
from collections import Counter
answers = [extract_final_answer(s) for s in solutions]
counter = Counter(answers)
return solutions[answers.index(counter.most_common(1)[0][0])]
This is not an architectural change to the transformer itself, but it changes how transformers are used. The model architecture must support long-form reasoning (extended context, reliable chain-of-thought), and the training pipeline must include reinforcement learning for reasoning quality.
3.4 Native Long Context
Context windows have expanded from 2K (GPT-2) to 128K (Llama 3.1) to 1M+ (Gemini). The architectural requirements:
- RoPE scaling for position extrapolation:
def ntk_aware_rope(dim, max_seq_len, base=10000.0,
training_length=8192, alpha=None):
"""NTK-aware RoPE scaling for context extension.
Adjusts the base frequency to avoid concentrated attention
patterns at positions beyond training length.
"""
if alpha is None:
# Dynamic alpha based on sequence length ratio
alpha = (max_seq_len / training_length) - 1
alpha = max(alpha, 0) + 1
# Modify base to spread frequencies
adjusted_base = base * alpha ** (dim / (dim - 2))
freqs = 1.0 / (adjusted_base ** (torch.arange(0, dim, 2).float() / dim))
positions = torch.arange(max_seq_len).float()
angles = torch.outer(positions, freqs)
return torch.polar(torch.ones_like(angles), angles)
- FlashAttention for memory efficiency (no materialization)
- Ring attention for distributing long sequences across GPUs:
def ring_attention_concept(q, k, v, ring_size):
"""Ring attention: distribute sequence across GPUs in a ring.
Each GPU holds a chunk of the sequence. KV pairs are passed
around the ring so every chunk attends to every other chunk.
Memory per GPU: O(n/P * d) instead of O(n * d)
Communication: P-1 rounds of KV transfer
"""
chunk_size = q.shape[1] # Each GPU has seq_len / ring_size tokens
local_q = q # This GPU's queries (stays put)
local_kv = (k, v) # Start with local KV
output = torch.zeros_like(q)
running_max = torch.full((q.shape[0], q.shape[1], q.shape[2], 1),
float("-inf"), device=q.device)
running_sum = torch.zeros_like(output)
running_denom = torch.zeros(
q.shape[0], q.shape[1], q.shape[2], 1, device=q.device
)
for step in range(ring_size):
k_chunk, v_chunk = local_kv
# Compute attention for this chunk
scores = (local_q @ k_chunk.transpose(-2, -1)) / (q.shape[-1] ** 0.5)
# Apply causal mask if needed
chunk_max = scores.max(dim=-1, keepdim=True).values
new_max = torch.maximum(running_max, chunk_max)
# Online softmax update
exp_scores = torch.exp(scores - new_max)
exp_old = torch.exp(running_max - new_max)
running_sum = exp_old * running_sum + exp_scores @ v_chunk
running_denom = exp_old * running_denom + exp_scores.sum(dim=-1, keepdim=True)
running_max = new_max
# Send KV to next GPU in ring, receive from previous
# local_kv = ring_send_recv(local_kv)
output = running_sum / running_denom
return output
What Did Not Work
Several hyped approaches from 2023-2024 have not achieved broad adoption:
4.1 Pure State Space Models
RWKV, Mamba (as standalone architectures), and RetNet showed promise for linear-time sequence processing. But pure SSMs consistently underperform transformers on in-context learning benchmarks, especially few-shot tasks requiring precise attention to specific input tokens. The consensus: SSMs are excellent compression mechanisms but poor retrieval mechanisms. Hybrids (Section 3.1) are the path forward.
4.2 ALiBi Position Encoding
ALiBi (Attention with Linear Biases) was used in MPT and some BLOOM variants. It adds a linear bias to attention scores based on position distance. RoPE proved superior for context extension, and ALiBi has been abandoned by all major labs.
4.3 Sparse Mixture of Experts with Few Experts
Mixtral’s 8-expert design was influential but DeepSeek V3 showed that many more experts (256) with finer granularity gives better routing and quality. The trend is toward more experts, not fewer.
4.4 Retrieval-Augmented Generation (as Architecture)
RAG was proposed as a way to extend context by retrieving relevant documents. It works well as a system design, but attempts to build retrieval into the architecture (RETRO, Atlas) have not been adopted. The winning approach is large native context windows plus external retrieval at the system level, not the architecture level.
Quantitative Architecture Comparison
Architecture FLOPs vs Quality (7B-scale models)
| Metric | 0.25x | 0.5x | 0.75x | 1.0x | 1.5x | 2.0x |
|---|---|---|---|---|---|---|
| Dense Transformer (2026 stack) | ||||||
| MoE 8x (2x total params) | ||||||
| Mamba-Transformer Hybrid | ||||||
| Dense Transformer (2022 stack) |
The 2026 stack (RMSNorm + SwiGLU + RoPE + GQA) gives roughly 8-10% better loss per FLOP than the 2022 stack (LayerNorm + ReLU FFN + learned absolute positions + MHA). MoE gives another 5-8% at the cost of higher memory. Hybrids are competitive with the dense transformer and win at long context lengths.
The Infrastructure Stack
Architecture choices do not exist in isolation. They are constrained by and co-evolved with the infrastructure stack:
6.1 Training Infrastructure
def training_stack_2026():
"""The standard training infrastructure."""
return {
"hardware": "NVIDIA H100/H200 or AMD MI300X clusters",
"interconnect": "NVLink + InfiniBand (400 Gbps+)",
"parallelism": {
"data": "FSDP (ZeRO-3) or PyTorch FSDP2",
"tensor": "Megatron-style column/row parallel",
"pipeline": "1F1B or interleaved schedule",
"context": "Ring attention or Ulysses for long sequences",
"expert": "Expert parallelism for MoE (all-to-all comm)",
},
"precision": "BF16 forward/backward, FP32 master weights",
"optimizer": "AdamW (beta1=0.9, beta2=0.95, eps=1e-8)",
"framework": "PyTorch 2.x with torch.compile",
"checkpointing": "Activation checkpointing for memory, async I/O",
}
6.2 Inference Infrastructure
def inference_stack_2026():
"""The standard inference infrastructure."""
return {
"serving": "vLLM, TensorRT-LLM, or SGLang",
"batching": "Continuous batching with PagedAttention",
"kv_cache": "Paged, with prefix caching for shared prompts",
"quantization": "W4A16 (GPTQ/AWQ) or W8A8 (SmoothQuant)",
"speculative_decoding": "Self-speculative (MTP) or draft model",
"hardware": "H100 SXM (80GB HBM3) or multi-GPU with NVLink",
"memory_management": "PagedAttention with block size 16",
}
6.3 Key Metric: Tokens Per Dollar
The ultimate metric for architecture evaluation is tokens per dollar — combining training cost, inference cost, and quality:
def tokens_per_dollar(model_params_B, flops_per_token,
gpu_tflops, gpu_cost_per_hour):
"""Estimate inference cost in tokens per dollar.
Example: Llama 70B on H100
- flops_per_token: 2 * 70e9 = 140 GFLOP
- gpu_tflops: 990 TFLOP/s (BF16 tensor core)
- utilization: ~40% for autoregressive (memory-bound)
- effective_tflops: 396 TFLOP/s
- tokens/sec: 396e12 / 140e9 = 2828 tokens/sec
- gpu_cost: $2.50/hr (cloud spot)
- tokens/dollar: 2828 * 3600 / 2.50 = 4.07M tokens/$
"""
effective_tflops = gpu_tflops * 0.4 # Utilization factor
tokens_per_sec = (effective_tflops * 1e12) / (flops_per_token)
tokens_per_hour = tokens_per_sec * 3600
return tokens_per_hour / gpu_cost_per_hour
Predictions for 2027
Based on current trajectories, high-confidence predictions (greater than 70% probability):
7.1 MoE Becomes Default
By late 2027, most frontier models will use MoE. The evidence: DeepSeek V3 demonstrated that MoE at scale works. The economic argument is overwhelming — MoE gives better quality per inference FLOP, and inference cost dominates total cost of ownership for deployed models. The remaining challenge is efficient expert parallelism during training, which is being solved by better collective communication libraries.
7.2 Native Context Hits 10M+ Tokens
Gemini already claims 1M+ context. Ring attention and hybrid architectures will push this to 10M+ by 2027. The constraint shifts from architecture (solved by FlashAttention + ring attention + hybrids) to training data (few documents are 10M tokens long) and evaluation (no good benchmarks for ultra-long context).
7.3 Test-Time Compute Scaling Matures
The o1/o3/R1 paradigm of scaling inference compute will become standard for all hard reasoning tasks. Architectural support will include built-in verification, backtracking, and budget allocation. Models will learn to allocate their own compute budget per query.
7.4 Quantization Moves to Training
Current practice: train in BF16, quantize post-training. By 2027, training in FP8 or FP4 will be standard on Blackwell and successor hardware. This halves training cost without post-training quantization artifacts.
7.5 Medium-Confidence Predictions
- Hybrid SSM-Transformer: at least one frontier model will use a Mamba-Transformer hybrid for production inference (probability: 60%).
- MLA adoption beyond DeepSeek: at least two other labs will ship models with MLA-style KV cache compression (probability: 55%).
- Differentiable architecture search at scale: automated methods for finding optimal layer-type schedules (which layers are attention, which are SSM, which use MoE) will produce competitive models (probability: 40%).
7.6 What Stays the Same
Some things are unlikely to change by 2027:
- The residual stream as the backbone (near-certain to remain)
- RMSNorm and SwiGLU (no compelling alternatives)
- BPE tokenization (character/byte-level models remain less efficient)
- AdamW optimizer (or close variants like Adafactor)
- PyTorch as the training framework (JAX maintains a niche)
- Autoregressive left-to-right generation as the dominant paradigm
The transformer in 2026 is a mature architecture with well-understood design principles. The remaining degrees of freedom are in the attention mechanism (standard vs. MLA vs. hybrid), the FFN (dense vs. MoE), the position encoding parameters (base frequency, scaling method), and the training recipe (data mix, learning rate schedule, context length curriculum). The transformer will evolve, but it will not be replaced.
Building a 2026-Stack Model from Scratch
For reference, here is the complete configuration for a 7B-class model using the settled 2026 stack:
def model_config_2026_7b():
"""Complete configuration for a 2026-stack 7B model."""
return {
# Architecture
"d_model": 4096,
"n_layers": 32,
"n_heads": 32,
"n_kv_heads": 8, # GQA: 4x compression
"head_dim": 128,
"d_ff": 14336, # 8/3 * 4096, rounded
"vocab_size": 128256,
"max_seq_len": 131072, # 128K context
# Normalization
"norm_type": "rmsnorm",
"norm_eps": 1e-5,
"norm_position": "pre", # Pre-norm (before attention and FFN)
# Attention
"attention_type": "gqa",
"rope_base": 500000.0, # Extended for long context
"rope_scaling": "ntk-aware",
# FFN
"ffn_type": "swiglu",
"ffn_activation": "silu",
"ffn_bias": False,
# Other
"tie_word_embeddings": False,
"attention_bias": False,
"mlp_bias": False,
# Training
"precision": "bf16",
"optimizer": "adamw",
"lr": 3e-4,
"min_lr": 3e-5,
"warmup_steps": 2000,
"total_tokens": 15_000_000_000_000, # 15T tokens
"batch_size_tokens": 4_000_000,
"weight_decay": 0.1,
"grad_clip": 1.0,
}
This configuration represents the consensus best practices as of early 2026. Individual labs may deviate on specific parameters, but the overall structure is remarkably consistent across Llama, Mistral, Qwen, Gemma, and other open model families.
References
- Vaswani et al. “Attention Is All You Need.” NeurIPS 2017.
- Zhang and Sennrich. “Root Mean Square Layer Normalization.” NeurIPS 2019.
- Shazeer. “GLU Variants Improve Transformer.” arXiv 2020.
- Su et al. “RoFormer: Enhanced Transformer with Rotary Position Embedding.” Neurocomputing 2024.
- DeepSeek. “DeepSeek-V3 Technical Report.” arXiv 2024.
- AI21 Labs. “Jamba: A Hybrid Transformer-Mamba Language Model.” arXiv 2024.
- Gu and Dao. “Mamba: Linear-Time Sequence Modeling with Selective State Spaces.” arXiv 2023.
- Raposo et al. “Mixture-of-Depths: Dynamically Allocating Compute in Transformer-Based Language Models.” arXiv 2024.