Part of Series Transformer Anatomy 35 of 36
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Distributed Data Parallel: Gradient Synchronization, Bucket All-Reduce, and Overlap with Backward 21 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 22 Dropout and Regularization in Transformers: Where It Helps, Where It Hurts 23 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 24 Mixed Precision Training: BF16 Forward, FP32 Master Weights, and the Precision Hierarchy 25 Token Prediction Heads: Next-Token, Multi-Token, and Classifier Heads 26 Mixture of Depths: Conditional Computation Per Layer for Faster Inference 27 Sparse Attention Patterns: Local, Strided, Hash-Based, and Learnable Sparsity 28 Rotary Position Embedding: The Complete Mathematical Derivation 29 Knowledge Distillation: Training Small Models to Match Large Ones 30 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search 31 Pruning at Scale: SparseGPT, Wanda, and Structured Removal of Redundant Parameters 32 The Transformer in 2026: What Changed, What Stayed, and What's Next 33 Data Loading: Tokenization, Sequence Packing, Padding Strategies, and Attention Masks 34 The FlashAttention Backward Pass: Recomputation, Memory Savings, and the 33% Compute Overhead 35 The Inference Engine: Token Generation Loop, KV Cache Management, and Autoregressive Decoding 36 Tensor Parallelism Implementation: Splitting Weights Across GPUs for Training and Inference

The transformer architecture from “Attention Is All You Need” (2017) is unrecognizable in 2026. Every component has been replaced, optimized, or augmented. Yet the core computation — residual stream plus attention plus feedforward, repeated LL times — remains. This post catalogs what has settled into consensus, what is actively contested, and what might change next.

This is a snapshot, not a prediction. The field moves fast. But certain architectural choices have converged across labs, and certain research directions have enough momentum that their trajectory is predictable.

The Settled Stack

These components appear in nearly every frontier model deployed in 2025-2026. They are no longer research questions — they are engineering defaults.

1.1 RMSNorm Over LayerNorm

Every major model (Llama 3, Mistral, Qwen 2.5, DeepSeek V3, Gemma 2) uses RMSNorm instead of LayerNorm. The mean subtraction in LayerNorm is unnecessary for transformer performance and adds computational cost.

import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    """The standard normalization in 2026 transformers."""

    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        # No mean subtraction -- just normalize by RMS
        rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

class LayerNorm(nn.Module):
    """The 2017-2022 default. Now obsolete for LLMs."""

    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.bias = nn.Parameter(torch.zeros(dim))
        self.eps = eps

    def forward(self, x):
        # Mean subtraction + variance normalization
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        return (x - mean) / torch.sqrt(var + self.eps) * self.weight + self.bias

Why RMSNorm won: (1) 20-30% faster than LayerNorm due to no mean computation, (2) no bias parameter needed, (3) empirically equivalent quality in all tested settings. The mean subtraction in LayerNorm was a legacy from batch normalization that transformers never needed.

1.2 SwiGLU Over Standard FFN

The feedforward network in every modern transformer uses SwiGLU (or a variant like GeGLU):

