Part of Series Transformer Anatomy 7 of 23
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 21 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 22 Knowledge Distillation: Training Small Models to Match Large Ones 23 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search

In the previous parts of this series, we have built up the transformer piece by piece: the attention mechanism, tokenization, embeddings, positional encoding, softmax numerics, and attention variants. Every one of those components assumes that the activations flowing through the network are well-behaved — that they occupy a reasonable numerical range, that gradients neither vanish nor explode, and that each layer receives inputs from a roughly stable distribution. Without normalization, none of those assumptions hold. A 96-layer transformer trained without any normalization will diverge within the first few hundred steps. The loss will spike to infinity, the gradients will overflow, and the run will be wasted.

Normalization is the component that makes deep transformers trainable at all. It is also the component where the choice of variant — LayerNorm vs RMSNorm, Pre-Norm vs Post-Norm, with or without QK-Norm — has quietly shaped the entire landscape of modern LLM architectures. This post explains why normalization exists, how each variant works from first principles, what specific training stability problem each one solves, and what the inference cost implications are.


1. Why Normalize at All?

The Activation Explosion Problem

Consider a deep network as a composition of functions: fLfL1f1(x)f_L \circ f_{L-1} \circ \cdots \circ f_1(x). Each layer flf_l applies a linear transformation followed by a nonlinearity. If the linear transformation has weights whose spectral norm is slightly greater than 1 — even by a small amount like 1.01 — the activations grow exponentially through depth:

aLa0(1.01)L\|a_L\| \approx \|a_0\| \cdot (1.01)^L

For L=96L = 96 layers (the depth of GPT-3), that factor is (1.01)962.6(1.01)^{96} \approx 2.6, which is manageable. But if the effective multiplier is 1.1, we get (1.1)968,916(1.1)^{96} \approx 8{,}916. At 1.5, we get (1.5)965.5×1016(1.5)^{96} \approx 5.5 \times 10^{16} — far beyond the representable range of FP16 (max 65,504\approx 65{,}504) and well into overflow territory even for BF16 (max 3.4×1038\approx 3.4 \times 10^{38}, but precision degrades catastrophically long before that).

The same problem applies in reverse for the backward pass. Gradients flowing backward through the network are multiplied by the transpose of each layer’s Jacobian. If activations grow forward, gradients shrink backward (vanishing gradients), and vice versa (exploding gradients). Either failure mode makes optimization impossible.

Internal Covariate Shift

Beyond the magnitude problem, there is a distributional problem. During training, the parameters of layer ll are updated by gradient descent. This changes the function flf_l, which changes the distribution of inputs to layer l+1l+1. Layer l+1l+1 must now adapt to a shifted input distribution, but its own update simultaneously shifts the input to layer l+2l+2. This cascading distributional shift — termed internal covariate shift by Ioffe and Szegedy (2015) — forces each layer to chase a moving target, slowing convergence and requiring very conservative (small) learning rates.

ℹ️ The Mental Model

Think of normalization as a “reset button” applied at every layer. Before each sublayer processes its input, normalization rescales the activations to have a consistent mean and variance. This ensures that (1) the magnitude stays bounded regardless of depth, and (2) the input distribution to each layer remains approximately stationary even as earlier layers update their weights. The result: you can use larger learning rates, train deeper networks, and converge faster.

What Normalization Must Provide

A useful normalization scheme for deep transformers must satisfy three properties:

  1. Bounded activations: outputs of the normalization should have controlled magnitude (typically zero mean, unit variance) regardless of the input scale.
  2. Per-example independence: the normalization of one example must not depend on other examples in the batch, because batch sizes vary between training and inference, and autoregressive generation often uses batch size 1.
  3. Differentiability: the normalization must be smooth enough to permit stable gradient computation.

These requirements immediately rule out certain approaches (notably BatchNorm, as we will see) and explain why LayerNorm became the default for transformers.

Σ Definition: Normalization as Reparameterization

Given activations aRda \in \mathbb{R}^d, a normalization function Norm(a)\text{Norm}(a) reparameterizes the activations such that Norm(a)\text{Norm}(a) has approximately zero mean and unit variance across the normalized dimensions. The key design choice is which dimensions to normalize over: the batch dimension, the feature dimension, the spatial dimensions, or some combination.


2. BatchNorm: Why It Works for CNNs, Fails for Transformers

How BatchNorm Works

Batch Normalization (Ioffe and Szegedy, 2015) was the first widely adopted normalization technique, and it powered the deep learning revolution in computer vision. For a given feature channel kk, BatchNorm computes the mean and variance across the batch dimension:

μk=1Bi=1Bai,k,σk2=1Bi=1B(ai,kμk)2\mu_k = \frac{1}{B} \sum_{i=1}^{B} a_{i,k}, \quad \sigma_k^2 = \frac{1}{B} \sum_{i=1}^{B} (a_{i,k} - \mu_k)^2 BN(ai,k)=γkai,kμkσk2+ϵ+βk\text{BN}(a_{i,k}) = \gamma_k \cdot \frac{a_{i,k} - \mu_k}{\sqrt{\sigma_k^2 + \epsilon}} + \beta_k

