Part of Series Transformer Anatomy 26 of 36
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Distributed Data Parallel: Gradient Synchronization, Bucket All-Reduce, and Overlap with Backward 21 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 22 Dropout and Regularization in Transformers: Where It Helps, Where It Hurts 23 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 24 Mixed Precision Training: BF16 Forward, FP32 Master Weights, and the Precision Hierarchy 25 Token Prediction Heads: Next-Token, Multi-Token, and Classifier Heads 26 Mixture of Depths: Conditional Computation Per Layer for Faster Inference 27 Sparse Attention Patterns: Local, Strided, Hash-Based, and Learnable Sparsity 28 Rotary Position Embedding: The Complete Mathematical Derivation 29 Knowledge Distillation: Training Small Models to Match Large Ones 30 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search 31 Pruning at Scale: SparseGPT, Wanda, and Structured Removal of Redundant Parameters 32 The Transformer in 2026: What Changed, What Stayed, and What's Next 33 Data Loading: Tokenization, Sequence Packing, Padding Strategies, and Attention Masks 34 The FlashAttention Backward Pass: Recomputation, Memory Savings, and the 33% Compute Overhead 35 The Inference Engine: Token Generation Loop, KV Cache Management, and Autoregressive Decoding 36 Tensor Parallelism Implementation: Splitting Weights Across GPUs for Training and Inference

Here’s a wasteful truth: in a standard transformer, the word “the” gets exactly the same computational treatment as your model’s most complex reasoning step. All 80 layers of Llama 70B, burning through trillions of FLOPs, processing a function word that contributes almost nothing to the final output. Mixture of Depths (MoD) fixes this by giving each layer a router that makes a simple binary decision: should this token compute through attention and FFN, or should it skip directly via the residual connection? The result is 30-50% fewer FLOPs at inference with minimal quality loss, achieved by letting simple tokens coast through most layers while complex tokens get the full compute budget.

The Core Idea

At each layer ll, a lightweight router produces a binary decision per token:

dt,l={1if router(ht(l))>τ0skip this layerd_{t,l} = \begin{cases} 1 & \text{if router}(h_t^{(l)}) > \tau \\ 0 & \text{skip this layer} \end{cases}

If dt,l=1d_{t,l} = 1: the token passes through attention + FFN normally. If dt,l=0d_{t,l} = 0: the token passes through only the residual connection (ht(l+1)=ht(l)h_t^{(l+1)} = h_t^{(l)}), skipping all computation in layer ll.

import torch
import torch.nn as nn

class MoDRouter(nn.Module):
    """Per-layer router: decides which tokens compute, which skip."""

    def __init__(self, d_model, capacity_ratio=0.5):
        super().__init__()
        self.gate = nn.Linear(d_model, 1, bias=True)
        self.capacity_ratio = capacity_ratio  # Fraction of tokens that compute

    def forward(self, hidden_states, training=True):
        """
        hidden_states: [B, S, d_model]
        Returns: mask [B, S] with 1 for compute, 0 for skip
        """
        scores = self.gate(hidden_states).squeeze(-1)  # [B, S]

        if training:
            # Top-k: select top capacity_ratio fraction of tokens
            k = max(1, int(scores.shape[1] * self.capacity_ratio))
            topk_vals, topk_idx = scores.topk(k, dim=-1)
            mask = torch.zeros_like(scores)
            mask.scatter_(1, topk_idx, 1.0)
            # Straight-through estimator for gradients
            return mask + (scores.sigmoid() - scores.sigmoid().detach())
        else:
            # At inference: threshold-based decision
            return (scores > 0).float()

class MoDTransformerLayer(nn.Module):
    """Transformer layer with Mixture of Depths routing."""

    def __init__(self, attention, ffn, d_model, capacity_ratio=0.5):
        super().__init__()
        self.attention = attention
        self.ffn = ffn
        self.norm1 = nn.RMSNorm(d_model)
        self.norm2 = nn.RMSNorm(d_model)
        self.router = MoDRouter(d_model, capacity_ratio)

    def forward(self, x):
        # Router decides which tokens compute
        mask = self.router(x)  # [B, S], binary

        # Only compute attention + FFN for selected tokens
        if mask.sum() > 0:
            # Gather selected tokens
            selected_idx = mask.nonzero(as_tuple=True)
            selected = x[selected_idx]  # [num_selected, d_model]

            # Attention (on selected tokens only)
            normed = self.norm1(selected.unsqueeze(0))
            attn_out = self.attention(normed)
            selected = selected + attn_out.squeeze(0)

            # FFN (on selected tokens only)
            normed = self.norm2(selected.unsqueeze(0))
            ffn_out = self.ffn(normed)
            selected = selected + ffn_out.squeeze(0)

            # Scatter back
            x = x.clone()
            x[selected_idx] = selected

        # Skipped tokens pass through unchanged (residual only)
        return x
📊

Mixture of Depths: FLOPs Savings vs Quality

Capacity RatioFLOPs ReductionPerplexity ImpactTokens Skipping per Layer
1.0 (all compute) 0% Baseline 0%
0.75 (75% compute) 18% +0.1 PPL 25% skip
0.50 (50% compute) 35% +0.3 PPL 50% skip
0.25 (25% compute) 52% +0.8 PPL 75% skip
0.12 (12% compute) 65% +2.1 PPL 88% skip
Note: Capacity ratio = fraction of tokens that pass through attention + FFN. Lower ratio = more savings but more quality loss.
The Sweet Spot: 50% Capacity

At capacity_ratio = 0.5, half the tokens skip each layer. This saves 35% of total FLOPs with only 0.3 perplexity points degradation. The router learns that function words, punctuation, and repetitive patterns can safely skip most layers, while content words, reasoning steps, and novel concepts need full processing.

What the Router Learns

Analysis of trained MoD routers shows consistent patterns:

  • Always compute (capacity_score consistently high): First token, last token, tokens after question marks, numbers in arithmetic, code keywords
  • Usually skip (capacity_score consistently low): Articles (“the”, “a”), prepositions (“of”, “in”), common conjunctions, repeated whitespace
  • Layer-dependent: Some tokens compute in early layers (syntactic processing) but skip late layers (already resolved). Others skip early layers but compute in late layers (semantic integration).

Token Skip Rate by Token Type (Llama 7B + MoD, capacity=0.5)

(% of layers skipped)
Function words (the, is, of) Skip most layers
78 % of layers skipped
Common nouns (data, model) Skip some layers
45 % of layers skipped
Technical terms (eigenvalue) Compute most layers
22 % of layers skipped
Numbers (3.14159) Rarely skip
12 % of layers skipped
First/last position Almost always compute
5 % of layers skipped

Connection to MoE

MoD is orthogonal to MoE: MoE selects WHICH expert processes each token (conditional routing to different FFNs), while MoD selects WHETHER any processing happens at all (conditional compute vs skip). They can be combined: a token first decides whether to compute (MoD router), then if yes, which expert to use (MoE router).

Summary

Mixture of Depths challenges a fundamental assumption in transformer design: that every token needs every layer. By adding a lightweight router that decides per-token, per-layer whether to compute or skip, MoD reduces inference FLOPs by 30-50% while maintaining quality within 0.3 perplexity points of the full model. The key insight is that most tokens in most layers are easy — function words, punctuation, repeated patterns — and can safely skip computation via the residual connection. Only the hard tokens — content words, reasoning steps, novel concepts — need the full attention and FFN treatment. The router learns this distinction automatically during training, and at inference you get a model that naturally allocates compute where it matters most. This is conditional computation at its most practical: no architectural changes beyond the router, no complex load balancing, just skip or compute.