class SwiGLU(nn.Module):
    """Gate * SiLU(gate_proj) * up_proj -- the 2026 standard."""

    def __init__(self, d_model, d_ff=None):
        super().__init__()
        if d_ff is None:
            # Llama convention: d_ff = 8/3 * d_model, rounded to multiple of 256
            d_ff = int(8 / 3 * d_model)
            d_ff = ((d_ff + 255) // 256) * 256

        self.gate_proj = nn.Linear(d_model, d_ff, bias=False)
        self.up_proj = nn.Linear(d_model, d_ff, bias=False)
        self.down_proj = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        return self.down_proj(
            nn.functional.silu(self.gate_proj(x)) * self.up_proj(x)
        )

class StandardFFN(nn.Module):
    """The 2017 original. No longer used in frontier models."""

    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(torch.relu(self.linear1(x)))

SwiGLU has 50% more parameters than the standard FFN at the same dffd_{ff} (three projections instead of two), but empirical scaling laws show it provides better loss per FLOP. The 83\frac{8}{3} ratio for dffd_{ff} compensates: 83×3=8\frac{8}{3} \times 3 = 8, matching the standard 4d4d FFN with two projections (4×2=84 \times 2 = 8 total weight matrices per token dimension).

1.3 RoPE for Positional Encoding

Rotary Position Embeddings (RoPE) replaced all alternatives: learned absolute, sinusoidal, ALiBi. Every major model uses RoPE.

def precompute_rope_freqs(dim, max_seq_len, base=10000.0):
    """Precompute the complex exponentials for RoPE.

    Each pair of dimensions rotates at a different frequency,
    determined by the geometric sequence base^(-2i/dim).
    """
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(max_seq_len).float()

    # Outer product: [seq_len, dim/2]
    angles = torch.outer(positions, freqs)

    # Complex representation: cos(theta) + i*sin(theta)
    return torch.polar(torch.ones_like(angles), angles)

def apply_rope(q, k, rope_freqs):
    """Apply rotary embeddings to query and key tensors.

    q, k: [B, n_heads, S, head_dim]
    rope_freqs: [S, head_dim/2] complex tensor
    """
    # Reshape to pairs: [B, n_heads, S, head_dim/2, 2]
    q_pairs = q.float().reshape(*q.shape[:-1], -1, 2)
    k_pairs = k.float().reshape(*k.shape[:-1], -1, 2)

    # Convert to complex
    q_complex = torch.view_as_complex(q_pairs)
    k_complex = torch.view_as_complex(k_pairs)

    # Rotate by element-wise multiplication with complex exponentials
    freqs = rope_freqs.unsqueeze(0).unsqueeze(0)  # [1, 1, S, head_dim/2]
    q_rotated = torch.view_as_real(q_complex * freqs).flatten(-2)
    k_rotated = torch.view_as_real(k_complex * freqs).flatten(-2)

    return q_rotated.type_as(q), k_rotated.type_as(k)

RoPE won because: (1) it encodes relative position through the dot product (the rotation angle between two positions depends only on their distance), (2) it extrapolates to longer sequences than seen in training (with NTK-aware scaling), (3) it adds zero parameters, (4) it is compatible with KV caching.

1.4 BPE Tokenization

Byte-Pair Encoding with a vocabulary of 100K-200K tokens is the universal standard. GPT-4 uses cl100k (100K tokens), Llama 3 uses a 128K vocabulary, Qwen 2.5 uses 152K. The tokenizer is trained on multilingual data with byte-level fallback.

def bpe_characteristics_2026():
    """Summary of modern BPE tokenizer properties."""
    return {
        "vocab_size": "100K-200K tokens",
        "algorithm": "Byte-level BPE (SentencePiece or tiktoken)",
        "byte_fallback": True,  # Any byte sequence is representable
        "special_tokens": [
            "<|begin_of_text|>",
            "<|end_of_text|>",
            "<|start_header_id|>",
            "<|end_header_id|>",
            "<|eot_id|>",
        ],
        "compression_ratio": "3.5-4.5 characters per token (English)",
        "multilingual": True,
        "training_data": "Trillions of tokens, balanced across languages",
    }
ℹ️ Note

The settled stack as of early 2026: Pre-RMSNorm (before attention and FFN), SwiGLU feedforward, RoPE positional encoding, BPE tokenization with 100K+ vocabulary, GQA (Grouped Query Attention) for inference efficiency, and no bias terms in linear layers. If you are building a new model from scratch, use exactly this stack.

Active Frontiers

These techniques are deployed in at least one major model but are not universal. The field is still evaluating their tradeoffs.

2.1 Multi-head Latent Attention (MLA)

DeepSeek V2/V3 introduced MLA, which compresses the KV cache by projecting keys and values into a low-rank latent space before caching:

class MultiHeadLatentAttention(nn.Module):
    """MLA: compress KV cache via low-rank projection.

    Standard GQA caches: n_kv_heads * head_dim * 2 * seq_len per layer
    MLA caches: latent_dim * seq_len per layer (shared across heads)

    For DeepSeek V3: latent_dim=512 vs GQA would need 1024+
    """

    def __init__(self, d_model, n_heads, latent_dim, head_dim):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = head_dim
        self.latent_dim = latent_dim

        # Compress input to latent representation (cached)
        self.down_proj = nn.Linear(d_model, latent_dim, bias=False)

        # Expand latent to per-head K and V (computed on the fly)
        self.k_proj = nn.Linear(latent_dim, n_heads * head_dim, bias=False)
        self.v_proj = nn.Linear(latent_dim, n_heads * head_dim, bias=False)

        # Standard Q projection
        self.q_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.o_proj = nn.Linear(n_heads * head_dim, d_model, bias=False)

        self.scale = head_dim ** -0.5

    def forward(self, x, cached_latent=None):
        B, S, _ = x.shape

        # Compress to latent (this is what gets cached)
        latent = self.down_proj(x)  # [B, S, latent_dim]

        # Expand to K, V
        k = self.k_proj(latent).reshape(B, S, self.n_heads, self.head_dim)
        v = self.v_proj(latent).reshape(B, S, self.n_heads, self.head_dim)
        q = self.q_proj(x).reshape(B, S, self.n_heads, self.head_dim)

        # Transpose for attention: [B, n_heads, S, head_dim]
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)

        # Standard attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        out = attn @ v

        out = out.transpose(1, 2).reshape(B, S, -1)
        return self.o_proj(out), latent  # Return latent for caching
Performance