where γk\gamma_k and βk\beta_k are learnable affine parameters and ϵ\epsilon is a small constant for numerical stability.

Why BatchNorm Dominates in CNNs

In convolutional networks, BatchNorm works exceptionally well for several reasons:

  • Fixed spatial structure: every image in a batch has the same height and width, so the batch statistics are always computed over the same number of elements.
  • Large batches: vision training typically uses batch sizes of 256 or more, providing robust statistics.
  • Channel-wise normalization: each convolutional channel captures a single feature map, and normalizing per-channel aligns well with the inductive bias of CNNs.

BatchNorm introduced an effective regularization effect through the noise in mini-batch statistics, often reducing or eliminating the need for dropout. It enabled training of very deep networks (ResNet-152 and beyond) that were previously intractable.

Why BatchNorm Fails for Transformers

Transformers violate every assumption that makes BatchNorm effective:

Variable sequence lengths. In a batch of text sequences, different examples have different lengths. Position 500 might be a valid token in one example and a padding token in another. Computing batch statistics across these positions mixes meaningful activations with padding, corrupting the statistics. Masking helps but adds complexity and still degrades the statistical quality when sequences have highly variable lengths.

Small batches at inference. During autoregressive generation, the effective batch size is often 1 (a single user request). BatchNorm requires maintaining running statistics from training, and these running averages may not match the test-time distribution well — a persistent source of train-test mismatch that plagues BatchNorm-based models in low-batch regimes.

Token-by-token generation. During the decode phase, each forward pass processes a single new token. There is no batch of tokens over which to compute meaningful statistics. The running statistics from training become the sole normalization reference, but they were computed over full sequences and may not apply to individual decode steps.

Coupling between examples. BatchNorm creates a dependency between examples in a batch: the normalization of example ii depends on all other examples through the batch statistics. This is acceptable in vision (images are independent), but in language modeling it means the output for one sequence changes depending on what other sequences happen to be in the same batch — a property that violates the principle of independent inference.

⚠️ The BatchNorm-Transformer Mismatch

Several early papers attempted to use BatchNorm in transformers (e.g., PowerNorm, Bhatt et al. 2020). While some tricks helped, the fundamental mismatch remains: BatchNorm computes statistics across the batch dimension, but transformer inference needs per-example normalization. The community converged on LayerNorm for good reason.


3. LayerNorm from First Principles

The Core Idea

Layer Normalization (Ba, Kiros, and Hinton, 2016) takes a fundamentally different approach from BatchNorm: instead of normalizing across the batch, it normalizes across the feature dimension. For each individual token in each individual example, LayerNorm computes the mean and variance across the dmodeld_\text{model} features:

μ=1di=1dxi,σ2=1di=1d(xiμ)2\mu = \frac{1}{d} \sum_{i=1}^{d} x_i, \quad \sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2 LN(x)=γxμσ2+ϵ+β\text{LN}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

Here xRdx \in \mathbb{R}^d is the activation vector for a single token, γ,βRd\gamma, \beta \in \mathbb{R}^d are learnable scale and shift parameters (element-wise), and ϵ\epsilon is a small constant (typically 10510^{-5} or 10610^{-6}).

Tensor Shape Walkthrough

Understanding exactly which dimensions are involved is critical for implementation. In a transformer, the activations after any sublayer have shape [batch, seq_len, d_model]. LayerNorm operates on the last dimension:

import torch
import torch.nn as nn

# Input shape: [batch=4, seq_len=128, d_model=4096]
x = torch.randn(4, 128, 4096, dtype=torch.bfloat16)

# LayerNorm normalizes over d_model (last dim)
ln = nn.LayerNorm(4096)  # normalized_shape = (4096,)

# For EACH of the 4*128 = 512 tokens independently:
#   1. Compute mean across 4096 features  -> scalar
#   2. Compute variance across 4096 features -> scalar
#   3. Normalize: (x - mean) / sqrt(var + eps)
#   4. Scale and shift: gamma * normalized + beta
y = ln(x)  # Output shape: [4, 128, 4096] -- same as input

The key property: every token is normalized independently. Token 0 in batch element 0 is normalized using only its own 4096-dimensional feature vector. It has zero dependence on any other token in the sequence or any other example in the batch. This makes LayerNorm perfectly compatible with variable sequence lengths, arbitrary batch sizes, and token-by-token autoregressive generation.

The Learnable Affine Parameters

After normalization, the activations have approximately zero mean and unit variance across the feature dimension. But this constraint might be too rigid — the network might need certain features to have larger magnitudes or non-zero means. The learnable parameters γ\gamma (scale) and β\beta (shift) restore the network’s ability to represent any affine transformation of the normalized activations.

In principle, the network could learn γ=σoriginal\gamma = \sigma_\text{original} and β=μoriginal\beta = \mu_\text{original}, recovering the un-normalized activations exactly. In practice, the normalization step acts as a beneficial regularizer, and the learned γ\gamma and β\beta converge to values that improve both training stability and final quality.

Computational Cost

LayerNorm requires two passes over the feature dimension per token:

  1. First pass: compute μ\mu (one reduction over dd elements)
  2. Second pass: compute σ2\sigma^2 (one reduction over dd elements, using μ\mu)
  3. Normalize: one element-wise operation (dd multiplications and additions)

