Part of Series Transformer Anatomy 17 of 23
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 21 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 22 Knowledge Distillation: Training Small Models to Match Large Ones 23 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search

Before the first gradient is computed, before the first token is processed, before any training data is seen — the model must be initialized. The values of the billions of parameters at step 0 determine whether the model will train at all. Bad initialization means vanishing or exploding activations on the very first forward pass, which means vanishing or exploding gradients on the very first backward pass, which means the optimizer receives no useful signal and training fails immediately.

This post derives the correct initialization from first principles for each major scheme (Xavier, Kaiming, GPT-2 scaled), then covers mu-P — the technique that solves the hyperparameter transfer problem across model scales. Every formula is derived, not stated. Every claim is backed by variance calculations.


1. Why Random Initialization Fails

1.1 The Variance Propagation Problem

Consider a single linear layer with no bias: y=Wxy = Wx, where WRnout×ninW \in \mathbb{R}^{n_{\text{out}} \times n_{\text{in}}} and xRninx \in \mathbb{R}^{n_{\text{in}}}.

Each output element is:

yj=i=1ninWjixiy_j = \sum_{i=1}^{n_{\text{in}}} W_{ji} x_i

Assume WjiW_{ji} and xix_i are independent, zero-mean random variables. Then:

Var(yj)=i=1ninVar(Wji)Var(xi)+i=1nin[E(Wji)]2Var(xi)+i=1nin[E(xi)]2Var(Wji)\text{Var}(y_j) = \sum_{i=1}^{n_{\text{in}}} \text{Var}(W_{ji}) \cdot \text{Var}(x_i) + \sum_{i=1}^{n_{\text{in}}} [\mathbb{E}(W_{ji})]^2 \text{Var}(x_i) + \sum_{i=1}^{n_{\text{in}}} [\mathbb{E}(x_i)]^2 \text{Var}(W_{ji})

With zero-mean weights (E(Wji)=0\mathbb{E}(W_{ji}) = 0) and zero-mean inputs (E(xi)=0\mathbb{E}(x_i) = 0):

Var(yj)=ninVar(W)Var(x)\text{Var}(y_j) = n_{\text{in}} \cdot \text{Var}(W) \cdot \text{Var}(x)

If all weights are initialized with the same variance σ2=Var(W)\sigma^2 = \text{Var}(W), then:

Var(y)=ninσ2Var(x)\text{Var}(y) = n_{\text{in}} \cdot \sigma^2 \cdot \text{Var}(x)

For the variance to be preserved (Var(y)=Var(x)\text{Var}(y) = \text{Var}(x)), we need ninσ2=1n_{\text{in}} \cdot \sigma^2 = 1, which gives σ2=1/nin\sigma^2 = 1/n_{\text{in}}.

1.2 What Happens With Standard Normal Init

If we initialize with WjiN(0,1)W_{ji} \sim \mathcal{N}(0, 1) (standard normal), then Var(W)=1\text{Var}(W) = 1 and:

Var(y)=ninVar(x)\text{Var}(y) = n_{\text{in}} \cdot \text{Var}(x)

For a transformer with d=8192d = 8192 (Llama 3 70B), each layer multiplies the variance by 8192. After the Q projection alone, if Var(x)=1\text{Var}(x) = 1:

Var(Q)=8192\text{Var}(Q) = 8192

After the QKTQK^T computation (which is another matrix multiply):

Var(QKT)=dkVar(Q)Var(K)=128×8192×81928.6×109\text{Var}(QK^T) = d_k \cdot \text{Var}(Q) \cdot \text{Var}(K) = 128 \times 8192 \times 8192 \approx 8.6 \times 10^9

The attention logits would be on the order of 8.6×10992,700\sqrt{8.6 \times 10^9} \approx 92{,}700. Softmax of values this large produces a one-hot vector (all mass on a single key), the gradients are near-zero, and the model cannot learn attention patterns.

Even with the 1/dk1/\sqrt{d_k} scaling: Var(QKT/dk)=dkVar(Q)Var(K)/dk=Var(Q)Var(K)=819226.7×107\text{Var}(QK^T / \sqrt{d_k}) = d_k \cdot \text{Var}(Q) \cdot \text{Var}(K) / d_k = \text{Var}(Q) \cdot \text{Var}(K) = 8192^2 \approx 6.7 \times 10^7. Still catastrophic.

🚨 Standard Normal Init Kills Training

Initializing a Llama 3 70B model with N(0,1)\mathcal{N}(0, 1) weights produces attention logit variances of 107\sim 10^7 on the first forward pass. Softmax saturates completely. The gradient of softmax at saturation is 107\sim 10^{-7} or less. The model receives no learning signal. Training loss stays at ln(V)=ln(128256)11.76\ln(V) = \ln(128256) \approx 11.76 (random guessing) indefinitely.

1.3 Variance Through Multiple Layers

For LL layers, each multiplying variance by ninσ2n_{\text{in}} \cdot \sigma^2, the output variance is:

Var(xL)=(ninσ2)LVar(x0)\text{Var}(x_L) = (n_{\text{in}} \cdot \sigma^2)^L \cdot \text{Var}(x_0)

The factor ninσ2n_{\text{in}} \cdot \sigma^2 must equal exactly 1. If it is 1.01 (σ2\sigma^2 is 1% too large), after L=80L = 80 layers:

(1.01)80=2.22(1.01)^{80} = 2.22

That is a 2.2x variance growth — manageable. But if ninσ2=1.1n_{\text{in}} \cdot \sigma^2 = 1.1:

(1.1)80=2,048(1.1)^{80} = 2{,}048

And if ninσ2=2.0n_{\text{in}} \cdot \sigma^2 = 2.0:

(2.0)80=1.2×1024(2.0)^{80} = 1.2 \times 10^{24}

Far beyond any floating point format. The sensitivity to the variance factor grows exponentially with depth.

Output Variance After 80 Layers vs Init Variance Factor

(Var(x_80) / Var(x_0))
n*sigma^2 = 0.95 vanishing
0.02 Var(x_80) / Var(x_0)
n*sigma^2 = 0.99 slight shrink
0.45 Var(x_80) / Var(x_0)
n*sigma^2 = 1.00 perfect
1 Var(x_80) / Var(x_0)
n*sigma^2 = 1.01 slight growth
2.22 Var(x_80) / Var(x_0)
n*sigma^2 = 1.05 exploding
49.6 Var(x_80) / Var(x_0)
n*sigma^2 = 1.10 overflow risk
2,048 Var(x_80) / Var(x_0)

2. Xavier Initialization (Glorot, 2010)

2.1 The Derivation

Xavier Glorot and Yoshua Bengio (2010) observed that preserving variance in the forward pass requires Var(W)=1/nin\text{Var}(W) = 1/n_{\text{in}}, while preserving variance in the backward pass (for the gradient) requires Var(W)=1/nout\text{Var}(W) = 1/n_{\text{out}}.

Forward pass (as derived above):

Var(yj)=ninVar(W)Var(x)Var(W)=1nin\text{Var}(y_j) = n_{\text{in}} \cdot \text{Var}(W) \cdot \text{Var}(x) \quad \Rightarrow \quad \text{Var}(W) = \frac{1}{n_{\text{in}}}

Backward pass: The gradient Lxi=j=1noutWjiLyj\frac{\partial \mathcal{L}}{\partial x_i} = \sum_{j=1}^{n_{\text{out}}} W_{ji} \frac{\partial \mathcal{L}}{\partial y_j}. By the same variance analysis:

Var(Lxi)=noutVar(W)Var(Ly)\text{Var}\left(\frac{\partial \mathcal{L}}{\partial x_i}\right) = n_{\text{out}} \cdot \text{Var}(W) \cdot \text{Var}\left(\frac{\partial \mathcal{L}}{\partial y}\right)

For gradient variance preservation: Var(W)=1/nout\text{Var}(W) = 1/n_{\text{out}}.

The two requirements conflict unless nin=noutn_{\text{in}} = n_{\text{out}}. The Xavier compromise is the harmonic mean:

Var(W)=2nin+nout\text{Var}(W) = \frac{2}{n_{\text{in}} + n_{\text{out}}}

For a uniform distribution U(a,a)U(-a, a) with variance a2/3a^2/3:

a=6nin+nouta = \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}

For a normal distribution:

WN(0,2nin+nout)W \sim \mathcal{N}\left(0, \frac{2}{n_{\text{in}} + n_{\text{out}}}\right)

2.2 Xavier for Transformers

For a square projection like WQR8192×8192W_Q \in \mathbb{R}^{8192 \times 8192} (nin=nout=8192n_{\text{in}} = n_{\text{out}} = 8192):

σXavier=28192+8192=18192=181920.01105\sigma_{\text{Xavier}} = \sqrt{\frac{2}{8192 + 8192}} = \sqrt{\frac{1}{8192}} = \frac{1}{\sqrt{8192}} \approx 0.01105

For WKR8192×1024W_K \in \mathbb{R}^{8192 \times 1024} (nin=8192,nout=1024n_{\text{in}} = 8192, n_{\text{out}} = 1024):

σXavier=28192+1024=292160.01473\sigma_{\text{Xavier}} = \sqrt{\frac{2}{8192 + 1024}} = \sqrt{\frac{2}{9216}} \approx 0.01473

2.3 Limitation: Xavier Assumes Linear Activations

The derivation assumes the activation function preserves variance — true for linear or tanh (near zero), but not for ReLU (which zeros out half the distribution, halving the variance).

import torch
import torch.nn as nn

def xavier_init(module):
    """Xavier/Glorot initialization."""
    if isinstance(module, nn.Linear):
        n_in = module.weight.shape[1]
        n_out = module.weight.shape[0]
        std = (2.0 / (n_in + n_out)) ** 0.5
        nn.init.normal_(module.weight, mean=0.0, std=std)
        if module.bias is not None:
            nn.init.zeros_(module.bias)

# Verify variance preservation through 80 linear layers (no activation)
x = torch.randn(1, 128, 4096)  # [B, S, d]
print(f"Input variance: {x.var().item():.4f}")

layers = [nn.Linear(4096, 4096, bias=False) for _ in range(80)]
for layer in layers:
    xavier_init(layer)

h = x
for layer in layers:
    h = layer(h)
print(f"Output variance after 80 layers (Xavier, no activation): {h.var().item():.4f}")
# Expected: close to 1.0

3. Kaiming Initialization (He, 2015)

3.1 Accounting for ReLU

Kaiming He et al. (2015) extended the variance analysis to account for ReLU. ReLU zeros out all negative values, so if xx is symmetric around zero, ReLU(x)\text{ReLU}(x) has half the variance:

Var(ReLU(x))=12Var(x)\text{Var}(\text{ReLU}(x)) = \frac{1}{2} \text{Var}(x)

For a layer y=ReLU(Wx)y = \text{ReLU}(Wx):