MLA KV cache comparison for a 67B model at sequence length 128K. Standard MHA: 128×128×128000×2=4.2128 \times 128 \times 128000 \times 2 = 4.2 GB per layer. GQA (8 KV heads): 8×128×128000×2=2628 \times 128 \times 128000 \times 2 = 262 MB per layer. MLA (latent_dim=512): 512×128000×2=131512 \times 128000 \times 2 = 131 MB per layer. MLA halves the KV cache compared to GQA with 8 heads.

2.2 Linear Attention and State Space Models

The quadratic cost of attention (O(n2d)O(n^2 d)) motivates subquadratic alternatives. Linear attention replaces the softmax with a kernel function that allows the attention computation to be factored:

class LinearAttention(nn.Module):
    """Linear attention: O(n * d^2) instead of O(n^2 * d).

    Replaces softmax(QK^T) with phi(Q) * phi(K)^T, where phi is a
    feature map. The key insight: phi(Q)(phi(K)^T V) can be computed
    left-to-right in O(n * d^2) by maintaining a running sum.
    """

    def __init__(self, d_model, n_heads, head_dim):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = head_dim

        self.q_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.k_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, n_heads * head_dim, bias=False)
        self.o_proj = nn.Linear(n_heads * head_dim, d_model, bias=False)

    def feature_map(self, x):
        """ELU + 1 feature map (from Katharopoulos et al.)"""
        return nn.functional.elu(x) + 1

    def forward(self, x):
        B, S, _ = x.shape

        q = self.q_proj(x).reshape(B, S, self.n_heads, self.head_dim)
        k = self.k_proj(x).reshape(B, S, self.n_heads, self.head_dim)
        v = self.v_proj(x).reshape(B, S, self.n_heads, self.head_dim)

        # Apply feature map
        q = self.feature_map(q)  # [B, S, H, D]
        k = self.feature_map(k)  # [B, S, H, D]

        # Causal linear attention via cumulative sum
        # S = cumsum(k^T @ v) -- the "state" matrix [B, H, D, D]
        # output_t = q_t @ S_t / (q_t @ cumsum(k))

        kv = torch.einsum("bshd,bshe->bshde", k, v)  # [B, S, H, D, D]
        state = kv.cumsum(dim=1)  # Running sum of outer products

        # Numerator: q @ state
        num = torch.einsum("bshd,bshde->bshe", q, state)  # [B, S, H, D]

        # Denominator: q @ cumsum(k) for normalization
        k_cumsum = k.cumsum(dim=1)
        den = torch.einsum("bshd,bshd->bsh", q, k_cumsum).unsqueeze(-1)

        out = num / (den + 1e-6)
        out = out.reshape(B, S, -1)
        return self.o_proj(out)

The problem with linear attention: the state matrix SRd×dS \in \mathbb{R}^{d \times d} must capture all historical context, which limits recall ability. Recent work (Based, GLA, RWKV-6, Mamba-2) addresses this through data-dependent state transitions and selective forgetting.

2.3 Mixture of Experts (MoE)

MoE replaces the dense FFN with a set of expert FFNs and a router that selects kk experts per token:

class MoELayer(nn.Module):
    """Mixture of Experts: route each token to top-k experts.

    Mixtral: 8 experts, top-2
    DeepSeek V3: 256 experts, top-8
    """

    def __init__(self, d_model, d_ff, n_experts=8, top_k=2):
        super().__init__()
        self.n_experts = n_experts
        self.top_k = top_k

        # Router: produces expert selection probabilities
        self.router = nn.Linear(d_model, n_experts, bias=False)

        # Expert FFNs (each is a SwiGLU)
        self.experts = nn.ModuleList([
            SwiGLU(d_model, d_ff) for _ in range(n_experts)
        ])

    def forward(self, x):
        B, S, D = x.shape
        x_flat = x.reshape(-1, D)  # [B*S, D]

        # Route
        router_logits = self.router(x_flat)  # [B*S, n_experts]
        router_probs = router_logits.softmax(dim=-1)

        # Select top-k experts per token
        top_k_probs, top_k_indices = router_probs.topk(self.top_k, dim=-1)

        # Normalize selected expert weights
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        # Compute expert outputs (simplified -- real impl uses grouped GEMM)
        output = torch.zeros_like(x_flat)
        for i in range(self.top_k):
            expert_idx = top_k_indices[:, i]  # [B*S]
            expert_weight = top_k_probs[:, i]  # [B*S]

            for e in range(self.n_experts):
                mask = expert_idx == e
                if mask.any():
                    expert_input = x_flat[mask]
                    expert_output = self.experts[e](expert_input)
                    output[mask] += expert_weight[mask].unsqueeze(-1) * expert_output

        return output.reshape(B, S, D)