Total: approximately 5d5d FLOPs per token (two reductions of dd, one element-wise multiply, one element-wise subtract, one element-wise add). For d=4096d = 4096, that is roughly 20K FLOPs per token — negligible compared to the matrix multiplications in attention and FFN, which require millions of FLOPs per token.

The cost is not in FLOPs but in memory bandwidth. Each LayerNorm reads and writes the full activation tensor, and the two-pass algorithm means the data must be loaded from memory twice. At large dmodeld_\text{model} (8192 or above), this becomes a meaningful contributor to per-layer latency, especially during decode when the computation is entirely memory-bandwidth-bound.

💡 Fused LayerNorm Kernels

In practice, frameworks like Apex, Triton, and FlashInfer provide fused LayerNorm kernels that compute the mean, variance, and normalization in a single pass over the data. This halves the memory bandwidth requirement and is essential for achieving good performance at large model dimensions. Never use a naive two-pass implementation in production.


4. Pre-Norm vs Post-Norm: THE Architectural Decision

The Original: Post-Norm (Vaswani et al., 2017)

The original transformer (“Attention Is All You Need”) placed normalization after the residual addition:

xl+1=LN(xl+Sublayer(xl))x_{l+1} = \text{LN}(x_l + \text{Sublayer}(x_l))

In this arrangement — called Post-Norm — the output of the sublayer (attention or FFN) is added to the residual stream, and then the sum is normalized. The normalized result becomes the input to the next sublayer.

Post-Norm was the default in the original transformer, BERT, and the first generation of large language models. It works. But it has a critical weakness that becomes increasingly severe as models get deeper.

The Modern Default: Pre-Norm (GPT-2 and All Modern LLMs)

Starting with GPT-2, the community shifted to placing normalization before the sublayer:

xl+1=xl+Sublayer(LN(xl))x_{l+1} = x_l + \text{Sublayer}(\text{LN}(x_l))

In Pre-Norm, the residual stream xlx_l passes through a normalization layer before entering the sublayer. The sublayer’s output is then added directly to the un-normalized residual stream.

This seemingly minor change has profound consequences for gradient flow.

Gradient Flow Analysis: Why Pre-Norm Won

To understand why Pre-Norm is more stable, we need to trace the gradient path from the loss L\mathcal{L} back through the network.

Post-Norm gradient path. In Post-Norm, the gradient at layer ll must flow through the normalization layer:

Lxl=Lxl+1LN(xl+Sublayer(xl))xl\frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_{l+1}} \cdot \frac{\partial \text{LN}(x_l + \text{Sublayer}(x_l))}{\partial x_l}

The Jacobian of LayerNorm is not the identity — it is a complex function that depends on the current activation statistics. More critically, the gradient through LayerNorm has a projection component that removes the mean gradient, and a scaling component that depends on the variance. Through many layers, these repeated projections and rescalings can attenuate the gradient signal, making it difficult for early layers to receive useful learning signals.

Pre-Norm gradient path. In Pre-Norm, the residual connection provides a clean, unobstructed gradient path:

Lxl=Lxl+1(I+Sublayer(LN(xl))xl)\frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_{l+1}} \cdot \left(I + \frac{\partial \text{Sublayer}(\text{LN}(x_l))}{\partial x_l}\right)

The identity matrix II in this expression is the crucial term. It means that gradients flow directly from xl+1x_{l+1} back to xlx_l through the skip connection without any modification. The sublayer and the normalization within it contribute an additive correction to the gradient, but the base gradient signal is preserved at full magnitude.

Σ Theorem: Pre-Norm Gradient Preservation

In a Pre-Norm transformer with LL layers, the gradient from the loss to the input of layer ll has a direct path through LlL - l identity operations. The gradient magnitude is bounded below by the gradient at the output layer, regardless of depth. In contrast, Post-Norm gradients must traverse LlL - l LayerNorm Jacobians, each of which can attenuate the signal.

The Empirical Evidence

The stability advantage of Pre-Norm is not merely theoretical. Multiple large-scale studies have confirmed it:

  • Xiong et al. (2020), “On Layer Normalization in the Transformer Architecture”: demonstrated that Pre-Norm eliminates the need for learning rate warmup (which Post-Norm requires to avoid early-training divergence) and enables stable training of transformers up to 100+ layers.
  • GPT-2, GPT-3, GPT-4: all use Pre-Norm. OpenAI switched from Post-Norm (GPT-1 followed the original transformer) to Pre-Norm in GPT-2 and never looked back.
  • Llama 1/2/3, Mistral, DeepSeek, Qwen: every major open-weight LLM uses Pre-Norm.

The pattern is unambiguous: for training stability at depth, Pre-Norm is strictly superior.

The Post-Norm Trade-off

There is one area where Post-Norm has an advantage: final model quality. Several studies (Liu et al., 2020; Nguyen and Salazar, 2019) have shown that Post-Norm models, when they can be trained successfully, achieve slightly better perplexity and downstream task performance than Pre-Norm models of the same size.

The intuition is that Post-Norm applies normalization to the combined residual + sublayer output, which provides a stronger regularization effect. The sublayer output must “earn” its contribution against the normalization of the sum, preventing any single sublayer from dominating the residual stream.

