Here’s a wasteful truth: in a standard transformer, the word “the” gets exactly the same computational treatment as your model’s most complex reasoning step. All 80 layers of Llama 70B, burning through trillions of FLOPs, processing a function word that contributes almost nothing to the final output. Mixture of Depths (MoD) fixes this by giving each layer a router that makes a simple binary decision: should this token compute through attention and FFN, or should it skip directly via the residual connection? The result is 30-50% fewer FLOPs at inference with minimal quality loss, achieved by letting simple tokens coast through most layers while complex tokens get the full compute budget.
The Core Idea
At each layer , a lightweight router produces a binary decision per token:
If : the token passes through attention + FFN normally. If : the token passes through only the residual connection (), skipping all computation in layer .
import torch
import torch.nn as nn
class MoDRouter(nn.Module):
"""Per-layer router: decides which tokens compute, which skip."""
def __init__(self, d_model, capacity_ratio=0.5):
super().__init__()
self.gate = nn.Linear(d_model, 1, bias=True)
self.capacity_ratio = capacity_ratio # Fraction of tokens that compute
def forward(self, hidden_states, training=True):
"""
hidden_states: [B, S, d_model]
Returns: mask [B, S] with 1 for compute, 0 for skip
"""
scores = self.gate(hidden_states).squeeze(-1) # [B, S]
if training:
# Top-k: select top capacity_ratio fraction of tokens
k = max(1, int(scores.shape[1] * self.capacity_ratio))
topk_vals, topk_idx = scores.topk(k, dim=-1)
mask = torch.zeros_like(scores)
mask.scatter_(1, topk_idx, 1.0)
# Straight-through estimator for gradients
return mask + (scores.sigmoid() - scores.sigmoid().detach())
else:
# At inference: threshold-based decision
return (scores > 0).float()
class MoDTransformerLayer(nn.Module):
"""Transformer layer with Mixture of Depths routing."""
def __init__(self, attention, ffn, d_model, capacity_ratio=0.5):
super().__init__()
self.attention = attention
self.ffn = ffn
self.norm1 = nn.RMSNorm(d_model)
self.norm2 = nn.RMSNorm(d_model)
self.router = MoDRouter(d_model, capacity_ratio)
def forward(self, x):
# Router decides which tokens compute
mask = self.router(x) # [B, S], binary
# Only compute attention + FFN for selected tokens
if mask.sum() > 0:
# Gather selected tokens
selected_idx = mask.nonzero(as_tuple=True)
selected = x[selected_idx] # [num_selected, d_model]
# Attention (on selected tokens only)
normed = self.norm1(selected.unsqueeze(0))
attn_out = self.attention(normed)
selected = selected + attn_out.squeeze(0)
# FFN (on selected tokens only)
normed = self.norm2(selected.unsqueeze(0))
ffn_out = self.ffn(normed)
selected = selected + ffn_out.squeeze(0)
# Scatter back
x = x.clone()
x[selected_idx] = selected
# Skipped tokens pass through unchanged (residual only)
return x
Mixture of Depths: FLOPs Savings vs Quality
| Capacity Ratio | FLOPs Reduction | Perplexity Impact | Tokens Skipping per Layer |
|---|---|---|---|
| 1.0 (all compute) | 0% | Baseline | 0% |
| 0.75 (75% compute) | 18% | +0.1 PPL | 25% skip |
| 0.50 (50% compute) | 35% | +0.3 PPL | 50% skip |
| 0.25 (25% compute) | 52% | +0.8 PPL | 75% skip |
| 0.12 (12% compute) | 65% | +2.1 PPL | 88% skip |
At capacity_ratio = 0.5, half the tokens skip each layer. This saves 35% of total FLOPs with only 0.3 perplexity points degradation. The router learns that function words, punctuation, and repetitive patterns can safely skip most layers, while content words, reasoning steps, and novel concepts need full processing.
What the Router Learns
Analysis of trained MoD routers shows consistent patterns:
- Always compute (capacity_score consistently high): First token, last token, tokens after question marks, numbers in arithmetic, code keywords
- Usually skip (capacity_score consistently low): Articles (“the”, “a”), prepositions (“of”, “in”), common conjunctions, repeated whitespace
- Layer-dependent: Some tokens compute in early layers (syntactic processing) but skip late layers (already resolved). Others skip early layers but compute in late layers (semantic integration).
Token Skip Rate by Token Type (Llama 7B + MoD, capacity=0.5)
(% of layers skipped)Connection to MoE
MoD is orthogonal to MoE: MoE selects WHICH expert processes each token (conditional routing to different FFNs), while MoD selects WHETHER any processing happens at all (conditional compute vs skip). They can be combined: a token first decides whether to compute (MoD router), then if yes, which expert to use (MoE router).
Summary
Mixture of Depths challenges a fundamental assumption in transformer design: that every token needs every layer. By adding a lightweight router that decides per-token, per-layer whether to compute or skip, MoD reduces inference FLOPs by 30-50% while maintaining quality within 0.3 perplexity points of the full model. The key insight is that most tokens in most layers are easy — function words, punctuation, repeated patterns — and can safely skip computation via the residual connection. Only the hard tokens — content words, reasoning steps, novel concepts — need the full attention and FFN treatment. The router learns this distinction automatically during training, and at inference you get a model that naturally allocates compute where it matters most. This is conditional computation at its most practical: no architectural changes beyond the router, no complex load balancing, just skip or compute.