MoE is now proven at scale: Mixtral 8x7B (2023), DeepSeek V3 671B (2024), and several undisclosed models. The key tradeoff: MoE models have more total parameters (higher memory) but activate fewer per token (lower FLOPs). A 671B MoE model with top-8 of 256 experts activates roughly 37B parameters per token — comparable FLOPs to a 37B dense model but with the representational capacity of a much larger network.

2.4 Multi-Token Prediction (MTP)

Instead of predicting one next token, predict KK tokens simultaneously. This provides a denser training signal and enables self-speculative decoding:

class MTPHead(nn.Module):
    """Multi-Token Prediction: predict K future tokens.

    DeepSeek V3 uses K=2 (next token + one ahead).
    During inference, the extra prediction enables speculative decoding
    without a separate draft model.
    """

    def __init__(self, d_model, vocab_size, K=2):
        super().__init__()
        self.K = K

        # Each future position gets a projection layer that
        # transforms the hidden state for that prediction horizon
        self.transform = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_model, bias=False),
                RMSNorm(d_model),
            ) for _ in range(K - 1)  # First token uses raw hidden state
        ])

        # Shared output embedding (tied with input embedding)
        self.output = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, hidden_states):
        """
        hidden_states: [B, S, d_model]
        Returns: list of K logit tensors, each [B, S, vocab_size]
        """
        logits = [self.output(hidden_states)]  # Token t+1

        h = hidden_states
        for k in range(self.K - 1):
            h = self.transform[k](h)
            logits.append(self.output(h))  # Token t+k+2

        return logits

    def mtp_loss(self, logits_list, target_ids):
        """
        logits_list: K tensors of [B, S, V]
        target_ids: [B, S] ground truth token IDs
        """
        total_loss = 0
        for k, logits in enumerate(logits_list):
            # Shift targets: predict token at position t+k+1
            shift = k + 1
            shifted_logits = logits[:, :-shift, :]
            shifted_targets = target_ids[:, shift:]

            loss = nn.functional.cross_entropy(
                shifted_logits.reshape(-1, shifted_logits.size(-1)),
                shifted_targets.reshape(-1),
            )
            # Weight future predictions less
            weight = 1.0 / (k + 1)
            total_loss += weight * loss

        return total_loss
💡 Tip

MTP’s primary value during training is providing a richer gradient signal — predicting future tokens forces the model to build better internal representations. The speculative decoding benefit at inference is a bonus. DeepSeek reports MTP improves training efficiency by 10-15% (fewer tokens needed to reach the same loss).

Emerging Directions

These are research-stage techniques with strong results but limited production deployment. They represent the likely next wave.

3.1 Mamba-Transformer Hybrids

Pure Mamba (selective state space model) struggles with in-context retrieval tasks that require precise token-to-token attention. Pure transformers struggle with very long sequences due to O(n2)O(n^2) cost. Hybrids interleave the two:

class MambaTransformerBlock(nn.Module):
    """Hybrid block: alternate Mamba and attention layers.

    Jamba (AI21, 2024) pattern: 7 Mamba layers per 1 attention layer.
    The attention layers handle precise retrieval.
    The Mamba layers handle long-range context compression.
    """

    def __init__(self, d_model, layer_type="mamba"):
        super().__init__()
        self.layer_type = layer_type
        self.norm = RMSNorm(d_model)

        if layer_type == "mamba":
            self.core = MambaBlock(d_model)  # O(n) per layer
        elif layer_type == "attention":
            self.core = AttentionBlock(d_model)  # O(n^2) per layer

        self.ffn_norm = RMSNorm(d_model)
        self.ffn = SwiGLU(d_model)

    def forward(self, x, **kwargs):
        x = x + self.core(self.norm(x), **kwargs)
        x = x + self.ffn(self.ffn_norm(x))
        return x

class MambaBlock(nn.Module):
    """Simplified Mamba-2 block (selective state space model).

    Core idea: state transition matrix A is data-dependent,
    allowing the model to selectively remember or forget.
    """

    def __init__(self, d_model, d_state=64, d_conv=4, expand=2):
        super().__init__()
        d_inner = d_model * expand

        self.in_proj = nn.Linear(d_model, d_inner * 2, bias=False)
        self.conv1d = nn.Conv1d(
            d_inner, d_inner, kernel_size=d_conv,
            padding=d_conv - 1, groups=d_inner
        )

        # SSM parameters (data-dependent)
        self.dt_proj = nn.Linear(d_inner, d_inner, bias=True)
        self.B_proj = nn.Linear(d_inner, d_state, bias=False)
        self.C_proj = nn.Linear(d_inner, d_state, bias=False)

        # Fixed A: log-space parameterization for stability
        A = torch.arange(1, d_state + 1).float()
        self.A_log = nn.Parameter(torch.log(A).unsqueeze(0).expand(d_inner, -1))

        self.D = nn.Parameter(torch.ones(d_inner))
        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def forward(self, x):
        B, S, D = x.shape

        # Project and split into main path and gate
        xz = self.in_proj(x)  # [B, S, 2*d_inner]
        x_main, z = xz.chunk(2, dim=-1)  # Each [B, S, d_inner]

        # Conv1d on the main path
        x_conv = self.conv1d(x_main.transpose(1, 2))[:, :, :S].transpose(1, 2)
        x_conv = nn.functional.silu(x_conv)

        # Compute data-dependent SSM parameters
        dt = nn.functional.softplus(self.dt_proj(x_conv))  # [B, S, d_inner]
        B_param = self.B_proj(x_conv)  # [B, S, d_state]
        C_param = self.C_proj(x_conv)  # [B, S, d_state]
        A = -torch.exp(self.A_log)  # [d_inner, d_state]

        # Selective scan (simplified -- real impl uses CUDA kernel)
        y = selective_scan(x_conv, dt, A, B_param, C_param, self.D)

        # Gate and project out
        y = y * nn.functional.silu(z)
        return self.out_proj(y)