But this advantage is academic for practitioners. Post-Norm requires careful hyperparameter tuning (especially learning rate warmup schedules), is brittle at depth, and provides at best a fraction of a percent improvement in quality. Pre-Norm is the robust default, and the field has moved on — with one notable exception: DeepNorm, which we will cover in Section 7.

📊

Pre-Norm vs Post-Norm Training Stability

ConfigurationMax Stable DepthLR Warmup RequiredFinal PPL (if trainable)
Post-Norm, no warmup 6 layers N/A (diverges) ---
Post-Norm, with warmup 24 layers Yes (4000 steps) 15.2
Pre-Norm, no warmup 96+ layers No 15.4
Pre-Norm, with warmup 96+ layers Optional 15.3
Note: Hypothetical 350M parameter model trained on OpenWebText. PPL = perplexity on validation set (lower is better). Pre-Norm sacrifices ~0.1-0.2 PPL for dramatically improved stability.

Residual Stream as Information Highway

The Pre-Norm perspective reveals a powerful mental model: the residual stream is the primary information carrier in a transformer. Each sublayer (attention or FFN) reads from the residual stream (after normalization), computes an update, and writes that update additively back to the residual stream. The residual stream itself is never modified in place — it only receives additive contributions.

This “information highway” view, popularized by Elhage et al. (2021) in the Anthropic interpretability work, explains why transformers can be so deep: the residual stream provides a direct channel from input to output, and each layer is free to contribute or not contribute to the final representation. Normalization before each sublayer ensures that the sublayer always receives well-conditioned inputs, regardless of what previous layers have added to the residual stream.


5. RMSNorm: Drop the Mean, Keep the RMS

Motivation: What Actually Matters in LayerNorm?

In 2019, Zhang and Sennrich published “Root Mean Square Layer Normalization,” asking a simple but important question: does the mean subtraction in LayerNorm actually contribute to model quality?

Recall the LayerNorm formula:

LN(x)=γxμσ2+ϵ+β\text{LN}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta

This performs two operations:

  1. Re-centering: subtracting the mean μ\mu (shifting the distribution to have zero mean)
  2. Re-scaling: dividing by the standard deviation σ2+ϵ\sqrt{\sigma^2 + \epsilon} (scaling the distribution to have unit variance)

Zhang and Sennrich hypothesized that re-scaling is the critical operation — it is what prevents activation magnitudes from growing unboundedly — while re-centering is a secondary effect that contributes little to training stability. Their experiments confirmed this.

The RMSNorm Formula

RMSNorm keeps the re-scaling but drops the re-centering:

RMS(x)=1di=1dxi2\text{RMS}(x) = \sqrt{\frac{1}{d} \sum_{i=1}^{d} x_i^2} RMSNorm(x)=γxRMS(x)+ϵ\text{RMSNorm}(x) = \gamma \odot \frac{x}{\text{RMS}(x) + \epsilon}

Note what has changed compared to LayerNorm:

  • No mean subtraction: the xμx - \mu term is gone; we divide xx directly by the RMS.
  • No bias parameter: the β\beta shift is removed.
  • RMS instead of standard deviation: RMS(x)=1dxi2\text{RMS}(x) = \sqrt{\frac{1}{d}\sum x_i^2}, which is the standard deviation only when the mean is zero. When the mean is non-zero, RMS(x)σ(x)\text{RMS}(x) \geq \sigma(x).

Why It Works: Re-scaling Invariance

The key insight from Zhang and Sennrich is that normalization’s primary benefit is providing re-scaling invariance: if the input activations are multiplied by any scalar α\alpha, the normalized output remains the same:

RMSNorm(αx)=γαxRMS(αx)=γαxαRMS(x)=sign(α)RMSNorm(x)\text{RMSNorm}(\alpha x) = \gamma \odot \frac{\alpha x}{\text{RMS}(\alpha x)} = \gamma \odot \frac{\alpha x}{|\alpha| \cdot \text{RMS}(x)} = \text{sign}(\alpha) \cdot \text{RMSNorm}(x)

For positive α\alpha (which is the typical case since we are concerned with growing activation magnitudes), this is exact invariance. The normalization completely cancels any uniform scaling of the activations — which is precisely the failure mode (exponential growth or decay through layers) that normalization exists to prevent.

LayerNorm provides both re-scaling invariance and re-centering invariance (invariance to uniform shifts). RMSNorm provides only re-scaling invariance. The experimental finding is that re-centering invariance is unnecessary: it is a “nice to have” that costs compute without improving training dynamics.

Σ Theorem: RMSNorm Invariance Property

For any positive scalar α>0\alpha > 0 and input xRdx \in \mathbb{R}^d: RMSNorm(αx)=RMSNorm(x)\text{RMSNorm}(\alpha x) = \text{RMSNorm}(x). This re-scaling invariance is sufficient to prevent activation magnitude explosion/collapse through depth, which is the primary failure mode that normalization addresses.

Computational Savings

RMSNorm is cheaper than LayerNorm because it eliminates the mean computation:

LayerNorm requires:

  1. Compute mean μ\mu: one reduction (dd additions, 1 division)
  2. Subtract mean: dd subtractions
  3. Compute variance σ2\sigma^2: dd squarings, one reduction, 1 division
  4. Normalize: dd divisions, dd multiplications, dd additions (for the affine transform)

RMSNorm requires:

  1. Compute sum of squares: dd squarings, one reduction, 1 division
  2. Normalize: dd divisions, dd multiplications

The savings are in the mean computation and mean subtraction steps — roughly 2d2d fewer FLOPs. In isolation, this sounds trivial. But remember that normalization is applied twice per transformer layer (once before attention, once before FFN), and there are 80+ layers in a modern LLM. At dmodel=8192d_\text{model} = 8192, the savings are:

2×2×8192×80=2,621,440 FLOPs per token saved2 \times 2 \times 8192 \times 80 = 2{,}621{,}440 \text{ FLOPs per token saved}

More importantly, the memory bandwidth savings are significant. RMSNorm requires one pass over the data (compute sum of squares and normalize in a single fused kernel), while a naive LayerNorm requires two passes (one for mean, one for variance + normalize). Even with fused LayerNorm kernels that reduce this to one pass, RMSNorm’s kernel is simpler and achieves higher memory bandwidth utilization.

Who Uses RMSNorm

The adoption of RMSNorm reads like a list of the most important LLMs of the past three years:

📊

RMSNorm Adoption in Major LLMs

ModelNormalizationd_modelLayersYear
GPT-3 LayerNorm 12288 96 2020
Llama 1 RMSNorm 4096-8192 32-80 2023
Llama 2 RMSNorm 4096-8192 32-80 2023
Llama 3 RMSNorm 8192 80 2024
Mistral 7B RMSNorm 4096 32 2023
DeepSeek-V2 RMSNorm 5120 60 2024
Qwen 2 RMSNorm 8192 80 2024
Gemma 2 RMSNorm 3072-4608 26-42 2024
Note: The shift from LayerNorm to RMSNorm happened around 2023 with the release of Llama 1. No major LLM released since has used full LayerNorm.

The transition was swift and decisive. Llama 1 (February 2023) demonstrated that RMSNorm matched LayerNorm in quality while being measurably faster, and the entire field followed. It is now effectively the standard.


6. QK-Norm: Taming Attention Logit Growth

The Problem: Unbounded Attention Logits

Recall from our attention discussion that the raw attention scores are computed as:

S=QKdkS = \frac{QK^\top}{\sqrt{d_k}}

The 1/dk1/\sqrt{d_k} scaling keeps the variance of the scores approximately 1 when QQ and KK have unit-variance entries. But during training, QQ and KK are the outputs of learned linear projections, and nothing constrains their magnitude. In very large models (billions of parameters), the norms of query and key vectors can grow steadily during training, causing the attention logits to become very large.

When attention logits are large, the softmax becomes nearly one-hot: one position receives almost all the attention weight, and all other positions receive effectively zero. This is called softmax saturation or attention entropy collapse. The consequences are severe:

  • Lost information: the output is nearly a copy of a single value vector, discarding the information from all other positions.
  • Vanishing gradients through softmax: when softmax is saturated, its Jacobian has near-zero entries everywhere except at the argmax, so gradients do not flow to positions that received low attention.
  • Training instability: attention patterns become spiky and volatile, especially in early training when the model has not yet learned stable representations.

The Solution: Normalize Q and K

QK-Norm applies normalization to the query and key vectors before computing attention scores:

S=Norm(Q)Norm(K)dkS = \frac{\text{Norm}(Q) \cdot \text{Norm}(K)^\top}{\sqrt{d_k}}

The normalization (typically L2 normalization or LayerNorm applied per head) constrains the magnitude of QQ and KK, which in turn bounds the attention logits. If both QQ and KK are L2-normalized to unit length, the dot product qikjq_i \cdot k_j is bounded by the Cauchy-Schwarz inequality:

qikjqikj=1|q_i \cdot k_j| \leq \|q_i\| \cdot \|k_j\| = 1

This means the attention logits are bounded in [1/dk,1/dk][-1/\sqrt{d_k}, 1/\sqrt{d_k}] (or [1,1][-1, 1] before the dk\sqrt{d_k} scaling), completely eliminating the possibility of softmax saturation from logit magnitude growth.

The Connection to Scaled Dot-Product Attention

The 1/dk1/\sqrt{d_k} scaling in standard attention is itself a normalization — it normalizes for the expected variance of dot products under the assumption that QQ and KK have unit-variance entries. QK-Norm is a stronger version of the same idea: instead of assuming unit variance, it enforces bounded magnitude.

The relationship is this: 1/dk1/\sqrt{d_k} scaling works when QQ and KK are initialized with appropriate variance and the variance stays stable during training. QK-Norm provides a guarantee that works regardless of what happens during training.

ℹ️ QK-Norm Adoption

QK-Norm is not yet universal, but its adoption is growing. Gemma 2 (Google, 2024) uses QK-Norm across all attention heads. Cohere Command R+ and several other production models have adopted it. As models continue to scale, the pressure toward QK-Norm will increase — larger models with more parameters tend to have more severe logit growth.

Implementation