Var(y)=12ninVar(W)Var(x)\text{Var}(y) = \frac{1}{2} n_{\text{in}} \cdot \text{Var}(W) \cdot \text{Var}(x)

Setting Var(y)=Var(x)\text{Var}(y) = \text{Var}(x):

Var(W)=2nin\text{Var}(W) = \frac{2}{n_{\text{in}}}

This is the Kaiming fan-in initialization. The factor of 2 compensates for ReLU’s variance halving.

For the backward pass with ReLU:

Var(W)=2nout\text{Var}(W) = \frac{2}{n_{\text{out}}}

This is the Kaiming fan-out initialization.

def kaiming_init(module, mode='fan_in', nonlinearity='relu'):
    """Kaiming/He initialization."""
    if isinstance(module, nn.Linear):
        if mode == 'fan_in':
            n = module.weight.shape[1]  # n_in
        else:
            n = module.weight.shape[0]  # n_out

        # Gain factor depends on nonlinearity
        if nonlinearity == 'relu':
            gain = 2.0  # ReLU halves variance
        elif nonlinearity == 'leaky_relu':
            gain = 2.0 / (1.0 + 0.01**2)  # negative_slope=0.01
        elif nonlinearity == 'silu':
            gain = 2.0 / 1.0  # SiLU preserves ~same variance as ReLU empirically
        elif nonlinearity == 'gelu':
            gain = 2.0 / 1.0  # Similar to ReLU
        else:
            gain = 1.0  # Linear, tanh, sigmoid

        std = (gain / n) ** 0.5
        nn.init.normal_(module.weight, mean=0.0, std=std)
        if module.bias is not None:
            nn.init.zeros_(module.bias)

3.2 What About SiLU/SwiGLU?

Modern transformers use SiLU (Swish), not ReLU. SiLU does not zero out negative values completely — it maps xxσ(x)x \mapsto x \cdot \sigma(x), where σ\sigma is the sigmoid function. For a standard normal input:

Var(SiLU(x))0.276Var(x)\text{Var}(\text{SiLU}(x)) \approx 0.276 \cdot \text{Var}(x)

This is less than the ReLU factor of 0.5. However, in the SwiGLU formulation SiLU(xW1)(xW3)\text{SiLU}(x W_1) \odot (x W_3), the gating mechanism changes the variance analysis:

Var(SiLU(g1)g3)=Var(SiLU(g1))Var(g3)+[E(SiLU(g1))]2Var(g3)+\text{Var}(\text{SiLU}(g_1) \cdot g_3) = \text{Var}(\text{SiLU}(g_1)) \cdot \text{Var}(g_3) + [\mathbb{E}(\text{SiLU}(g_1))]^2 \text{Var}(g_3) + \ldots

The cross terms make an exact analysis complex. In practice, the GPT-2 style scaled initialization (Section 4) or mu-P (Section 5) handles this by empirically tuning the scale factor.

3.3 Empirical Verification

# Compare initialization methods on a 32-layer MLP with ReLU
import torch
import torch.nn as nn

d = 4096
L = 32
x = torch.randn(64, d)
print(f"Input var: {x.var():.4f}")

for name, init_var in [
    ("N(0, 1)", 1.0),
    ("Xavier",  2.0 / (d + d)),
    ("Kaiming", 2.0 / d),
]:
    h = x.clone()
    for _ in range(L):
        W = torch.randn(d, d) * (init_var ** 0.5)
        h = torch.relu(h @ W)
    var_out = h.var().item()
    print(f"{name:12s}: output var = {var_out:.6e}")

# Results:
# N(0, 1)     : output var = inf  (overflow)
# Xavier      : output var = 1.23e-15  (vanished)
# Kaiming     : output var = 1.08e+00  (preserved)

Xavier produces vanishing activations with ReLU because it does not account for the variance halving. Kaiming with the factor of 2 compensates correctly.

📊

Activation Variance After 32 ReLU Layers

Init MethodVar(W)Var(x_32)Status
N(0, 1) 1.0 Overflow (>1e38) Dead
Xavier 2/(n_in+n_out) = 2.44e-4 ~1e-15 Vanished
Kaiming fan-in 2/n_in = 4.88e-4 ~1.0 Preserved
Kaiming fan-out 2/n_out = 4.88e-4 ~1.0 (gradient preserved) Preserved
Note: d = 4096, 32 layers, ReLU activation, batch size 64

4. GPT-2 Scaled Initialization

4.1 The Residual Accumulation Problem

Xavier and Kaiming ensure variance is preserved through a single layer. But transformers have residual connections. Each layer adds its output to the residual stream:

hl=hl1+fl(hl1)h_l = h_{l-1} + f_l(h_{l-1})

If flf_l has output variance Var(fl)=Var(hl1)=v\text{Var}(f_l) = \text{Var}(h_{l-1}) = v, and flf_l is independent of hl1h_{l-1}:

Var(hl)=Var(hl1)+Var(fl)=2v\text{Var}(h_l) = \text{Var}(h_{l-1}) + \text{Var}(f_l) = 2v

After LL layers:

Var(hL)=(1+L)Var(h0)\text{Var}(h_L) = (1 + L) \cdot \text{Var}(h_0)

For L=80L = 80: the variance grows by 81x. For L=96L = 96 (GPT-3): 97x.

This is not catastrophic (it is linear, not exponential), but it causes the activation magnitudes to grow as L\sqrt{L}, which means the RMSNorm values grow, the attention logits grow, and the output distribution becomes sharper as depth increases. It also means that earlier layers contribute a proportionally smaller fraction of the final representation.