def selective_scan(x, dt, A, B, C, D):
    """Selective scan: the core Mamba operation.

    Processes sequence left-to-right, maintaining a hidden state
    that is selectively updated based on the input.

    x: [B, S, d_inner]
    dt: [B, S, d_inner] -- discretization step (data-dependent)
    A: [d_inner, d_state] -- state transition (log-space)
    B: [B, S, d_state] -- input projection (data-dependent)
    C: [B, S, d_state] -- output projection (data-dependent)
    D: [d_inner] -- skip connection
    """
    B_batch, S, d_inner = x.shape
    d_state = A.shape[1]

    # Discretize: A_bar = exp(dt * A), B_bar = dt * B
    dA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0))
    dB = dt.unsqueeze(-1) * B.unsqueeze(2)  # [B, S, d_inner, d_state]

    # Sequential scan
    h = torch.zeros(B_batch, d_inner, d_state, device=x.device)
    outputs = []

    for t in range(S):
        h = dA[:, t] * h + dB[:, t] * x[:, t].unsqueeze(-1)
        y_t = (h * C[:, t].unsqueeze(1)).sum(dim=-1)  # [B, d_inner]
        outputs.append(y_t)

    y = torch.stack(outputs, dim=1)  # [B, S, d_inner]
    y = y + x * D.unsqueeze(0).unsqueeze(0)  # Skip connection
    return y

Empirical results from Jamba and follow-up work show that hybrids with a 7:1 or 4:1 ratio of Mamba-to-attention layers match pure transformer quality while reducing KV cache by 80%+ and improving throughput on long sequences by 2-3x.

3.2 Mixture of Depths (MoD)

Not every token needs every layer. MoD adds a router at each layer that decides whether a token should be processed or skip via the residual connection:

class MoDLayer(nn.Module):
    """Mixture of Depths: skip layers for easy tokens.

    capacity_ratio=0.5 means only 50% of tokens are processed.
    The other 50% pass through via residual connection.
    """

    def __init__(self, d_model, capacity_ratio=0.5):
        super().__init__()
        self.router = nn.Linear(d_model, 1, bias=True)
        self.capacity_ratio = capacity_ratio
        self.attention = AttentionBlock(d_model)
        self.ffn = SwiGLU(d_model)
        self.norm1 = RMSNorm(d_model)
        self.norm2 = RMSNorm(d_model)

    def forward(self, x, training=True):
        B, S, D = x.shape
        scores = self.router(x).squeeze(-1)  # [B, S]

        # Select top-k tokens (k = capacity_ratio * S)
        k = int(S * self.capacity_ratio)
        _, top_indices = scores.topk(k, dim=1)

        # Process selected tokens
        selected = torch.gather(
            x, 1, top_indices.unsqueeze(-1).expand(-1, -1, D)
        )

        # Attention + FFN on selected tokens only
        selected = selected + self.attention(self.norm1(selected))
        selected = selected + self.ffn(self.norm2(selected))

        # Scatter back
        output = x.clone()
        output.scatter_(1, top_indices.unsqueeze(-1).expand(-1, -1, D), selected)

        return output
📊

Mixture of Depths Impact (32-layer Transformer)

ConfigurationRelative FLOPsPerplexityFLOP Savings
Standard (capacity=1.0) 1.00x 8.42 baseline
MoD (capacity=0.75) 0.78x 8.48 -22%
MoD (capacity=0.50) 0.56x 8.71 -44%
MoD (capacity=0.25) 0.38x 9.34 -62%

3.3 Test-Time Compute Scaling

A major shift in 2024-2025: instead of only scaling training compute, scale inference compute. Models like o1, o3, DeepSeek R1, and QwQ allocate variable compute per query based on difficulty. The key mechanism is chain-of-thought sampling with verification:

def test_time_scaling(model, tokenizer, prompt,
                      max_attempts=16, verifier=None):
    """Scale test-time compute by sampling multiple solutions
    and selecting the best one.

    Budget allocation:
    - Easy questions: 1 sample, short chain-of-thought
    - Hard questions: 16+ samples, long chain-of-thought, verification
    """
    solutions = []

    for attempt in range(max_attempts):
        # Sample with temperature > 0 for diversity
        output = model.generate(
            tokenizer.encode(prompt, return_tensors="pt").cuda(),
            max_new_tokens=4096,
            temperature=0.7,
            top_p=0.95,
            do_sample=True,
        )
        solution = tokenizer.decode(output[0])
        solutions.append(solution)

        # Early stopping: if verifier is confident, stop sampling
        if verifier is not None:
            confidence = verifier.score(prompt, solution)
            if confidence > 0.95:
                return solution

    # Select best solution via majority voting or verifier
    if verifier is not None:
        scores = [verifier.score(prompt, s) for s in solutions]
        return solutions[scores.index(max(scores))]
    else:
        # Majority voting on final answer
        return majority_vote(solutions)

def majority_vote(solutions):
    """Extract final answers and return the most common one."""
    from collections import Counter
    answers = [extract_final_answer(s) for s in solutions]
    counter = Counter(answers)
    return solutions[answers.index(counter.most_common(1)[0][0])]

This is not an architectural change to the transformer itself, but it changes how transformers are used. The model architecture must support long-form reasoning (extended context, reliable chain-of-thought), and the training pipeline must include reinforcement learning for reasoning quality.

3.4 Native Long Context

Context windows have expanded from 2K (GPT-2) to 128K (Llama 3.1) to 1M+ (Gemini). The architectural requirements:

  1. RoPE scaling for position extrapolation:
def ntk_aware_rope(dim, max_seq_len, base=10000.0,
                   training_length=8192, alpha=None):
    """NTK-aware RoPE scaling for context extension.

    Adjusts the base frequency to avoid concentrated attention
    patterns at positions beyond training length.
    """
    if alpha is None:
        # Dynamic alpha based on sequence length ratio
        alpha = (max_seq_len / training_length) - 1
        alpha = max(alpha, 0) + 1

    # Modify base to spread frequencies
    adjusted_base = base * alpha ** (dim / (dim - 2))

    freqs = 1.0 / (adjusted_base ** (torch.arange(0, dim, 2).float() / dim))
    positions = torch.arange(max_seq_len).float()
    angles = torch.outer(positions, freqs)

    return torch.polar(torch.ones_like(angles), angles)
  1. FlashAttention for memory efficiency (no O(n2)O(n^2) materialization)
  2. Ring attention for distributing long sequences across GPUs:
def ring_attention_concept(q, k, v, ring_size):
    """Ring attention: distribute sequence across GPUs in a ring.

    Each GPU holds a chunk of the sequence. KV pairs are passed
    around the ring so every chunk attends to every other chunk.

    Memory per GPU: O(n/P * d) instead of O(n * d)
    Communication: P-1 rounds of KV transfer
    """
    chunk_size = q.shape[1]  # Each GPU has seq_len / ring_size tokens

    local_q = q  # This GPU's queries (stays put)
    local_kv = (k, v)  # Start with local KV

    output = torch.zeros_like(q)
    running_max = torch.full((q.shape[0], q.shape[1], q.shape[2], 1),
                              float("-inf"), device=q.device)
    running_sum = torch.zeros_like(output)
    running_denom = torch.zeros(
        q.shape[0], q.shape[1], q.shape[2], 1, device=q.device
    )

    for step in range(ring_size):
        k_chunk, v_chunk = local_kv

        # Compute attention for this chunk
        scores = (local_q @ k_chunk.transpose(-2, -1)) / (q.shape[-1] ** 0.5)
        # Apply causal mask if needed

        chunk_max = scores.max(dim=-1, keepdim=True).values
        new_max = torch.maximum(running_max, chunk_max)

        # Online softmax update
        exp_scores = torch.exp(scores - new_max)
        exp_old = torch.exp(running_max - new_max)

        running_sum = exp_old * running_sum + exp_scores @ v_chunk
        running_denom = exp_old * running_denom + exp_scores.sum(dim=-1, keepdim=True)
        running_max = new_max

        # Send KV to next GPU in ring, receive from previous
        # local_kv = ring_send_recv(local_kv)

    output = running_sum / running_denom
    return output

What Did Not Work

Several hyped approaches from 2023-2024 have not achieved broad adoption:

4.1 Pure State Space Models

RWKV, Mamba (as standalone architectures), and RetNet showed promise for linear-time sequence processing. But pure SSMs consistently underperform transformers on in-context learning benchmarks, especially few-shot tasks requiring precise attention to specific input tokens. The consensus: SSMs are excellent compression mechanisms but poor retrieval mechanisms. Hybrids (Section 3.1) are the path forward.