In practice, QK-Norm is typically implemented as an RMSNorm (or L2 norm) applied to the query and key tensors after projection but before the attention score computation:

class QKNormAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.head_dim = d_model // num_heads
        self.num_heads = num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        # Per-head RMSNorm for Q and K
        self.q_norm = RMSNorm(self.head_dim)
        self.k_norm = RMSNorm(self.head_dim)

    def forward(self, x):
        B, L, D = x.shape
        Q = self.W_q(x).view(B, L, self.num_heads, self.head_dim)
        K = self.W_k(x).view(B, L, self.num_heads, self.head_dim)
        V = self.W_v(x).view(B, L, self.num_heads, self.head_dim)

        # Normalize Q and K per head
        Q = self.q_norm(Q)
        K = self.k_norm(K)

        # Standard scaled dot-product attention
        scores = torch.einsum('blhd,bshd->bhls', Q, K) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)
        out = torch.einsum('bhls,bshd->blhd', attn, V)
        return self.W_o(out.reshape(B, L, D))

The additional cost is two RMSNorm operations per attention layer, each over a tensor of shape [batch, seq, heads, head_dim]. This is typically negligible compared to the QKV projection matrix multiplications.


7. DeepNorm: Scaling to 1000 Layers

The Problem at Extreme Depth

Pre-Norm enables training of deep transformers (up to roughly 100 layers with standard hyperparameters). But as researchers pushed toward even deeper models — 200, 500, 1000 layers — Pre-Norm itself began to show limitations. The issue is subtle: while Pre-Norm preserves gradient magnitude through the residual connections, the relative contribution of each sublayer’s output to the residual stream decreases with depth. In a 1000-layer Pre-Norm transformer, each sublayer’s contribution is a tiny fraction of the accumulated residual, and the model struggles to learn effectively.

Microsoft’s DeepNorm

Wang et al. (2022) at Microsoft Research introduced DeepNorm, which modifies the residual connection scaling:

xl+1=LN(αxl+Sublayer(xl))x_{l+1} = \text{LN}(\alpha \cdot x_l + \text{Sublayer}(x_l))

where α\alpha is a constant scaling factor applied to the residual stream before the sublayer output is added. Additionally, the sublayer’s weights are initialized with a scaling factor β\beta that depends on the depth:

α=(2L)1/4,β=(8L)1/4\alpha = (2L)^{1/4}, \quad \beta = (8L)^{-1/4}

where LL is the total number of layers.

The intuition: by scaling up the residual (α>1\alpha > 1), the residual stream retains more of its accumulated information, while the smaller initialization (β<1\beta < 1) ensures that each sublayer’s contribution starts small and grows during training. This prevents the sublayer outputs from overwhelming the residual stream in early training, which is the primary failure mode at extreme depth.

💡 DeepNorm's Architectural Position

Note that DeepNorm uses a Post-Norm placement (normalization after the residual addition), but with the α\alpha scaling. This is a hybrid approach: the α\alpha scaling addresses Post-Norm’s gradient attenuation problem, while the Post-Norm placement provides the quality advantage of normalizing the combined residual + sublayer output. DeepNorm achieves the best of both worlds for extremely deep models.

Practical Impact

Microsoft trained a 1000-layer transformer (2500 attention + FFN sublayers) using DeepNorm, which would have been impossible with either standard Pre-Norm or Post-Norm. The resulting model showed continued improvement with depth, confirming that the scaling laws had not saturated — it was the training dynamics, not the architecture’s expressivity, that had previously limited depth.

For most practitioners, 1000-layer transformers are not yet relevant (current frontier models use 80-128 layers). But DeepNorm demonstrates an important principle: the normalization scheme and the residual connection scheme must be co-designed, especially at scale. As the field pushes toward deeper and more efficient architectures, DeepNorm’s ideas are likely to become increasingly important.


8. The Training Stability Connection

Normalization does not exist in isolation. It interacts with every other training decision: learning rate, weight initialization, optimizer choice, and numerical precision. Understanding these interactions is essential for stable large-scale training.

Normalization and Learning Rate

Without normalization, the maximum stable learning rate is tightly coupled to the network depth. Deeper networks require smaller learning rates because each layer’s update propagates forward and gets amplified (or attenuated) by subsequent layers. The maximum stable learning rate scales roughly as ηmax1/L\eta_\text{max} \propto 1/\sqrt{L} for unnormalized networks.

Normalization decouples the learning rate from depth. Because each layer’s input is normalized to a consistent scale, the effective learning rate at each layer is approximately the same regardless of depth. This is why Pre-Norm transformers can train with the same learning rate schedule whether they have 12 layers or 96 layers.

In practice, well-normalized transformers train with peak learning rates of 3×1043 \times 10^{-4} to 1×1031 \times 10^{-3}, regardless of depth. Without normalization, a 96-layer model might require a learning rate of 10510^{-5} or smaller — and even then, training is fragile.

Normalization and Weight Initialization

The standard initialization schemes (Xavier, Kaiming/He) are designed to maintain activation variance through the forward pass of an unnormalized network. With normalization, the initialization matters less: even if the initial forward pass produces activations with wildly varying magnitudes, the normalization layers will immediately bring them into a well-behaved range.