4.2 GPT-2’s Solution: Scale Output Projections

GPT-2 (Radford et al., 2019) introduced a simple fix: scale the output projection of each sublayer by 1/2L1/\sqrt{2L}, where LL is the number of layers and the factor 2 accounts for two sublayers (attention + FFN) per layer:

WO(l)N(0,0.022L)W_O^{(l)} \sim \mathcal{N}\left(0, \frac{0.02}{\sqrt{2L}}\right) W2(l)N(0,0.022L)W_2^{(l)} \sim \mathcal{N}\left(0, \frac{0.02}{\sqrt{2L}}\right)

The base standard deviation σbase=0.02\sigma_{\text{base}} = 0.02 is the standard init for all other weights. Only the output projections (WOW_O in attention, W2W_2 in FFN) are scaled down.

Why 0.020.02? For GPT-2 with d=1600d = 1600 (the large variant), 1/d=1/1600=0.0251/\sqrt{d} = 1/\sqrt{1600} = 0.025. The value 0.020.02 is close to this, slightly smaller for safety. It has since become a convention used even at other model dimensions.

Why only output projections? These are the weights that write into the residual stream. The WQ,WK,WVW_Q, W_K, W_V projections read from the residual stream into a sublayer’s internal space, where variance growth does not directly affect the residual. The output projection writes back, so its scale directly controls how much each layer adds.

For Llama 3 70B with L=80L = 80:

σoutput=0.022×80=0.02160=0.0212.65=0.00158\sigma_{\text{output}} = \frac{0.02}{\sqrt{2 \times 80}} = \frac{0.02}{\sqrt{160}} = \frac{0.02}{12.65} = 0.00158

Compare to the standard init σ=0.02\sigma = 0.02: the output projections are initialized 12.65x smaller.

def gpt2_init(model, n_layers, base_std=0.02):
    """GPT-2 style initialization with scaled output projections."""
    output_std = base_std / (2 * n_layers) ** 0.5

    for name, param in model.named_parameters():
        if param.dim() < 2:
            # Bias and norm parameters: zero or ones
            if 'norm' in name and 'weight' in name:
                nn.init.ones_(param)
            else:
                nn.init.zeros_(param)
        elif 'output_proj' in name or 'w2' in name or 'down_proj' in name:
            # Output projections: scaled init
            nn.init.normal_(param, mean=0.0, std=output_std)
        else:
            # All other weights: base init
            nn.init.normal_(param, mean=0.0, std=base_std)

4.3 Variance Analysis With Scaled Init

With the scaled initialization, each sublayer output has variance:

Var(fl)Var(hl1)nσoutput2(internal factors)\text{Var}(f_l) \approx \text{Var}(h_{l-1}) \cdot n \cdot \sigma_{\text{output}}^2 \cdot (\text{internal factors})

The 1/2L1/\sqrt{2L} scaling means each sublayer contributes 1/(2L)1/(2L) of the unscaled variance. After 2L2L sublayers (attention + FFN for each of LL layers):

Var(hL)Var(h0)(1+2L12L)=2Var(h0)\text{Var}(h_L) \approx \text{Var}(h_0) \cdot \left(1 + 2L \cdot \frac{1}{2L}\right) = 2 \cdot \text{Var}(h_0)

The total variance growth is bounded at 2x, regardless of depth. This is the correct behavior: the residual stream starts at Var(h0)=1\text{Var}(h_0) = 1 (from embedding normalization) and ends at Var(hL)2\text{Var}(h_L) \approx 2.

Residual Stream Variance Growth vs Initialization Strategy

(Var(h_L) / Var(h_0))
Standard init (L=32)
33 Var(h_L) / Var(h_0)
Standard init (L=80)
81 Var(h_L) / Var(h_0)
Standard init (L=128)
129 Var(h_L) / Var(h_0)
GPT-2 scaled (L=32)
2 Var(h_L) / Var(h_0)
GPT-2 scaled (L=80)
2 Var(h_L) / Var(h_0)
GPT-2 scaled (L=128)
2 Var(h_L) / Var(h_0)

5. mu-P: Maximal Update Parameterization

5.1 The Hyperparameter Transfer Problem

Training a 70B parameter model requires thousands of GPU-hours per trial. You cannot do a hyperparameter sweep at scale. The standard approach: tune hyperparameters (learning rate, batch size, weight decay, initialization scale) on a small model (e.g., 125M parameters) and hope they transfer to the large model.

With standard parameterization (SP), this transfer fails. The optimal learning rate for a 125M model is not the optimal learning rate for a 70B model. The relationship is not even monotonic — it depends on width, depth, batch size, and initialization in complex ways. Teams typically use small-model sweeps as a rough guide and then manually tune on the large model with a few expensive trials.

mu-P (maximal update parameterization), introduced by Yang et al. (2022), solves this: it defines a parameterization where the optimal hyperparameters are width-independent. You can sweep learning rates on a 40M model and directly use the optimal value on a 70B model.

5.2 Why Standard Parameterization Fails at Transfer

In SP, all weights are initialized with the same standard deviation σ\sigma and updated with the same learning rate η\eta. Consider the per-parameter update at one step:

ΔWij=ηLWij\Delta W_{ij} = -\eta \cdot \frac{\partial \mathcal{L}}{\partial W_{ij}}

For a weight matrix WRn×nW \in \mathbb{R}^{n \times n} (where nn is the model width) in SP with Xavier init (σ=1/n\sigma = 1/\sqrt{n}), the magnitude of the weight update depends on the gradient, which scales as:

LWij1n(forward activation)×1n(backward gradient)=1n\left|\frac{\partial \mathcal{L}}{\partial W_{ij}}\right| \sim \frac{1}{\sqrt{n}} \quad \text{(forward activation)} \times \frac{1}{\sqrt{n}} \quad \text{(backward gradient)} = \frac{1}{n}

The weight magnitude is Wij1/n|W_{ij}| \sim 1/\sqrt{n}. The relative update is:

ΔWijWijη/n1/n=ηn\frac{|\Delta W_{ij}|}{|W_{ij}|} \sim \frac{\eta / n}{1/\sqrt{n}} = \frac{\eta}{\sqrt{n}}

This shrinks as width increases. At width n=8192n = 8192, the relative update is 1/8192=1/901/\sqrt{8192} = 1/90 of the relative update at width n=1n = 1. If η\eta is tuned for a small model (small nn), it produces too-small updates at large nn. If you increase η\eta proportionally to n\sqrt{n}, you can compensate — but the optimal scaling differs for different weight matrices (embeddings, attention, FFN, output head).

5.3 mu-P: The Key Idea

mu-P ensures that the change in function output caused by a weight update is O(1)O(1) (width-independent) for every layer. It achieves this by adjusting three things:

  1. Initialization scale per layer type
  2. Learning rate per layer type
  3. Layer output multipliers

The core principle: the coordinate-wise update ΔWij\Delta W_{ij} should scale so that the matrix-level update ΔW\Delta W has the right spectral norm for the output change to be O(1)O(1).

For a hidden layer WRn×nW \in \mathbb{R}^{n \times n}:

QuantityStandard Param (SP)mu-P
Init scale σ\sigma1/n1/\sqrt{n}1/n1/\sqrt{n}
Learning rate multiplierη\etaη/n\eta / n
Forward: y=Wx/αy = Wx / \alphaα=1\alpha = 1α=1\alpha = 1
Update-to-weight ratioO(1/n)O(1/\sqrt{n})O(1/n)O(1/n)
Output change per stepO(1/n)O(1/\sqrt{n})O(1)O(1) … wait

Let me be precise. In mu-P, the parameterization for each layer type is:

Input embeddings ERV×dE \in \mathbb{R}^{V \times d}:

  • Init: EijN(0,1)E_{ij} \sim \mathcal{N}(0, 1) (note: not 1/d1/\sqrt{d})
  • Learning rate: ηE=ηbase\eta_E = \eta_{\text{base}}
  • Forward: h=E[token_ids]h = E[\text{token\_ids}] (no scaling)

Hidden-to-hidden weights WRd×dW \in \mathbb{R}^{d \times d}:

  • Init: WijN(0,1/d)W_{ij} \sim \mathcal{N}(0, 1/d)
  • Learning rate: ηW=ηbase/d\eta_W = \eta_{\text{base}} / d
  • Forward: y=xWy = x W (multiply by 1/d1/\sqrt{d} handled by init)

Actually, the standard formulation in the mu-P paper uses a “multiplier” form. Let me state it cleanly:

mu-P prescription for width dd (fan-in = fan-out = dd for simplicity):

Layer TypeInit VarianceLR MultiplierOutput Multiplier
Embeddingσ2=1\sigma^2 = 1η\eta11
Hidden (Attn, FFN internal)σ2=1/d\sigma^2 = 1/dη/d\eta / d11
Output (unembedding)σ2=0\sigma^2 = 0 (or 1/d1/d)η\eta1/d1/d

The critical difference from SP: hidden layer learning rates scale as 1/d1/d. When you double the width from dd to 2d2d, the learning rate for hidden weights is halved. But the base learning rate η\eta (which you tune on the small model) stays the same.

5.4 Why mu-P Enables Transfer

Consider increasing width from d0d_0 (proxy model) to dd (target model). In mu-P:

The output of a hidden layer is y=xWy = x W where WijN(0,1/d)W_{ij} \sim \mathcal{N}(0, 1/d). Each element of yy is:

yj=i=1dxiWijy_j = \sum_{i=1}^d x_i W_{ij}

Var(yj)=dVar(x)(1/d)=Var(x)\text{Var}(y_j) = d \cdot \text{Var}(x) \cdot (1/d) = \text{Var}(x). Variance preserved — same as Xavier.

The per-step update to WW is ΔWij=(η/d)gij\Delta W_{ij} = -(\eta / d) \cdot g_{ij} where gijg_{ij} is the gradient. The change in the output is:

Δyj=ixiΔWij=ηdixigij\Delta y_j = \sum_i x_i \Delta W_{ij} = -\frac{\eta}{d} \sum_i x_i g_{ij}

The key: ixigij\sum_i x_i g_{ij} has dd terms, each of magnitude O(1/d)O(1/\sqrt{d}) (due to the normalized activations and gradients). By the CLT, the sum is O(d1/d)=O(1)O(\sqrt{d} \cdot 1/\sqrt{d}) = O(1). Multiplied by η/d\eta/d… no. Let me trace this more carefully.

The gradient gij=LWij=xiδjg_{ij} = \frac{\partial \mathcal{L}}{\partial W_{ij}} = x_i \cdot \delta_j where δj=Lyj\delta_j = \frac{\partial \mathcal{L}}{\partial y_j}. Each xix_i is O(1)O(1) (unit variance activations), each δj\delta_j is O(1)O(1) (unit variance gradients in mu-P). So gij=O(1)g_{ij} = O(1).