4.2 ALiBi Position Encoding

ALiBi (Attention with Linear Biases) was used in MPT and some BLOOM variants. It adds a linear bias to attention scores based on position distance. RoPE proved superior for context extension, and ALiBi has been abandoned by all major labs.

4.3 Sparse Mixture of Experts with Few Experts

Mixtral’s 8-expert design was influential but DeepSeek V3 showed that many more experts (256) with finer granularity gives better routing and quality. The trend is toward more experts, not fewer.

4.4 Retrieval-Augmented Generation (as Architecture)

RAG was proposed as a way to extend context by retrieving relevant documents. It works well as a system design, but attempts to build retrieval into the architecture (RETRO, Atlas) have not been adopted. The winning approach is large native context windows plus external retrieval at the system level, not the architecture level.

Quantitative Architecture Comparison

Architecture FLOPs vs Quality (7B-scale models)

Metric 0.25x0.5x0.75x1.0x1.5x2.0x
Dense Transformer (2026 stack)
3.2
2.85
2.65
2.52
2.38
2.29
MoE 8x (2x total params)
3
2.68
2.5
2.38
2.25
2.17
Mamba-Transformer Hybrid
3.15
2.8
2.6
2.48
2.35
2.26
Dense Transformer (2022 stack)
3.4
3.05
2.85
2.72
2.57
2.48

The 2026 stack (RMSNorm + SwiGLU + RoPE + GQA) gives roughly 8-10% better loss per FLOP than the 2022 stack (LayerNorm + ReLU FFN + learned absolute positions + MHA). MoE gives another 5-8% at the cost of higher memory. Hybrids are competitive with the dense transformer and win at long context lengths.

The Infrastructure Stack

Architecture choices do not exist in isolation. They are constrained by and co-evolved with the infrastructure stack:

6.1 Training Infrastructure

def training_stack_2026():
    """The standard training infrastructure."""
    return {
        "hardware": "NVIDIA H100/H200 or AMD MI300X clusters",
        "interconnect": "NVLink + InfiniBand (400 Gbps+)",
        "parallelism": {
            "data": "FSDP (ZeRO-3) or PyTorch FSDP2",
            "tensor": "Megatron-style column/row parallel",
            "pipeline": "1F1B or interleaved schedule",
            "context": "Ring attention or Ulysses for long sequences",
            "expert": "Expert parallelism for MoE (all-to-all comm)",
        },
        "precision": "BF16 forward/backward, FP32 master weights",
        "optimizer": "AdamW (beta1=0.9, beta2=0.95, eps=1e-8)",
        "framework": "PyTorch 2.x with torch.compile",
        "checkpointing": "Activation checkpointing for memory, async I/O",
    }

6.2 Inference Infrastructure

def inference_stack_2026():
    """The standard inference infrastructure."""
    return {
        "serving": "vLLM, TensorRT-LLM, or SGLang",
        "batching": "Continuous batching with PagedAttention",
        "kv_cache": "Paged, with prefix caching for shared prompts",
        "quantization": "W4A16 (GPTQ/AWQ) or W8A8 (SmoothQuant)",
        "speculative_decoding": "Self-speculative (MTP) or draft model",
        "hardware": "H100 SXM (80GB HBM3) or multi-GPU with NVLink",
        "memory_management": "PagedAttention with block size 16",
    }

6.3 Key Metric: Tokens Per Dollar

The ultimate metric for architecture evaluation is tokens per dollar — combining training cost, inference cost, and quality:

def tokens_per_dollar(model_params_B, flops_per_token,
                      gpu_tflops, gpu_cost_per_hour):
    """Estimate inference cost in tokens per dollar.

    Example: Llama 70B on H100
    - flops_per_token: 2 * 70e9 = 140 GFLOP
    - gpu_tflops: 990 TFLOP/s (BF16 tensor core)
    - utilization: ~40% for autoregressive (memory-bound)
    - effective_tflops: 396 TFLOP/s
    - tokens/sec: 396e12 / 140e9 = 2828 tokens/sec
    - gpu_cost: $2.50/hr (cloud spot)
    - tokens/dollar: 2828 * 3600 / 2.50 = 4.07M tokens/$
    """
    effective_tflops = gpu_tflops * 0.4  # Utilization factor
    tokens_per_sec = (effective_tflops * 1e12) / (flops_per_token)
    tokens_per_hour = tokens_per_sec * 3600
    return tokens_per_hour / gpu_cost_per_hour

Predictions for 2027

Based on current trajectories, high-confidence predictions (greater than 70% probability):

7.1 MoE Becomes Default