This does not mean initialization is irrelevant. Bad initialization can still cause problems in the first few training steps (before the optimizer has adjusted the parameters), and it can affect the early-training loss landscape. The GPT-2 and GPT-3 papers both use a scaled initialization for the residual path (1/2L1/\sqrt{2L} scaling of the output projection weights), which ensures that the initial sublayer contributions are small relative to the residual stream. This complements the normalization rather than replacing it.

Mixed Precision and Normalization: The FP32 Rule

This is the most practically important interaction, and getting it wrong will silently corrupt your training run.

Modern LLM training uses mixed precision: the bulk of the computation (matrix multiplications) runs in BF16 or FP8, while certain operations run in FP32. Normalization must always run in FP32. The reason is the variance computation:

σ2=1di=1d(xiμ)2\sigma^2 = \frac{1}{d} \sum_{i=1}^{d} (x_i - \mu)^2

In BF16 (which has only 8 bits of mantissa, giving roughly 3 decimal digits of precision), the subtraction xiμx_i - \mu suffers catastrophic cancellation when xix_i is close to μ\mu — which it often is, since we are computing the deviation from the mean. The resulting variance estimate can be wildly inaccurate, leading to incorrect normalization, corrupted gradients, and training instability.

🚨 Never Run Normalization in Low Precision

If you are implementing a training pipeline and you see normalization running in BF16 or FP16, that is a bug. The variance computation requires FP32 precision. The standard pattern: upcast the input to FP32, compute the normalization, then downcast the output back to BF16/FP16. PyTorch’s nn.LayerNorm and most framework implementations handle this automatically, but custom kernels (Triton, CUDA) must do it explicitly.

The FP32 normalization requirement has a performance cost. The upcast and downcast operations add latency, and the FP32 arithmetic is slower than BF16 on modern hardware (A100 and H100 tensor cores are 2x faster in BF16 than FP32). But this cost is non-negotiable — incorrect normalization will destroy your training run.

For FP8 training (the current frontier of mixed-precision research), the normalization precision requirement becomes even more critical. FP8 has only 3-4 bits of mantissa, making any statistical computation (mean, variance, RMS) completely unreliable. The normalization layers are always the first operations excluded from the FP8 compute graph.

Gradient Scaling and Normalization

Normalization also interacts with loss scaling, which is used in mixed-precision training to prevent gradient underflow in FP16. The normalization layers’ gradients involve division by σ\sigma (or RMS), which can amplify small gradients. If the loss scale is too small, the gradients through the normalization layer may underflow to zero in FP16 before being rescaled. If the loss scale is too large, the amplified gradients may overflow.

Dynamic loss scaling (used by default in PyTorch AMP and DeepSpeed) handles this automatically by adjusting the scale factor based on whether overflows are detected. But understanding the interaction explains why dynamic loss scaling occasionally reduces the scale factor during training — it is often the normalization gradients that trigger the overflow detection.

Training Stability: Effect of Normalization Precision

(% of target quality at 100K steps)
FP32 Norm Stable baseline
100 % of target quality at 100K steps
BF16 Norm Degraded convergence
73 % of target quality at 100K steps
FP16 Norm Frequent loss spikes
45 % of target quality at 100K steps
No Norm Diverges immediately
0 % of target quality at 100K steps

9. Inference Cost: LayerNorm vs RMSNorm in Production

Where Normalization Sits in the Latency Budget

During inference, especially in the autoregressive decode phase (generating one token at a time), the computation is dominated by memory bandwidth — reading model weights and KV cache from HBM. The matrix multiplications in attention and FFN are the primary consumers. Normalization is a relatively small but non-negligible contributor.

A single transformer layer contains two normalization operations: one before attention, one before FFN. For a model with LL layers, that is 2L2L normalization operations per token generated. Let us compute the absolute cost.

Microbenchmark: LayerNorm vs RMSNorm

Consider a model with dmodel=8192d_\text{model} = 8192 (Llama 3 70B scale). Each normalization operation processes one vector of 8192 elements.

LayerNorm per operation:

  • Read input: 8192 elements x 2 bytes (BF16) = 16 KB
  • Upcast to FP32, compute mean (reduction), compute variance (reduction), normalize, apply affine, downcast
  • Read γ\gamma and β\beta parameters: 2 x 8192 x 4 bytes (FP32) = 64 KB
  • Write output: 16 KB
  • Total memory traffic: ~96 KB
  • Estimated latency on H100 (3.35 TB/s HBM bandwidth): ~0.028 us

RMSNorm per operation:

  • Read input: 16 KB
  • Upcast to FP32, compute sum of squares (single reduction), normalize, apply scale, downcast
  • Read γ\gamma parameter only: 8192 x 4 bytes = 32 KB (no β\beta)
  • Write output: 16 KB
  • Total memory traffic: ~64 KB
  • Estimated latency on H100: ~0.019 us

The difference per operation is approximately 0.009 us. But these are theoretical minimums based on pure bandwidth. In practice, kernel launch overhead, cache effects, and the FP32 upcast make the actual latencies somewhat higher. Measured microbenchmarks on H100 show:

📊

Normalization Microbenchmarks (H100, d_model=8192, batch=1)