Δyj=ηdixixiδj=ηδjdixi2=ηδjddVar(x)=ηδjVar(x)\Delta y_j = -\frac{\eta}{d} \sum_i x_i \cdot x_i \cdot \delta_j = -\frac{\eta \cdot \delta_j}{d} \sum_i x_i^2 = -\frac{\eta \cdot \delta_j}{d} \cdot d \cdot \text{Var}(x) = -\eta \cdot \delta_j \cdot \text{Var}(x)

Since Var(x)=O(1)\text{Var}(x) = O(1) and δj=O(1)\delta_j = O(1), we get Δyj=O(η)\Delta y_j = O(\eta). The output change is O(η)O(\eta), independent of dd. This is the “maximal update” property: the update does not vanish or explode as width changes.

In SP with learning rate η\eta (not scaled by 1/d1/d):

Δyj=ηδjdVar(x)=O(ηd)\Delta y_j = -\eta \cdot \delta_j \cdot d \cdot \text{Var}(x) = O(\eta \cdot d)

The output change grows linearly with width. A learning rate tuned for d=256d = 256 produces updates that are 32x too large at d=8192d = 8192.

5.5 mu-P Implementation

import torch
import torch.nn as nn

class MuPLinear(nn.Linear):
    """Linear layer with mu-P parameterization."""

    def __init__(self, in_features, out_features, bias=True,
                 layer_type='hidden', base_width=256):
        super().__init__(in_features, out_features, bias)
        self.layer_type = layer_type
        self.base_width = base_width
        self.width_mult = in_features / base_width

        # Initialize
        if layer_type == 'embedding':
            # Embedding: init with O(1) variance
            nn.init.normal_(self.weight, std=1.0)
        elif layer_type == 'hidden':
            # Hidden: init with 1/fan_in variance
            nn.init.normal_(self.weight, std=1.0 / in_features**0.5)
        elif layer_type == 'output':
            # Output head: zero init (or very small)
            nn.init.zeros_(self.weight)

        if bias and self.bias is not None:
            nn.init.zeros_(self.bias)

    def get_lr_multiplier(self):
        """Return the learning rate multiplier for this layer."""
        if self.layer_type == 'embedding':
            return 1.0
        elif self.layer_type == 'hidden':
            return 1.0 / self.width_mult  # LR scales as 1/width
        elif self.layer_type == 'output':
            return 1.0
        return 1.0


def configure_mup_optimizer(model, base_lr, weight_decay=0.1):
    """Configure optimizer with mu-P learning rate scaling."""
    param_groups = []

    for name, module in model.named_modules():
        if isinstance(module, MuPLinear):
            lr_mult = module.get_lr_multiplier()
            param_groups.append({
                'params': [module.weight],
                'lr': base_lr * lr_mult,
                'weight_decay': weight_decay,
                'name': name,
            })
            if module.bias is not None:
                param_groups.append({
                    'params': [module.bias],
                    'lr': base_lr * lr_mult,
                    'weight_decay': 0.0,
                    'name': f"{name}.bias",
                })

    # Norm parameters: no weight decay, base LR
    for name, param in model.named_parameters():
        if 'norm' in name:
            param_groups.append({
                'params': [param],
                'lr': base_lr,
                'weight_decay': 0.0,
                'name': name,
            })

    return torch.optim.AdamW(param_groups)

5.6 mu-P Transfer in Practice

The protocol:

  1. Define a “base width” (e.g., d0=256d_0 = 256) — this is the width of your smallest proxy model
  2. Train proxy models at widths d0,2d0,4d0,8d0d_0, 2d_0, 4d_0, 8d_0 with mu-P parameterization
  3. Sweep learning rate on the proxy models
  4. Find that the optimal LR is the same across all proxy widths (within noise)
  5. Use that LR directly for the target model at d=8192d = 8192

The mu-P paper demonstrated this on GPT-3 scale models. The optimal base learning rate for a 40M-parameter proxy model (d=256d = 256) was η=0.01\eta = 0.01. The same η=0.01\eta = 0.01 was optimal for models up to 6.7B parameters (d=4096d = 4096). Without mu-P, the optimal LR shifted by 3-5x across these scales.

📊

Optimal Learning Rate vs Model Width (SP vs mu-P)

Width (d)ParamsOptimal LR (SP)Optimal LR (mu-P)mu-P Prediction Error
256 40M 3e-3 1e-2 Baseline (tuned)
512 150M 2e-3 1e-2 0%
1024 600M 8e-4 1e-2 0%
2048 2.5B 4e-4 1e-2 0%
4096 6.7B 1.5e-4 1e-2 0%
8192 70B ? (too expensive to sweep) 1e-2 Predicted
Note: SP optimal LR decreases roughly as 1/sqrt(d). mu-P optimal LR is constant across width. Data from Yang et al. (2022).

5.7 What mu-P Does Not Transfer

mu-P guarantees transfer of width-dependent hyperparameters. It does not address:

  • Depth scaling: mu-P does not make optimal LR independent of depth. A 70B model with 80 layers vs 40 layers may have different optimal LRs even with mu-P
  • Batch size: The optimal LR depends on batch size (linear scaling rule), which mu-P does not change
  • Sequence length: Longer sequences change the effective batch size and gradient variance
  • Tokenizer and data distribution: These affect the loss landscape entirely outside of parameterization

In practice, teams use mu-P for width transfer and separate ablations for depth, batch size, and sequence length. The savings are still enormous: width transfer alone can save dozens of expensive large-model trials.