By late 2027, most frontier models will use MoE. The evidence: DeepSeek V3 demonstrated that MoE at scale works. The economic argument is overwhelming — MoE gives better quality per inference FLOP, and inference cost dominates total cost of ownership for deployed models. The remaining challenge is efficient expert parallelism during training, which is being solved by better collective communication libraries.

7.2 Native Context Hits 10M+ Tokens

Gemini already claims 1M+ context. Ring attention and hybrid architectures will push this to 10M+ by 2027. The constraint shifts from architecture (solved by FlashAttention + ring attention + hybrids) to training data (few documents are 10M tokens long) and evaluation (no good benchmarks for ultra-long context).

7.3 Test-Time Compute Scaling Matures

The o1/o3/R1 paradigm of scaling inference compute will become standard for all hard reasoning tasks. Architectural support will include built-in verification, backtracking, and budget allocation. Models will learn to allocate their own compute budget per query.

7.4 Quantization Moves to Training

Current practice: train in BF16, quantize post-training. By 2027, training in FP8 or FP4 will be standard on Blackwell and successor hardware. This halves training cost without post-training quantization artifacts.

7.5 Medium-Confidence Predictions

  • Hybrid SSM-Transformer: at least one frontier model will use a Mamba-Transformer hybrid for production inference (probability: 60%).
  • MLA adoption beyond DeepSeek: at least two other labs will ship models with MLA-style KV cache compression (probability: 55%).
  • Differentiable architecture search at scale: automated methods for finding optimal layer-type schedules (which layers are attention, which are SSM, which use MoE) will produce competitive models (probability: 40%).

7.6 What Stays the Same

Some things are unlikely to change by 2027:

  • The residual stream as the backbone (near-certain to remain)
  • RMSNorm and SwiGLU (no compelling alternatives)
  • BPE tokenization (character/byte-level models remain less efficient)
  • AdamW optimizer (or close variants like Adafactor)
  • PyTorch as the training framework (JAX maintains a niche)
  • Autoregressive left-to-right generation as the dominant paradigm
ℹ️ Note

The transformer in 2026 is a mature architecture with well-understood design principles. The remaining degrees of freedom are in the attention mechanism (standard vs. MLA vs. hybrid), the FFN (dense vs. MoE), the position encoding parameters (base frequency, scaling method), and the training recipe (data mix, learning rate schedule, context length curriculum). The transformer will evolve, but it will not be replaced.

Building a 2026-Stack Model from Scratch

For reference, here is the complete configuration for a 7B-class model using the settled 2026 stack:

def model_config_2026_7b():
    """Complete configuration for a 2026-stack 7B model."""
    return {
        # Architecture
        "d_model": 4096,
        "n_layers": 32,
        "n_heads": 32,
        "n_kv_heads": 8,           # GQA: 4x compression
        "head_dim": 128,
        "d_ff": 14336,             # 8/3 * 4096, rounded
        "vocab_size": 128256,
        "max_seq_len": 131072,     # 128K context

        # Normalization
        "norm_type": "rmsnorm",
        "norm_eps": 1e-5,
        "norm_position": "pre",    # Pre-norm (before attention and FFN)

        # Attention
        "attention_type": "gqa",
        "rope_base": 500000.0,    # Extended for long context
        "rope_scaling": "ntk-aware",

        # FFN
        "ffn_type": "swiglu",
        "ffn_activation": "silu",
        "ffn_bias": False,

        # Other
        "tie_word_embeddings": False,
        "attention_bias": False,
        "mlp_bias": False,

        # Training
        "precision": "bf16",
        "optimizer": "adamw",
        "lr": 3e-4,
        "min_lr": 3e-5,
        "warmup_steps": 2000,
        "total_tokens": 15_000_000_000_000,  # 15T tokens
        "batch_size_tokens": 4_000_000,
        "weight_decay": 0.1,
        "grad_clip": 1.0,
    }

This configuration represents the consensus best practices as of early 2026. Individual labs may deviate on specific parameters, but the overall structure is remarkably consistent across Llama, Mistral, Qwen, Gemma, and other open model families.

References

  1. Vaswani et al. “Attention Is All You Need.” NeurIPS 2017.
  2. Zhang and Sennrich. “Root Mean Square Layer Normalization.” NeurIPS 2019.
  3. Shazeer. “GLU Variants Improve Transformer.” arXiv 2020.
  4. Su et al. “RoFormer: Enhanced Transformer with Rotary Position Embedding.” Neurocomputing 2024.
  5. DeepSeek. “DeepSeek-V3 Technical Report.” arXiv 2024.
  6. AI21 Labs. “Jamba: A Hybrid Transformer-Mamba Language Model.” arXiv 2024.
  7. Gu and Dao. “Mamba: Linear-Time Sequence Modeling with Selective State Spaces.” arXiv 2023.
  8. Raposo et al. “Mixture-of-Depths: Dynamically Allocating Compute in Transformer-Based Language Models.” arXiv 2024.