OperationLatency (us)Memory Traffic (KB)Bandwidth Utilization
LayerNorm (fused) 0.52 96 61%
RMSNorm (fused) 0.38 -27% 64 -33% 56%
LayerNorm (naive, 2-pass) 0.91 192 70%
RMSNorm (naive) 0.61 128 70%
Note: Measured with Triton kernels, CUDA 12.3, single stream. Fused kernels compute normalization in a single pass with online variance computation.

Total Impact Across a Full Model

For Llama 3 70B (L=80L = 80 layers, dmodel=8192d_\text{model} = 8192):

  • LayerNorm total: 2×80×0.52=83.22 \times 80 \times 0.52 = 83.2 us per token
  • RMSNorm total: 2×80×0.38=60.82 \times 80 \times 0.38 = 60.8 us per token
  • Savings: 22.4 us per token

During batch-1 decode (the latency-critical path), a single token takes roughly 8-12 ms end-to-end on a tensor-parallel H100 setup. The normalization savings of 22.4 us represent about 0.2-0.3% of the total decode latency.

Normalization as % of Total Decode Latency (Llama 3 70B, H100 TP=4)

(us)
QKV + O Projections Matrix multiplications
3,200 us
FFN (up + gate + down) Largest component
5,100 us
KV Cache Read Memory-bound
800 us
RMSNorm (160 ops) 0.6% of total
61 us
Softmax + Masking Per-layer attention
120 us
AllReduce (TP comm) Network-bound
450 us

When Does It Matter?

The 0.2-0.3% savings from RMSNorm over LayerNorm sound negligible, and for a single model serving a moderate workload, they are. But consider the scale:

  • Fleet-wide impact: if you serve 100 billion tokens per day (a realistic number for a large provider), the RMSNorm savings of 22.4 us per token across 160 normalization operations amount to 100×109×22.4×106=2,240,000100 \times 10^9 \times 22.4 \times 10^{-6} = 2{,}240{,}000 GPU-seconds saved per day. That is roughly 26 GPU-days, or about one H100 running continuously.
  • Marginal but compounding: the savings from RMSNorm are one of many small optimizations (fused attention, quantized KV cache, speculative decoding) that individually contribute fractions of a percent but collectively determine whether a model is economically viable to serve.
  • Free quality: the most important property of the LayerNorm-to-RMSNorm switch is that it costs nothing in model quality. Unlike most performance optimizations (quantization, KV cache compression, attention approximation), RMSNorm provides strictly the same quality as LayerNorm. It is a rare genuine free lunch.
The Real Reason for RMSNorm

The inference latency savings from RMSNorm are real but small. The primary reason every modern LLM uses RMSNorm is the training speedup: fewer FLOPs per step across billions of training tokens translates to days or weeks of saved training time, which at the scale of frontier model training represents millions of dollars in compute cost.


Summary: A Decision Framework

The normalization landscape in transformers has converged to a clear set of best practices. Here is how to think about each choice:

LayerNorm vs RMSNorm: Use RMSNorm. It matches LayerNorm in quality, is faster in both training and inference, and is the standard in every major LLM since Llama 1. The only reason to use LayerNorm is backward compatibility with an existing codebase or checkpoint.

Pre-Norm vs Post-Norm: Use Pre-Norm. It provides dramatically better training stability, requires less hyperparameter tuning, and enables training of deep models (60+ layers) without special techniques. The slight quality advantage of Post-Norm is not worth the stability risk.

QK-Norm: Use it if your model is large (billions of parameters) or if you observe attention entropy collapse during training. The cost is negligible and the stability benefit is real. As models continue to scale, QK-Norm will likely become the default.

DeepNorm: Only relevant if you are training extremely deep models (200+ layers). For standard architectures (up to 128 layers), Pre-Norm with RMSNorm is sufficient.

Mixed precision: Always run normalization in FP32, regardless of what precision the rest of the network uses. This is non-negotiable.

📊

Normalization Decision Matrix

ScenarioNormalizationPlacementQK-NormPrecision
Standard LLM (7B-70B) RMSNorm Pre-Norm Optional FP32
Very large LLM (100B+) RMSNorm Pre-Norm Recommended FP32
Ultra-deep (200+ layers) RMSNorm + DeepNorm Modified Post-Norm Recommended FP32
Legacy BERT-style encoder LayerNorm Post-Norm No FP32
Vision Transformer LayerNorm or RMSNorm Pre-Norm Optional FP32
Note: This reflects the consensus as of early 2025. The field evolves quickly -- check recent publications for the latest developments.

The story of normalization in transformers is a story of the field learning what actually matters. BatchNorm worked for CNNs but not for transformers. LayerNorm solved the batch-dependence problem. Pre-Norm solved the gradient flow problem. RMSNorm solved the unnecessary computation problem. QK-Norm solved the attention logit growth problem. DeepNorm solved the extreme-depth problem. Each variant exists because it addresses a specific failure mode, and understanding those failure modes is the key to making informed architectural decisions.

In the next part of this series, we will turn to the Feed-Forward Network (FFN) — the other half of each transformer layer — and examine why the SwiGLU activation function replaced ReLU, what the role of the FFN is in the context of the residual stream, and how mixture-of-experts (MoE) FFN layers trade compute for capacity.