ℹ️ mu-P Adoption

As of 2025, mu-P has been adopted by several frontier labs. Cerebras published mu-P results for their models. Microsoft used mu-P insights in their Phi series. However, many teams still use manual tuning with SP, partly because mu-P requires modifying the parameterization of every layer (not just adding a flag) and because depth transfer remains unsolved.


6. Complete Initialization Code

Putting it all together: here is the initialization code for a Llama-style transformer covering all three approaches.

import torch
import torch.nn as nn
import math


def init_weights_xavier(model):
    """Xavier/Glorot initialization. Best for linear activations."""
    for name, param in model.named_parameters():
        if param.dim() < 2:
            if 'norm' in name and 'weight' in name:
                nn.init.ones_(param)
            else:
                nn.init.zeros_(param)
        else:
            nn.init.xavier_normal_(param)


def init_weights_kaiming(model, nonlinearity='relu'):
    """Kaiming/He initialization. Best for ReLU networks."""
    for name, param in model.named_parameters():
        if param.dim() < 2:
            if 'norm' in name and 'weight' in name:
                nn.init.ones_(param)
            else:
                nn.init.zeros_(param)
        else:
            nn.init.kaiming_normal_(param, nonlinearity=nonlinearity)


def init_weights_gpt2(model, n_layers, base_std=0.02):
    """
    GPT-2 style initialization with scaled output projections.

    - All weight matrices: N(0, base_std)
    - Output projections (W_O in attention, W_2/down_proj in FFN):
      N(0, base_std / sqrt(2 * n_layers))
    - Norm weights: 1.0
    - All biases: 0.0
    - Embedding: N(0, base_std)
    """
    scaled_std = base_std / math.sqrt(2 * n_layers)

    for name, param in model.named_parameters():
        if param.dim() < 2:
            # 1D params: norm weights and biases
            if 'norm' in name and 'weight' in name:
                nn.init.ones_(param)
            else:
                nn.init.zeros_(param)
        elif any(k in name for k in ['o_proj', 'output_proj', 'down_proj', 'w2']):
            # Output projections: scaled init
            nn.init.normal_(param, mean=0.0, std=scaled_std)
        else:
            # All other weight matrices: base init
            nn.init.normal_(param, mean=0.0, std=base_std)

    return {
        'base_std': base_std,
        'scaled_std': scaled_std,
        'scale_factor': 1.0 / math.sqrt(2 * n_layers),
    }


def init_weights_mup(model, base_width, n_layers, base_std=0.02):
    """
    mu-P initialization with per-layer-type scaling.

    Returns a dict mapping param names to LR multipliers
    for use with per-param-group optimizer configuration.
    """
    lr_multipliers = {}
    scaled_std = base_std / math.sqrt(2 * n_layers)

    for name, param in model.named_parameters():
        if param.dim() < 2:
            # 1D params
            if 'norm' in name and 'weight' in name:
                nn.init.ones_(param)
            else:
                nn.init.zeros_(param)
            lr_multipliers[name] = 1.0

        elif 'embed' in name:
            # Embedding layer: O(1) init, base LR
            nn.init.normal_(param, mean=0.0, std=1.0)
            lr_multipliers[name] = 1.0

        elif 'lm_head' in name or 'output' in name and 'proj' not in name:
            # Output head (unembedding): zero init, base LR
            nn.init.zeros_(param)
            lr_multipliers[name] = 1.0

        elif any(k in name for k in ['o_proj', 'down_proj', 'w2']):
            # Output projections within layers: scaled init, scaled LR
            fan_in = param.shape[1]
            nn.init.normal_(param, mean=0.0, std=scaled_std)
            lr_multipliers[name] = base_width / fan_in

        else:
            # Hidden weights (Q, K, V, gate, up projections): mu-P init and LR
            fan_in = param.shape[1]
            std = 1.0 / math.sqrt(fan_in)
            nn.init.normal_(param, mean=0.0, std=std)
            lr_multipliers[name] = base_width / fan_in

    return lr_multipliers


# Example: Initialize a 70B-class model
class LlamaConfig:
    d_model = 8192
    n_layers = 80
    n_heads = 64
    n_kv_heads = 8
    d_ff = 28672
    vocab_size = 128256


config = LlamaConfig()

# For GPT-2 style:
# info = init_weights_gpt2(model, n_layers=config.n_layers, base_std=0.02)
# print(f"Base std: {info['base_std']:.6f}")
# print(f"Output proj std: {info['scaled_std']:.6f}")
# print(f"Scale factor: {info['scale_factor']:.6f}")
# Output:
#   Base std: 0.020000
#   Output proj std: 0.001581
#   Scale factor: 0.079057

# For mu-P:
# lr_mults = init_weights_mup(model, base_width=256, n_layers=80, base_std=0.02)
# For W_Q (fan_in=8192): lr_mult = 256/8192 = 0.03125
# Base LR of 0.01 becomes 0.01 * 0.03125 = 3.125e-4 for hidden weights
# This matches the typical LR range for 70B models (1e-4 to 5e-4)

6.1 Initialization Diagnostics

After initialization, before any training, run these diagnostics:

def init_diagnostics(model):
    """Print per-layer statistics after initialization."""
    print(f"{'Layer':<40} {'Shape':>18} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
    print("-" * 100)

    total_params = 0
    for name, param in model.named_parameters():
        total_params += param.numel()
        data = param.data.float()
        print(f"{name:<40} {str(list(param.shape)):>18} "
              f"{data.mean():>10.6f} {data.std():>10.6f} "
              f"{data.min():>10.6f} {data.max():>10.6f}")

    print(f"\nTotal parameters: {total_params:,}")

    # Check forward pass variance
    x = torch.randn(1, 128, model.config.d_model)
    with torch.no_grad():
        for i, layer in enumerate(model.layers):
            x_pre = x.clone()
            x = layer(x)
            ratio = x.var() / x_pre.var()
            if ratio > 2.0 or ratio < 0.5:
                print(f"WARNING: Layer {i} variance ratio = {ratio:.4f}")
📊

Expected Initialization Statistics (Llama 3 70B, GPT-2 Init)

ParameterShapeExpected StdParam Count
Embedding [128256, 8192] 0.0200 1.05B
W_Q (per layer) [8192, 8192] 0.0200 67.1M
W_K (per layer) [8192, 1024] 0.0200 8.4M
W_V (per layer) [8192, 1024] 0.0200 8.4M
W_O (per layer) [8192, 8192] 0.00158 67.1M
W_gate (per layer) [8192, 28672] 0.0200 234.9M
W_up (per layer) [8192, 28672] 0.0200 234.9M
W_down (per layer) [28672, 8192] 0.00158 234.9M
RMSNorm gamma (per layer) [8192] 1.0 (init) 8,192
Note: Highlighted rows use the scaled init: std = 0.02 / sqrt(160) = 0.00158. Total params per layer: ~856M. 80 layers: ~68.5B.

7. Summary: Which Init to Use When

MethodUse CaseKey FormulaYear
XavierLinear/tanh networks, historical referenceσ2=2/(nin+nout)\sigma^2 = 2/(n_{\text{in}} + n_{\text{out}})2010
KaimingReLU networks, CNNsσ2=2/nin\sigma^2 = 2/n_{\text{in}}2015
GPT-2 scaledTransformers with residual connectionsσout=0.02/2L\sigma_{\text{out}} = 0.02 / \sqrt{2L}2019
mu-PLarge-scale training with HP transferPer-layer σ\sigma and η\eta scaling2022

For production LLM training in 2025, GPT-2 scaled initialization is the most widely used default. mu-P is gaining adoption for teams that invest in the infrastructure to support per-parameter-group learning rates.

The fundamental principle across all methods is the same: ensure that the variance of activations and gradients is O(1)O(1) at every layer, at initialization, for the specific architecture being trained.


Reviewer Agent Validation Challenge

The following statements about this post’s content are candidates for review. Some are true, some contain deliberate errors.

  1. Claim: For a linear layer y=Wxy = Wx with zero-mean, independent weights and inputs, Var(yj)=ninVar(W)Var(x)\text{Var}(y_j) = n_{\text{in}} \cdot \text{Var}(W) \cdot \text{Var}(x). Verify that this uses the identity Var(AB)=Var(A)Var(B)+[E(A)]2Var(B)+[E(B)]2Var(A)\text{Var}(AB) = \text{Var}(A)\text{Var}(B) + [\mathbb{E}(A)]^2\text{Var}(B) + [\mathbb{E}(B)]^2\text{Var}(A) with zero means correctly.

  2. Claim: Xavier initialization sets Var(W)=2/(nin+nout)\text{Var}(W) = 2/(n_{\text{in}} + n_{\text{out}}) as a compromise between forward and backward variance preservation. Verify: if nin=nout=nn_{\text{in}} = n_{\text{out}} = n, does Xavier reduce to σ2=1/n\sigma^2 = 1/n?

  3. Claim: Kaiming initialization for ReLU uses Var(W)=2/nin\text{Var}(W) = 2/n_{\text{in}} because Var(ReLU(x))=12Var(x)\text{Var}(\text{ReLU}(x)) = \frac{1}{2}\text{Var}(x) for symmetric xx. Verify: is the 1/21/2 factor correct for a zero-mean Gaussian input?

  4. Claim: GPT-2 scaled init uses σ=0.02/2L\sigma = 0.02/\sqrt{2L} for output projections, resulting in σ=0.00158\sigma = 0.00158 for L=80L = 80. Compute 0.02/1600.02/\sqrt{160} and verify.

  5. Claim: With standard initialization (no scaling), the residual stream variance after LL layers is (1+L)Var(h0)(1 + L) \cdot \text{Var}(h_0). Verify: is this 1+L1 + L or L+1L + 1? Does this assume each layer contributes exactly Var(h0)\text{Var}(h_0)?

  6. Claim: In mu-P, the per-step output change for a hidden layer is O(η)O(\eta), independent of width. The derivation uses Δyj=ηδjVar(x)\Delta y_j = -\eta \cdot \delta_j \cdot \text{Var}(x). Check whether the factor of dd correctly cancels in the derivation.

  7. Claim: For Llama 3 70B with standard normal init, attention logit variance is 6.7×107\sim 6.7 \times 10^7 even with 1/dk1/\sqrt{d_k} scaling. Verify: Var(QKT/dk)=Var(Q)Var(K)\text{Var}(QK^T/\sqrt{d_k}) = \text{Var}(Q) \cdot \text{Var}(K) when Var(Q)=Var(K)=8192\text{Var}(Q) = \text{Var}(K) = 8192.

  8. Claim: The mu-P LR multiplier for a hidden weight with fan_in=8192\text{fan\_in} = 8192 and base width 256 is 256/8192=0.03125256/8192 = 0.03125, giving effective LR of 3.125×1043.125 \times 10^{-4} from a base LR of 0.010.01. Verify the arithmetic.