Before the first gradient is computed, before the first token is processed, before any training data is seen — the model must be initialized. The values of the billions of parameters at step 0 determine whether the model will train at all. Bad initialization means vanishing or exploding activations on the very first forward pass, which means vanishing or exploding gradients on the very first backward pass, which means the optimizer receives no useful signal and training fails immediately.
This post derives the correct initialization from first principles for each major scheme (Xavier, Kaiming, GPT-2 scaled), then covers mu-P — the technique that solves the hyperparameter transfer problem across model scales. Every formula is derived, not stated. Every claim is backed by variance calculations.
1. Why Random Initialization Fails
1.1 The Variance Propagation Problem
Consider a single linear layer with no bias: , where and .
Each output element is:
Assume and are independent, zero-mean random variables. Then:
With zero-mean weights () and zero-mean inputs ():
If all weights are initialized with the same variance , then:
For the variance to be preserved (), we need , which gives .
1.2 What Happens With Standard Normal Init
If we initialize with (standard normal), then and:
For a transformer with (Llama 3 70B), each layer multiplies the variance by 8192. After the Q projection alone, if :
After the computation (which is another matrix multiply):
The attention logits would be on the order of . Softmax of values this large produces a one-hot vector (all mass on a single key), the gradients are near-zero, and the model cannot learn attention patterns.
Even with the scaling: . Still catastrophic.
Initializing a Llama 3 70B model with weights produces attention logit variances of on the first forward pass. Softmax saturates completely. The gradient of softmax at saturation is or less. The model receives no learning signal. Training loss stays at (random guessing) indefinitely.
1.3 Variance Through Multiple Layers
For layers, each multiplying variance by , the output variance is:
The factor must equal exactly 1. If it is 1.01 ( is 1% too large), after layers:
That is a 2.2x variance growth — manageable. But if :
And if :
Far beyond any floating point format. The sensitivity to the variance factor grows exponentially with depth.
Output Variance After 80 Layers vs Init Variance Factor
(Var(x_80) / Var(x_0))2. Xavier Initialization (Glorot, 2010)
2.1 The Derivation
Xavier Glorot and Yoshua Bengio (2010) observed that preserving variance in the forward pass requires , while preserving variance in the backward pass (for the gradient) requires .
Forward pass (as derived above):
Backward pass: The gradient . By the same variance analysis:
For gradient variance preservation: .
The two requirements conflict unless . The Xavier compromise is the harmonic mean:
For a uniform distribution with variance :
For a normal distribution:
2.2 Xavier for Transformers
For a square projection like ():
For ():
2.3 Limitation: Xavier Assumes Linear Activations
The derivation assumes the activation function preserves variance — true for linear or tanh (near zero), but not for ReLU (which zeros out half the distribution, halving the variance).
import torch
import torch.nn as nn
def xavier_init(module):
"""Xavier/Glorot initialization."""
if isinstance(module, nn.Linear):
n_in = module.weight.shape[1]
n_out = module.weight.shape[0]
std = (2.0 / (n_in + n_out)) ** 0.5
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
# Verify variance preservation through 80 linear layers (no activation)
x = torch.randn(1, 128, 4096) # [B, S, d]
print(f"Input variance: {x.var().item():.4f}")
layers = [nn.Linear(4096, 4096, bias=False) for _ in range(80)]
for layer in layers:
xavier_init(layer)
h = x
for layer in layers:
h = layer(h)
print(f"Output variance after 80 layers (Xavier, no activation): {h.var().item():.4f}")
# Expected: close to 1.0
3. Kaiming Initialization (He, 2015)
3.1 Accounting for ReLU
Kaiming He et al. (2015) extended the variance analysis to account for ReLU. ReLU zeros out all negative values, so if is symmetric around zero, has half the variance:
For a layer :
Setting :
This is the Kaiming fan-in initialization. The factor of 2 compensates for ReLU’s variance halving.
For the backward pass with ReLU:
This is the Kaiming fan-out initialization.
def kaiming_init(module, mode='fan_in', nonlinearity='relu'):
"""Kaiming/He initialization."""
if isinstance(module, nn.Linear):
if mode == 'fan_in':
n = module.weight.shape[1] # n_in
else:
n = module.weight.shape[0] # n_out
# Gain factor depends on nonlinearity
if nonlinearity == 'relu':
gain = 2.0 # ReLU halves variance
elif nonlinearity == 'leaky_relu':
gain = 2.0 / (1.0 + 0.01**2) # negative_slope=0.01
elif nonlinearity == 'silu':
gain = 2.0 / 1.0 # SiLU preserves ~same variance as ReLU empirically
elif nonlinearity == 'gelu':
gain = 2.0 / 1.0 # Similar to ReLU
else:
gain = 1.0 # Linear, tanh, sigmoid
std = (gain / n) ** 0.5
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)
3.2 What About SiLU/SwiGLU?
Modern transformers use SiLU (Swish), not ReLU. SiLU does not zero out negative values completely — it maps , where is the sigmoid function. For a standard normal input:
This is less than the ReLU factor of 0.5. However, in the SwiGLU formulation , the gating mechanism changes the variance analysis:
The cross terms make an exact analysis complex. In practice, the GPT-2 style scaled initialization (Section 4) or mu-P (Section 5) handles this by empirically tuning the scale factor.
3.3 Empirical Verification
# Compare initialization methods on a 32-layer MLP with ReLU
import torch
import torch.nn as nn
d = 4096
L = 32
x = torch.randn(64, d)
print(f"Input var: {x.var():.4f}")
for name, init_var in [
("N(0, 1)", 1.0),
("Xavier", 2.0 / (d + d)),
("Kaiming", 2.0 / d),
]:
h = x.clone()
for _ in range(L):
W = torch.randn(d, d) * (init_var ** 0.5)
h = torch.relu(h @ W)
var_out = h.var().item()
print(f"{name:12s}: output var = {var_out:.6e}")
# Results:
# N(0, 1) : output var = inf (overflow)
# Xavier : output var = 1.23e-15 (vanished)
# Kaiming : output var = 1.08e+00 (preserved)
Xavier produces vanishing activations with ReLU because it does not account for the variance halving. Kaiming with the factor of 2 compensates correctly.
Activation Variance After 32 ReLU Layers
| Init Method | Var(W) | Var(x_32) | Status |
|---|---|---|---|
| N(0, 1) | 1.0 | Overflow (>1e38) | Dead |
| Xavier | 2/(n_in+n_out) = 2.44e-4 | ~1e-15 | Vanished |
| Kaiming fan-in | 2/n_in = 4.88e-4 | ~1.0 | Preserved |
| Kaiming fan-out | 2/n_out = 4.88e-4 | ~1.0 (gradient preserved) | Preserved |
4. GPT-2 Scaled Initialization
4.1 The Residual Accumulation Problem
Xavier and Kaiming ensure variance is preserved through a single layer. But transformers have residual connections. Each layer adds its output to the residual stream:
If has output variance , and is independent of :
After layers:
For : the variance grows by 81x. For (GPT-3): 97x.
This is not catastrophic (it is linear, not exponential), but it causes the activation magnitudes to grow as , which means the RMSNorm values grow, the attention logits grow, and the output distribution becomes sharper as depth increases. It also means that earlier layers contribute a proportionally smaller fraction of the final representation.
4.2 GPT-2’s Solution: Scale Output Projections
GPT-2 (Radford et al., 2019) introduced a simple fix: scale the output projection of each sublayer by , where is the number of layers and the factor 2 accounts for two sublayers (attention + FFN) per layer:
The base standard deviation is the standard init for all other weights. Only the output projections ( in attention, in FFN) are scaled down.
Why ? For GPT-2 with (the large variant), . The value is close to this, slightly smaller for safety. It has since become a convention used even at other model dimensions.
Why only output projections? These are the weights that write into the residual stream. The projections read from the residual stream into a sublayer’s internal space, where variance growth does not directly affect the residual. The output projection writes back, so its scale directly controls how much each layer adds.
For Llama 3 70B with :
Compare to the standard init : the output projections are initialized 12.65x smaller.
def gpt2_init(model, n_layers, base_std=0.02):
"""GPT-2 style initialization with scaled output projections."""
output_std = base_std / (2 * n_layers) ** 0.5
for name, param in model.named_parameters():
if param.dim() < 2:
# Bias and norm parameters: zero or ones
if 'norm' in name and 'weight' in name:
nn.init.ones_(param)
else:
nn.init.zeros_(param)
elif 'output_proj' in name or 'w2' in name or 'down_proj' in name:
# Output projections: scaled init
nn.init.normal_(param, mean=0.0, std=output_std)
else:
# All other weights: base init
nn.init.normal_(param, mean=0.0, std=base_std)
4.3 Variance Analysis With Scaled Init
With the scaled initialization, each sublayer output has variance:
The scaling means each sublayer contributes of the unscaled variance. After sublayers (attention + FFN for each of layers):
The total variance growth is bounded at 2x, regardless of depth. This is the correct behavior: the residual stream starts at (from embedding normalization) and ends at .
Residual Stream Variance Growth vs Initialization Strategy
(Var(h_L) / Var(h_0))5. mu-P: Maximal Update Parameterization
5.1 The Hyperparameter Transfer Problem
Training a 70B parameter model requires thousands of GPU-hours per trial. You cannot do a hyperparameter sweep at scale. The standard approach: tune hyperparameters (learning rate, batch size, weight decay, initialization scale) on a small model (e.g., 125M parameters) and hope they transfer to the large model.
With standard parameterization (SP), this transfer fails. The optimal learning rate for a 125M model is not the optimal learning rate for a 70B model. The relationship is not even monotonic — it depends on width, depth, batch size, and initialization in complex ways. Teams typically use small-model sweeps as a rough guide and then manually tune on the large model with a few expensive trials.
mu-P (maximal update parameterization), introduced by Yang et al. (2022), solves this: it defines a parameterization where the optimal hyperparameters are width-independent. You can sweep learning rates on a 40M model and directly use the optimal value on a 70B model.
5.2 Why Standard Parameterization Fails at Transfer
In SP, all weights are initialized with the same standard deviation and updated with the same learning rate . Consider the per-parameter update at one step:
For a weight matrix (where is the model width) in SP with Xavier init (), the magnitude of the weight update depends on the gradient, which scales as:
The weight magnitude is . The relative update is:
This shrinks as width increases. At width , the relative update is of the relative update at width . If is tuned for a small model (small ), it produces too-small updates at large . If you increase proportionally to , you can compensate — but the optimal scaling differs for different weight matrices (embeddings, attention, FFN, output head).
5.3 mu-P: The Key Idea
mu-P ensures that the change in function output caused by a weight update is (width-independent) for every layer. It achieves this by adjusting three things:
- Initialization scale per layer type
- Learning rate per layer type
- Layer output multipliers
The core principle: the coordinate-wise update should scale so that the matrix-level update has the right spectral norm for the output change to be .
For a hidden layer :
| Quantity | Standard Param (SP) | mu-P |
|---|---|---|
| Init scale | ||
| Learning rate multiplier | ||
| Forward: | ||
| Update-to-weight ratio | ||
| Output change per step | … wait |
Let me be precise. In mu-P, the parameterization for each layer type is:
Input embeddings :
- Init: (note: not )
- Learning rate:
- Forward: (no scaling)
Hidden-to-hidden weights :
- Init:
- Learning rate:
- Forward: (multiply by handled by init)
Actually, the standard formulation in the mu-P paper uses a “multiplier” form. Let me state it cleanly:
mu-P prescription for width (fan-in = fan-out = for simplicity):
| Layer Type | Init Variance | LR Multiplier | Output Multiplier |
|---|---|---|---|
| Embedding | |||
| Hidden (Attn, FFN internal) | |||
| Output (unembedding) | (or ) |
The critical difference from SP: hidden layer learning rates scale as . When you double the width from to , the learning rate for hidden weights is halved. But the base learning rate (which you tune on the small model) stays the same.
5.4 Why mu-P Enables Transfer
Consider increasing width from (proxy model) to (target model). In mu-P:
The output of a hidden layer is where . Each element of is:
. Variance preserved — same as Xavier.
The per-step update to is where is the gradient. The change in the output is:
The key: has terms, each of magnitude (due to the normalized activations and gradients). By the CLT, the sum is . Multiplied by … no. Let me trace this more carefully.
The gradient where . Each is (unit variance activations), each is (unit variance gradients in mu-P). So .
Since and , we get . The output change is , independent of . This is the “maximal update” property: the update does not vanish or explode as width changes.
In SP with learning rate (not scaled by ):
The output change grows linearly with width. A learning rate tuned for produces updates that are 32x too large at .
5.5 mu-P Implementation
import torch
import torch.nn as nn
class MuPLinear(nn.Linear):
"""Linear layer with mu-P parameterization."""
def __init__(self, in_features, out_features, bias=True,
layer_type='hidden', base_width=256):
super().__init__(in_features, out_features, bias)
self.layer_type = layer_type
self.base_width = base_width
self.width_mult = in_features / base_width
# Initialize
if layer_type == 'embedding':
# Embedding: init with O(1) variance
nn.init.normal_(self.weight, std=1.0)
elif layer_type == 'hidden':
# Hidden: init with 1/fan_in variance
nn.init.normal_(self.weight, std=1.0 / in_features**0.5)
elif layer_type == 'output':
# Output head: zero init (or very small)
nn.init.zeros_(self.weight)
if bias and self.bias is not None:
nn.init.zeros_(self.bias)
def get_lr_multiplier(self):
"""Return the learning rate multiplier for this layer."""
if self.layer_type == 'embedding':
return 1.0
elif self.layer_type == 'hidden':
return 1.0 / self.width_mult # LR scales as 1/width
elif self.layer_type == 'output':
return 1.0
return 1.0
def configure_mup_optimizer(model, base_lr, weight_decay=0.1):
"""Configure optimizer with mu-P learning rate scaling."""
param_groups = []
for name, module in model.named_modules():
if isinstance(module, MuPLinear):
lr_mult = module.get_lr_multiplier()
param_groups.append({
'params': [module.weight],
'lr': base_lr * lr_mult,
'weight_decay': weight_decay,
'name': name,
})
if module.bias is not None:
param_groups.append({
'params': [module.bias],
'lr': base_lr * lr_mult,
'weight_decay': 0.0,
'name': f"{name}.bias",
})
# Norm parameters: no weight decay, base LR
for name, param in model.named_parameters():
if 'norm' in name:
param_groups.append({
'params': [param],
'lr': base_lr,
'weight_decay': 0.0,
'name': name,
})
return torch.optim.AdamW(param_groups)
5.6 mu-P Transfer in Practice
The protocol:
- Define a “base width” (e.g., ) — this is the width of your smallest proxy model
- Train proxy models at widths with mu-P parameterization
- Sweep learning rate on the proxy models
- Find that the optimal LR is the same across all proxy widths (within noise)
- Use that LR directly for the target model at
The mu-P paper demonstrated this on GPT-3 scale models. The optimal base learning rate for a 40M-parameter proxy model () was . The same was optimal for models up to 6.7B parameters (). Without mu-P, the optimal LR shifted by 3-5x across these scales.
Optimal Learning Rate vs Model Width (SP vs mu-P)
| Width (d) | Params | Optimal LR (SP) | Optimal LR (mu-P) | mu-P Prediction Error |
|---|---|---|---|---|
| 256 | 40M | 3e-3 | 1e-2 | Baseline (tuned) |
| 512 | 150M | 2e-3 | 1e-2 | 0% |
| 1024 | 600M | 8e-4 | 1e-2 | 0% |
| 2048 | 2.5B | 4e-4 | 1e-2 | 0% |
| 4096 | 6.7B | 1.5e-4 | 1e-2 | 0% |
| 8192 | 70B | ? (too expensive to sweep) | 1e-2 | Predicted |
5.7 What mu-P Does Not Transfer
mu-P guarantees transfer of width-dependent hyperparameters. It does not address:
- Depth scaling: mu-P does not make optimal LR independent of depth. A 70B model with 80 layers vs 40 layers may have different optimal LRs even with mu-P
- Batch size: The optimal LR depends on batch size (linear scaling rule), which mu-P does not change
- Sequence length: Longer sequences change the effective batch size and gradient variance
- Tokenizer and data distribution: These affect the loss landscape entirely outside of parameterization
In practice, teams use mu-P for width transfer and separate ablations for depth, batch size, and sequence length. The savings are still enormous: width transfer alone can save dozens of expensive large-model trials.
As of 2025, mu-P has been adopted by several frontier labs. Cerebras published mu-P results for their models. Microsoft used mu-P insights in their Phi series. However, many teams still use manual tuning with SP, partly because mu-P requires modifying the parameterization of every layer (not just adding a flag) and because depth transfer remains unsolved.
6. Complete Initialization Code
Putting it all together: here is the initialization code for a Llama-style transformer covering all three approaches.
import torch
import torch.nn as nn
import math
def init_weights_xavier(model):
"""Xavier/Glorot initialization. Best for linear activations."""
for name, param in model.named_parameters():
if param.dim() < 2:
if 'norm' in name and 'weight' in name:
nn.init.ones_(param)
else:
nn.init.zeros_(param)
else:
nn.init.xavier_normal_(param)
def init_weights_kaiming(model, nonlinearity='relu'):
"""Kaiming/He initialization. Best for ReLU networks."""
for name, param in model.named_parameters():
if param.dim() < 2:
if 'norm' in name and 'weight' in name:
nn.init.ones_(param)
else:
nn.init.zeros_(param)
else:
nn.init.kaiming_normal_(param, nonlinearity=nonlinearity)
def init_weights_gpt2(model, n_layers, base_std=0.02):
"""
GPT-2 style initialization with scaled output projections.
- All weight matrices: N(0, base_std)
- Output projections (W_O in attention, W_2/down_proj in FFN):
N(0, base_std / sqrt(2 * n_layers))
- Norm weights: 1.0
- All biases: 0.0
- Embedding: N(0, base_std)
"""
scaled_std = base_std / math.sqrt(2 * n_layers)
for name, param in model.named_parameters():
if param.dim() < 2:
# 1D params: norm weights and biases
if 'norm' in name and 'weight' in name:
nn.init.ones_(param)
else:
nn.init.zeros_(param)
elif any(k in name for k in ['o_proj', 'output_proj', 'down_proj', 'w2']):
# Output projections: scaled init
nn.init.normal_(param, mean=0.0, std=scaled_std)
else:
# All other weight matrices: base init
nn.init.normal_(param, mean=0.0, std=base_std)
return {
'base_std': base_std,
'scaled_std': scaled_std,
'scale_factor': 1.0 / math.sqrt(2 * n_layers),
}
def init_weights_mup(model, base_width, n_layers, base_std=0.02):
"""
mu-P initialization with per-layer-type scaling.
Returns a dict mapping param names to LR multipliers
for use with per-param-group optimizer configuration.
"""
lr_multipliers = {}
scaled_std = base_std / math.sqrt(2 * n_layers)
for name, param in model.named_parameters():
if param.dim() < 2:
# 1D params
if 'norm' in name and 'weight' in name:
nn.init.ones_(param)
else:
nn.init.zeros_(param)
lr_multipliers[name] = 1.0
elif 'embed' in name:
# Embedding layer: O(1) init, base LR
nn.init.normal_(param, mean=0.0, std=1.0)
lr_multipliers[name] = 1.0
elif 'lm_head' in name or 'output' in name and 'proj' not in name:
# Output head (unembedding): zero init, base LR
nn.init.zeros_(param)
lr_multipliers[name] = 1.0
elif any(k in name for k in ['o_proj', 'down_proj', 'w2']):
# Output projections within layers: scaled init, scaled LR
fan_in = param.shape[1]
nn.init.normal_(param, mean=0.0, std=scaled_std)
lr_multipliers[name] = base_width / fan_in
else:
# Hidden weights (Q, K, V, gate, up projections): mu-P init and LR
fan_in = param.shape[1]
std = 1.0 / math.sqrt(fan_in)
nn.init.normal_(param, mean=0.0, std=std)
lr_multipliers[name] = base_width / fan_in
return lr_multipliers
# Example: Initialize a 70B-class model
class LlamaConfig:
d_model = 8192
n_layers = 80
n_heads = 64
n_kv_heads = 8
d_ff = 28672
vocab_size = 128256
config = LlamaConfig()
# For GPT-2 style:
# info = init_weights_gpt2(model, n_layers=config.n_layers, base_std=0.02)
# print(f"Base std: {info['base_std']:.6f}")
# print(f"Output proj std: {info['scaled_std']:.6f}")
# print(f"Scale factor: {info['scale_factor']:.6f}")
# Output:
# Base std: 0.020000
# Output proj std: 0.001581
# Scale factor: 0.079057
# For mu-P:
# lr_mults = init_weights_mup(model, base_width=256, n_layers=80, base_std=0.02)
# For W_Q (fan_in=8192): lr_mult = 256/8192 = 0.03125
# Base LR of 0.01 becomes 0.01 * 0.03125 = 3.125e-4 for hidden weights
# This matches the typical LR range for 70B models (1e-4 to 5e-4)
6.1 Initialization Diagnostics
After initialization, before any training, run these diagnostics:
def init_diagnostics(model):
"""Print per-layer statistics after initialization."""
print(f"{'Layer':<40} {'Shape':>18} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
print("-" * 100)
total_params = 0
for name, param in model.named_parameters():
total_params += param.numel()
data = param.data.float()
print(f"{name:<40} {str(list(param.shape)):>18} "
f"{data.mean():>10.6f} {data.std():>10.6f} "
f"{data.min():>10.6f} {data.max():>10.6f}")
print(f"\nTotal parameters: {total_params:,}")
# Check forward pass variance
x = torch.randn(1, 128, model.config.d_model)
with torch.no_grad():
for i, layer in enumerate(model.layers):
x_pre = x.clone()
x = layer(x)
ratio = x.var() / x_pre.var()
if ratio > 2.0 or ratio < 0.5:
print(f"WARNING: Layer {i} variance ratio = {ratio:.4f}")
Expected Initialization Statistics (Llama 3 70B, GPT-2 Init)
| Parameter | Shape | Expected Std | Param Count |
|---|---|---|---|
| Embedding | [128256, 8192] | 0.0200 | 1.05B |
| W_Q (per layer) | [8192, 8192] | 0.0200 | 67.1M |
| W_K (per layer) | [8192, 1024] | 0.0200 | 8.4M |
| W_V (per layer) | [8192, 1024] | 0.0200 | 8.4M |
| W_O (per layer) | [8192, 8192] | 0.00158 | 67.1M |
| W_gate (per layer) | [8192, 28672] | 0.0200 | 234.9M |
| W_up (per layer) | [8192, 28672] | 0.0200 | 234.9M |
| W_down (per layer) | [28672, 8192] | 0.00158 | 234.9M |
| RMSNorm gamma (per layer) | [8192] | 1.0 (init) | 8,192 |
7. Summary: Which Init to Use When
| Method | Use Case | Key Formula | Year |
|---|---|---|---|
| Xavier | Linear/tanh networks, historical reference | 2010 | |
| Kaiming | ReLU networks, CNNs | 2015 | |
| GPT-2 scaled | Transformers with residual connections | 2019 | |
| mu-P | Large-scale training with HP transfer | Per-layer and scaling | 2022 |
For production LLM training in 2025, GPT-2 scaled initialization is the most widely used default. mu-P is gaining adoption for teams that invest in the infrastructure to support per-parameter-group learning rates.
The fundamental principle across all methods is the same: ensure that the variance of activations and gradients is at every layer, at initialization, for the specific architecture being trained.
Reviewer Agent Validation Challenge
The following statements about this post’s content are candidates for review. Some are true, some contain deliberate errors.
-
Claim: For a linear layer with zero-mean, independent weights and inputs, . Verify that this uses the identity with zero means correctly.
-
Claim: Xavier initialization sets as a compromise between forward and backward variance preservation. Verify: if , does Xavier reduce to ?
-
Claim: Kaiming initialization for ReLU uses because for symmetric . Verify: is the factor correct for a zero-mean Gaussian input?
-
Claim: GPT-2 scaled init uses for output projections, resulting in for . Compute and verify.
-
Claim: With standard initialization (no scaling), the residual stream variance after layers is . Verify: is this or ? Does this assume each layer contributes exactly ?
-
Claim: In mu-P, the per-step output change for a hidden layer is , independent of width. The derivation uses . Check whether the factor of correctly cancels in the derivation.
-
Claim: For Llama 3 70B with standard normal init, attention logit variance is even with scaling. Verify: when .
-
Claim: The mu-P LR multiplier for a hidden weight with and base width 256 is , giving effective LR of from a base LR of . Verify the arithmetic.