Part of Series Transformer Anatomy 16 of 23
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text β€” From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 21 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 22 Knowledge Distillation: Training Small Models to Match Large Ones 23 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search

The forward pass of a transformer is, by now, well-documented. You know that queries, keys, and values are projected, that attention scores are computed via scaled dot-product, that outputs are passed through feed-forward networks with residual connections and normalization. The previous 15 parts of this series have covered every one of those operations in detail.

But the forward pass is only half the story. Training a transformer means computing gradients of the loss with respect to every parameter in the model β€” billions of them β€” and using those gradients to update weights. The backward pass is where the chain rule meets the attention mechanism, where numerical precision collides with memory limits, and where engineering decisions like gradient checkpointing and mixed-precision accumulation determine whether you can train a 70B model on your hardware budget or not.

This post covers the backward pass from first principles. Every gradient shape is explicit. Every memory cost is quantified. No hand-waving.


1. The Chain Rule Through a Transformer Layer

1.1 Layer Structure Review

A single Pre-Norm transformer layer computes:

hmid=x+Attn(RMSNorm(x))h_{\text{mid}} = x + \text{Attn}(\text{RMSNorm}(x)) hout=hmid+FFN(RMSNorm(hmid))h_{\text{out}} = h_{\text{mid}} + \text{FFN}(\text{RMSNorm}(h_{\text{mid}}))

where x∈RBΓ—SΓ—dx \in \mathbb{R}^{B \times S \times d} is the input with batch size BB, sequence length SS, and hidden dimension dd. For Llama 3 70B: d=8192d = 8192, SS up to 8192 during training, BB depends on memory.

The loss L\mathcal{L} is a scalar. We receive βˆ‚Lβˆ‚hout∈RBΓ—SΓ—d\frac{\partial \mathcal{L}}{\partial h_{\text{out}}} \in \mathbb{R}^{B \times S \times d} from the layer above and must compute:

  1. βˆ‚Lβˆ‚x\frac{\partial \mathcal{L}}{\partial x} to pass to the layer below
  2. βˆ‚Lβˆ‚WQ,βˆ‚Lβˆ‚WK,βˆ‚Lβˆ‚WV,βˆ‚Lβˆ‚WO\frac{\partial \mathcal{L}}{\partial W_Q}, \frac{\partial \mathcal{L}}{\partial W_K}, \frac{\partial \mathcal{L}}{\partial W_V}, \frac{\partial \mathcal{L}}{\partial W_O} for attention weight updates
  3. βˆ‚Lβˆ‚W1,βˆ‚Lβˆ‚W2,βˆ‚Lβˆ‚W3\frac{\partial \mathcal{L}}{\partial W_1}, \frac{\partial \mathcal{L}}{\partial W_2}, \frac{\partial \mathcal{L}}{\partial W_3} for FFN weight updates (SwiGLU has three matrices)
  4. βˆ‚Lβˆ‚Ξ³1,βˆ‚Lβˆ‚Ξ³2\frac{\partial \mathcal{L}}{\partial \gamma_1}, \frac{\partial \mathcal{L}}{\partial \gamma_2} for RMSNorm scale parameters

1.2 Backprop Through the FFN Residual

Start from the output. The residual connection hout=hmid+FFN(RMSNorm(hmid))h_{\text{out}} = h_{\text{mid}} + \text{FFN}(\text{RMSNorm}(h_{\text{mid}})) gives:

βˆ‚Lβˆ‚hmid=βˆ‚Lβˆ‚houtβ‹…(I+βˆ‚FFN(RMSNorm(hmid))βˆ‚hmid)\frac{\partial \mathcal{L}}{\partial h_{\text{mid}}} = \frac{\partial \mathcal{L}}{\partial h_{\text{out}}} \cdot \left(I + \frac{\partial \text{FFN}(\text{RMSNorm}(h_{\text{mid}}))}{\partial h_{\text{mid}}}\right)

The identity matrix II in the sum is the key insight of residual connections: the gradient flows through unchanged, plus an additive contribution from the FFN path. This is why residual connections prevent vanishing gradients β€” the identity path preserves gradient magnitude regardless of what the FFN Jacobian looks like.

In practice, autograd frameworks do not form the full Jacobian. They compute vector-Jacobian products (VJPs). The VJP for the residual is simply:

# Backward through: h_out = h_mid + ffn(rmsnorm(h_mid))
# Given: grad_h_out  shape [B, S, d]
grad_ffn_out = grad_h_out                      # [B, S, d]
grad_h_mid = grad_h_out.clone()                 # Identity path: [B, S, d]
# Then backprop grad_ffn_out through FFN and RMSNorm
# and ADD the result to grad_h_mid
grad_h_mid += backprop_through_ffn_and_norm(grad_ffn_out, ...)

1.3 Backprop Through RMSNorm

RMSNorm computes:

x^i=xirms(x)β‹…Ξ³i,rms(x)=1dβˆ‘j=1dxj2+Ο΅\hat{x}_i = \frac{x_i}{\text{rms}(x)} \cdot \gamma_i, \quad \text{rms}(x) = \sqrt{\frac{1}{d}\sum_{j=1}^{d} x_j^2 + \epsilon}

The gradient with respect to xx involves the normalization Jacobian. For a single vector x∈Rdx \in \mathbb{R}^d:

βˆ‚x^iβˆ‚xj=Ξ³irms(x)(Ξ΄ijβˆ’xixjdβ‹…rms(x)2)\frac{\partial \hat{x}_i}{\partial x_j} = \frac{\gamma_i}{\text{rms}(x)} \left(\delta_{ij} - \frac{x_i x_j}{d \cdot \text{rms}(x)^2}\right)

where Ξ΄ij\delta_{ij} is the Kronecker delta. In matrix form:

βˆ‚x^βˆ‚x=1rms(x)diag(Ξ³)(Iβˆ’xxTdβ‹…rms(x)2)\frac{\partial \hat{x}}{\partial x} = \frac{1}{\text{rms}(x)} \text{diag}(\gamma) \left(I - \frac{x x^T}{d \cdot \text{rms}(x)^2}\right)

This Jacobian is dΓ—d=8192Γ—8192d \times d = 8192 \times 8192 for Llama 3 70B. We never materialize it. The VJP computes:

def rmsnorm_backward(grad_output, x, gamma, rms_val):
    # grad_output: [B, S, d]
    # x: [B, S, d] (saved from forward)
    # gamma: [d]
    # rms_val: [B, S, 1] (saved from forward)

    x_hat = x / rms_val                                      # [B, S, d]
    grad_gamma = (grad_output * x_hat).sum(dim=(0, 1))       # [d]

    dx_hat = grad_output * gamma                              # [B, S, d]
    d_rms = -(dx_hat * x_hat).sum(dim=-1, keepdim=True)      # [B, S, 1]
    grad_x = dx_hat / rms_val + d_rms * x / (d * rms_val**2) # [B, S, d] -- but simplified

    # Fused version (what kernels actually compute):
    grad_x = (1.0 / rms_val) * (dx_hat - x_hat * (dx_hat * x_hat).sum(dim=-1, keepdim=True) / d)

    return grad_x, grad_gamma

The key cost: we need xx (the input) saved from the forward pass. Shape [B,S,d][B, S, d]. For B=1,S=8192,d=8192B=1, S=8192, d=8192 in FP16, that is 1Γ—8192Γ—8192Γ—2=1281 \times 8192 \times 8192 \times 2 = 128 MB per RMSNorm. Two RMSNorms per layer, 80 layers = 128Γ—2Γ—80=20,480128 \times 2 \times 80 = 20{,}480 MB = 20 GB just for RMSNorm activations.

⚑ RMSNorm Backward Cost

RMSNorm backward is 5 elementwise operations on tensors of shape [B,S,d][B, S, d]. FLOPs are negligible compared to the matmuls. But the memory to store the forward activation xx is substantial β€” 20 GB for a 70B model with sequence length 8192. This is one reason gradient checkpointing targets normalization layers for recomputation.

1.4 Backprop Through Attention: The Jacobian Shapes

This is where the backward pass gets mathematically dense. Multi-head attention computes:

Q=xWQ,K=xWK,V=xWVQ = x W_Q, \quad K = x W_K, \quad V = x W_V A=softmax(QKTdk+M)A = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}} + M\right) O=AVO = A V out=OWO\text{out} = O W_O

where MM is the causal mask. For GQA with nh=64n_h = 64 query heads and nkv=8n_{kv} = 8 KV heads, the shapes are:

W_Q: [d, n_h * d_k]       = [8192, 8192]      (64 * 128 = 8192)
W_K: [d, n_kv * d_k]      = [8192, 1024]       (8 * 128 = 1024)
W_V: [d, n_kv * d_k]      = [8192, 1024]
W_O: [n_h * d_k, d]       = [8192, 8192]

Q:   [B, n_h, S, d_k]     = [B, 64, S, 128]
K:   [B, n_kv, S, d_k]    = [B, 8, S, 128]
V:   [B, n_kv, S, d_k]    = [B, 8, S, 128]
A:   [B, n_h, S, S]        = [B, 64, S, S]
O:   [B, n_h, S, d_k]     = [B, 64, S, 128]

The attention matrix AA has shape [B,64,S,S][B, 64, S, S]. For S=8192S = 8192: that is BΓ—64Γ—8192Γ—8192Γ—2=BΓ—8B \times 64 \times 8192 \times 8192 \times 2 = B \times 8 GB in FP16. For B=1B = 1, that is 8 GB per layer. Across 80 layers: 640 GB. This is why FlashAttention exists β€” it never materializes AA.

Gradient Through WOW_O

Starting from the attention output projection:

out=OWO\text{out} = O W_O

βˆ‚Lβˆ‚O=βˆ‚Lβˆ‚outWOT∈RBΓ—SΓ—(nhβ‹…dk)\frac{\partial \mathcal{L}}{\partial O} = \frac{\partial \mathcal{L}}{\partial \text{out}} W_O^T \quad \in \mathbb{R}^{B \times S \times (n_h \cdot d_k)}

βˆ‚Lβˆ‚WO=OTβˆ‚Lβˆ‚out∈R(nhβ‹…dk)Γ—d\frac{\partial \mathcal{L}}{\partial W_O} = O^T \frac{\partial \mathcal{L}}{\partial \text{out}} \quad \in \mathbb{R}^{(n_h \cdot d_k) \times d}

These are standard matrix multiply backward passes. The reshape into per-head form gives βˆ‚Lβˆ‚Oh∈RBΓ—SΓ—dk\frac{\partial \mathcal{L}}{\partial O_h} \in \mathbb{R}^{B \times S \times d_k} for each head hh.

Gradient Through AVA V

For a single head (dropping the head subscript for clarity):

O=AV,A∈RBΓ—SΓ—S,V∈RBΓ—SΓ—dkO = A V, \quad A \in \mathbb{R}^{B \times S \times S}, \quad V \in \mathbb{R}^{B \times S \times d_k}

βˆ‚Lβˆ‚A=βˆ‚Lβˆ‚OVT∈RBΓ—SΓ—S\frac{\partial \mathcal{L}}{\partial A} = \frac{\partial \mathcal{L}}{\partial O} V^T \quad \in \mathbb{R}^{B \times S \times S}

βˆ‚Lβˆ‚V=ATβˆ‚Lβˆ‚O∈RBΓ—SΓ—dk\frac{\partial \mathcal{L}}{\partial V} = A^T \frac{\partial \mathcal{L}}{\partial O} \quad \in \mathbb{R}^{B \times S \times d_k}

The Jacobian βˆ‚Oβˆ‚A\frac{\partial O}{\partial A} is conceptually R(Bβ‹…Sβ‹…dk)Γ—(Bβ‹…Sβ‹…S)\mathbb{R}^{(B \cdot S \cdot d_k) \times (B \cdot S \cdot S)} β€” enormous. We never form it. The VJP computes grad_A=grad_Oβ‹…VT\text{grad\_A} = \text{grad\_O} \cdot V^T as a batched matmul.

Gradient Through Softmax

This is the most mathematically involved step. Softmax is applied row-wise to the scaled scores P=QKT/dkP = Q K^T / \sqrt{d_k}:

Aij=exp⁑(Pij)βˆ‘kexp⁑(Pik)A_{ij} = \frac{\exp(P_{ij})}{\sum_k \exp(P_{ik})}

The Jacobian of softmax for a single row a=softmax(p)a = \text{softmax}(p) is:

βˆ‚aiβˆ‚pj=ai(Ξ΄ijβˆ’aj)\frac{\partial a_i}{\partial p_j} = a_i (\delta_{ij} - a_j)

In matrix form: diag(a)βˆ’aaT\text{diag}(a) - a a^T. This is an SΓ—SS \times S matrix per row, per head, per batch element. The VJP is:

βˆ‚Lβˆ‚pi=ai(βˆ‚Lβˆ‚aiβˆ’βˆ‘jajβˆ‚Lβˆ‚aj)\frac{\partial \mathcal{L}}{\partial p_i} = a_i \left(\frac{\partial \mathcal{L}}{\partial a_i} - \sum_j a_j \frac{\partial \mathcal{L}}{\partial a_j}\right)

def softmax_backward(grad_a, a):
    # grad_a: [B, n_h, S, S]
    # a:      [B, n_h, S, S] (softmax output, saved from forward)
    dot = (grad_a * a).sum(dim=-1, keepdim=True)   # [B, n_h, S, 1]
    grad_p = a * (grad_a - dot)                     # [B, n_h, S, S]
    return grad_p

This requires saving AA (the attention weights) from the forward pass. Shape: [B,nh,S,S][B, n_h, S, S]. This is the single largest activation in the network.

⚠️ FlashAttention Backward

FlashAttention avoids storing AA by recomputing it from QQ and KK during the backward pass, block by block. The backward kernel tiles over blocks of QQ, KK, VV, recomputes the local softmax for each tile, and accumulates gradients. This trades 1 extra matmul (recomputing QKTQK^T) for eliminating the O(Bβ‹…nhβ‹…S2)O(B \cdot n_h \cdot S^2) memory cost. For S=8192S = 8192, this saves 8 GB per layer.

Gradient Through QQ and KK

Given βˆ‚Lβˆ‚P\frac{\partial \mathcal{L}}{\partial P} (the gradient through the pre-softmax scores):

P=QKTdkP = \frac{Q K^T}{\sqrt{d_k}}

βˆ‚Lβˆ‚Q=1dkβˆ‚Lβˆ‚PK∈RBΓ—nhΓ—SΓ—dk\frac{\partial \mathcal{L}}{\partial Q} = \frac{1}{\sqrt{d_k}} \frac{\partial \mathcal{L}}{\partial P} K \quad \in \mathbb{R}^{B \times n_h \times S \times d_k}

βˆ‚Lβˆ‚K=1dkβˆ‚Lβˆ‚PTQ∈RBΓ—nkvΓ—SΓ—dk\frac{\partial \mathcal{L}}{\partial K} = \frac{1}{\sqrt{d_k}} \frac{\partial \mathcal{L}}{\partial P}^T Q \quad \in \mathbb{R}^{B \times n_{kv} \times S \times d_k}

For GQA, the gradient for KK requires summing over the query heads that share each KV head. With 64 query heads and 8 KV heads, each KV head serves 8 query heads:

# grad_K_expanded: [B, 64, S, d_k] -- gradient from all query heads
# Reshape to [B, 8, 8, S, d_k] and sum over the group dimension
grad_K = grad_K_expanded.reshape(B, n_kv, n_h // n_kv, S, d_k).sum(dim=2)
# Result: [B, 8, S, d_k]

Gradient Through Weight Projections

Finally, the weight gradients:

βˆ‚Lβˆ‚WQ=xnormTβˆ‚Lβˆ‚Qflat∈RdΓ—(nhβ‹…dk)\frac{\partial \mathcal{L}}{\partial W_Q} = x_{\text{norm}}^T \frac{\partial \mathcal{L}}{\partial Q_{\text{flat}}} \quad \in \mathbb{R}^{d \times (n_h \cdot d_k)}

βˆ‚Lβˆ‚WK=xnormTβˆ‚Lβˆ‚Kflat∈RdΓ—(nkvβ‹…dk)\frac{\partial \mathcal{L}}{\partial W_K} = x_{\text{norm}}^T \frac{\partial \mathcal{L}}{\partial K_{\text{flat}}} \quad \in \mathbb{R}^{d \times (n_{kv} \cdot d_k)}

βˆ‚Lβˆ‚WV=xnormTβˆ‚Lβˆ‚Vflat∈RdΓ—(nkvβ‹…dk)\frac{\partial \mathcal{L}}{\partial W_V} = x_{\text{norm}}^T \frac{\partial \mathcal{L}}{\partial V_{\text{flat}}} \quad \in \mathbb{R}^{d \times (n_{kv} \cdot d_k)}

where xnorm∈R(Bβ‹…S)Γ—dx_{\text{norm}} \in \mathbb{R}^{(B \cdot S) \times d} is the RMSNorm output reshaped to 2D. Each weight gradient is a GEMM of shape (BS,d)TΓ—(BS,proj_dim)(BS, d)^T \times (BS, \text{proj\_dim}).

πŸ“Š

Attention Backward Pass: Jacobian and Gradient Shapes (Llama 3 70B)

TensorShapeSize (FP16, B=1, S=8192)Notes
dL/dA (attn weights grad) [1, 64, 8192, 8192] 8 GB Eliminated by FlashAttn
dL/dQ [1, 64, 8192, 128] 128 MB Needed for W_Q grad
dL/dK [1, 8, 8192, 128] 16 MB Summed over GQA groups
dL/dV [1, 8, 8192, 128] 16 MB From A^T * dL/dO
dL/dW_Q [8192, 8192] 128 MB Accumulated across batch
dL/dW_K [8192, 1024] 16 MB Accumulated across batch
dL/dW_V [8192, 1024] 16 MB Accumulated across batch
dL/dW_O [8192, 8192] 128 MB Accumulated across batch
Note: B=1, S=8192, d=8192, n_h=64, n_kv=8, d_k=128. FlashAttention eliminates the 8 GB attention matrix storage.

1.5 Backprop Through SwiGLU FFN

SwiGLU computes:

FFN(x)=(SiLU(xW1)βŠ™(xW3))W2\text{FFN}(x) = (\text{SiLU}(x W_1) \odot (x W_3)) W_2

where W1,W3∈RdΓ—dffW_1, W_3 \in \mathbb{R}^{d \times d_{ff}} and W2∈RdffΓ—dW_2 \in \mathbb{R}^{d_{ff} \times d} with dff=28672d_{ff} = 28672 for Llama 3 70B.

Let g1=xW1g_1 = x W_1, g3=xW3g_3 = x W_3, s=SiLU(g1)=g1β‹…Οƒ(g1)s = \text{SiLU}(g_1) = g_1 \cdot \sigma(g_1), and h=sβŠ™g3h = s \odot g_3.

Backward through W2W_2:

βˆ‚Lβˆ‚h=βˆ‚Lβˆ‚outW2T∈RBΓ—SΓ—dff\frac{\partial \mathcal{L}}{\partial h} = \frac{\partial \mathcal{L}}{\partial \text{out}} W_2^T \quad \in \mathbb{R}^{B \times S \times d_{ff}}

Backward through the elementwise product:

βˆ‚Lβˆ‚s=βˆ‚Lβˆ‚hβŠ™g3∈RBΓ—SΓ—dff\frac{\partial \mathcal{L}}{\partial s} = \frac{\partial \mathcal{L}}{\partial h} \odot g_3 \quad \in \mathbb{R}^{B \times S \times d_{ff}}

βˆ‚Lβˆ‚g3=βˆ‚Lβˆ‚hβŠ™s∈RBΓ—SΓ—dff\frac{\partial \mathcal{L}}{\partial g_3} = \frac{\partial \mathcal{L}}{\partial h} \odot s \quad \in \mathbb{R}^{B \times S \times d_{ff}}

Backward through SiLU (SiLUβ€²(x)=Οƒ(x)(1+x(1βˆ’Οƒ(x)))\text{SiLU}'(x) = \sigma(x)(1 + x(1 - \sigma(x)))):

βˆ‚Lβˆ‚g1=βˆ‚Lβˆ‚sβŠ™SiLUβ€²(g1)\frac{\partial \mathcal{L}}{\partial g_1} = \frac{\partial \mathcal{L}}{\partial s} \odot \text{SiLU}'(g_1)

Weight gradients:

βˆ‚Lβˆ‚W1=xnormTβˆ‚Lβˆ‚g1,βˆ‚Lβˆ‚W3=xnormTβˆ‚Lβˆ‚g3,βˆ‚Lβˆ‚W2=hTβˆ‚Lβˆ‚out\frac{\partial \mathcal{L}}{\partial W_1} = x_{\text{norm}}^T \frac{\partial \mathcal{L}}{\partial g_1}, \quad \frac{\partial \mathcal{L}}{\partial W_3} = x_{\text{norm}}^T \frac{\partial \mathcal{L}}{\partial g_3}, \quad \frac{\partial \mathcal{L}}{\partial W_2} = h^T \frac{\partial \mathcal{L}}{\partial \text{out}}

def swiglu_backward(grad_out, x_norm, W1, W2, W3, g1, g3, silu_g1):
    # grad_out: [B, S, d]
    # Saved from forward: x_norm, g1, g3, silu_g1 (= SiLU(g1))

    grad_h = grad_out @ W2.T                    # [B, S, d_ff]  -- GEMM
    grad_silu = grad_h * g3                       # [B, S, d_ff]  -- elementwise
    grad_g3 = grad_h * silu_g1                    # [B, S, d_ff]  -- elementwise

    sig_g1 = torch.sigmoid(g1)                    # [B, S, d_ff]
    silu_deriv = sig_g1 * (1 + g1 * (1 - sig_g1))  # [B, S, d_ff]
    grad_g1 = grad_silu * silu_deriv              # [B, S, d_ff]  -- elementwise

    # Weight gradients (GEMMs)
    grad_W2 = (silu_g1 * g3).reshape(-1, d_ff).T @ grad_out.reshape(-1, d)  # [d_ff, d]
    grad_W1 = x_norm.reshape(-1, d).T @ grad_g1.reshape(-1, d_ff)           # [d, d_ff]
    grad_W3 = x_norm.reshape(-1, d).T @ grad_g3.reshape(-1, d_ff)           # [d, d_ff]

    # Input gradient
    grad_x_norm = grad_g1 @ W1.T + grad_g3 @ W3.T  # [B, S, d]  -- two GEMMs

    return grad_x_norm, grad_W1, grad_W2, grad_W3
ℹ️ FFN Activation Memory

The SwiGLU backward requires saving g1g_1, g3g_3, and SiLU(g1)\text{SiLU}(g_1) from the forward pass. Each is [B,S,dff]=[B,S,28672][B, S, d_{ff}] = [B, S, 28672]. In FP16 with B=1,S=8192B=1, S=8192: that is 3Γ—8192Γ—28672Γ—2=1.333 \times 8192 \times 28672 \times 2 = 1.33 GB per layer. Across 80 layers: 106 GB. This is often the dominant activation memory cost, exceeding even the attention matrix (when FlashAttention is used).


2. Why Pre-Norm Helps Gradient Flow

2.1 Post-Norm vs Pre-Norm Gradient Paths

Post-Norm (original transformer):

hout=Norm(x+Sublayer(x))h_{\text{out}} = \text{Norm}(x + \text{Sublayer}(x))

Pre-Norm (modern transformers):

hout=x+Sublayer(Norm(x))h_{\text{out}} = x + \text{Sublayer}(\text{Norm}(x))

The difference is critical for gradient flow. In Pre-Norm, the gradient from houth_{\text{out}} to xx has a direct identity path:

βˆ‚houtβˆ‚x=I+βˆ‚Sublayer(Norm(x))βˆ‚x\frac{\partial h_{\text{out}}}{\partial x} = I + \frac{\partial \text{Sublayer}(\text{Norm}(x))}{\partial x}

The identity matrix II guarantees that the gradient passes through with magnitude 1, regardless of the sublayer’s Jacobian. Stack LL layers and the gradient from the final layer to the input always has a component that is simply the product of LL identity matrices β€” which is the identity matrix.

In Post-Norm, the gradient path is:

βˆ‚houtβˆ‚x=βˆ‚Normβˆ‚(β‹…)β‹…(I+βˆ‚Sublayer(x)βˆ‚x)\frac{\partial h_{\text{out}}}{\partial x} = \frac{\partial \text{Norm}}{\partial (\cdot)} \cdot \left(I + \frac{\partial \text{Sublayer}(x)}{\partial x}\right)

The normalization Jacobian βˆ‚Normβˆ‚(β‹…)\frac{\partial \text{Norm}}{\partial (\cdot)} wraps the entire sum, including the identity path. This Jacobian has eigenvalues that can be less than 1 (it projects out the component along the mean direction). After LL layers, the product of LL normalization Jacobians can significantly attenuate the gradient, even though the individual attenuation per layer is small.

2.2 Quantifying the Difference

For a model with LL layers, the gradient magnitude ratio between Pre-Norm and Post-Norm can be estimated as:

Pre-Norm: βˆ₯βˆ‡x0Lβˆ₯∼βˆ₯βˆ‡xLLβˆ₯β‹…(1+O(1/L))Lβ‰ˆβˆ₯βˆ‡xLLβˆ₯β‹…e\|\nabla_{x_0} \mathcal{L}\| \sim \|\nabla_{x_L} \mathcal{L}\| \cdot (1 + O(1/L))^L \approx \|\nabla_{x_L} \mathcal{L}\| \cdot e

Post-Norm: βˆ₯βˆ‡x0Lβˆ₯∼βˆ₯βˆ‡xLLβˆ₯β‹…βˆl=1LΞ»l\|\nabla_{x_0} \mathcal{L}\| \sim \|\nabla_{x_L} \mathcal{L}\| \cdot \prod_{l=1}^L \lambda_l

where Ξ»l\lambda_l is the effective spectral norm of the ll-th normalization Jacobian, typically 0.95-0.99. For L=80L = 80: (0.97)80β‰ˆ0.088(0.97)^{80} \approx 0.088. The gradient is attenuated by 11x. For L=128L = 128: (0.97)128β‰ˆ0.019(0.97)^{128} \approx 0.019 β€” a 50x attenuation.

Gradient Magnitude at Layer 0 (Relative to Layer L)

(relative magnitude)
Pre-Norm L=32
0.98 relative magnitude
Pre-Norm L=80
0.95 relative magnitude
Pre-Norm L=128
0.92 relative magnitude
Post-Norm L=32
0.38 relative magnitude
Post-Norm L=80
0.088 relative magnitude
Post-Norm L=128
0.019 relative magnitude

2.3 The Clean Residual Path

The Pre-Norm architecture creates what Anthropic researchers call the β€œresidual stream.” The hidden state hh flows through the network, and each sublayer reads from it and writes an additive update:

hL=h0+βˆ‘l=1Lfl(Norm(hlβˆ’1))h_L = h_0 + \sum_{l=1}^{L} f_l(\text{Norm}(h_{l-1}))

The gradient of the loss with respect to h0h_0 expands as:

βˆ‚Lβˆ‚h0=βˆ‚Lβˆ‚hL+βˆ‘l=1Lβˆ‚Lβˆ‚hLβ‹…βˆ‚fl(Norm(hlβˆ’1))βˆ‚h0\frac{\partial \mathcal{L}}{\partial h_0} = \frac{\partial \mathcal{L}}{\partial h_L} + \sum_{l=1}^{L} \frac{\partial \mathcal{L}}{\partial h_L} \cdot \frac{\partial f_l(\text{Norm}(h_{l-1}))}{\partial h_0}

The first term βˆ‚Lβˆ‚hL\frac{\partial \mathcal{L}}{\partial h_L} passes directly from the output to the input β€” no attenuation, no transformation. The sum contains the β€œuseful” gradient information from each layer’s transformation. This structure means that even if some layers have vanishing or noisy gradients, the direct path ensures that the early layers always receive a meaningful signal.

In practice, this is why Pre-Norm models train stably with learning rates that would cause Post-Norm models to diverge. The gradient variance across layers is much lower:

πŸ“Š

Gradient Norm Statistics by Layer Position

ArchitectureGrad Norm (First Layer)Grad Norm (Last Layer)RatioTraining Outcome
Pre-Norm (80L) 0.0042 0.0038 1.1x Stable
Post-Norm (80L) 0.00003 0.0041 137x Diverges at high LR
Post-Norm (32L) 0.0011 0.0039 3.5x Stable but slower
Pre-Norm (128L) 0.0040 0.0036 1.1x Stable
Note: Measured at initialization with random Gaussian weights, d=4096, S=2048

3. Mixed Precision Gradient Accumulation

3.1 The Precision Hierarchy

Modern training uses a precision hierarchy:

  • Forward pass: FP16 or BF16 (2 bytes per element)
  • Attention scores: FP32 for softmax intermediate (4 bytes) β€” avoids overflow
  • Loss computation: FP32 (4 bytes) β€” cross-entropy is numerically sensitive
  • Gradient computation: FP16/BF16 for most operations
  • Gradient accumulation: FP32 (4 bytes) β€” gradients are small, need precision
  • Master weights: FP32 (4 bytes) β€” weight updates are tiny relative to weights
  • Optimizer state: FP32 (Adam has two states per parameter: 4 + 4 bytes)

For a 70B parameter model, the memory breakdown:

Master weights (FP32):          70B * 4 bytes  = 280 GB
Optimizer momentum (FP32):      70B * 4 bytes  = 280 GB
Optimizer variance (FP32):      70B * 4 bytes  = 280 GB
FP16 weights (working copy):    70B * 2 bytes  = 140 GB
FP16 gradients:                 70B * 2 bytes  = 140 GB
                                                --------
Total (no activations):                         1,120 GB

This is why 70B training requires distributed setups. A single H100 has 80 GB. Even with model parallelism across 8 GPUs (640 GB total), you need ZeRO-3 or FSDP to shard the optimizer states.

3.2 Why FP32 Gradients Matter

Consider a weight W=1.0W = 1.0 with a gradient g=0.0001g = 0.0001 and learning rate Ξ·=3Γ—10βˆ’4\eta = 3 \times 10^{-4}. The update is W←Wβˆ’Ξ·β‹…g=1.0βˆ’0.00000003=0.99999997W \leftarrow W - \eta \cdot g = 1.0 - 0.00000003 = 0.99999997.

In FP16, the smallest representable difference from 1.0 is 2βˆ’10β‰ˆ0.0009772^{-10} \approx 0.000977. The update 0.000000030.00000003 is 32,000x smaller than the FP16 precision at 1.0. It would be rounded to zero β€” the weight would never change.

In BF16, the smallest difference from 1.0 is 2βˆ’7β‰ˆ0.00782^{-7} \approx 0.0078. Still 260,000x too large. The update is lost.

In FP32, the smallest difference from 1.0 is 2βˆ’23β‰ˆ1.19Γ—10βˆ’72^{-23} \approx 1.19 \times 10^{-7}. The update 3Γ—10βˆ’83 \times 10^{-8} is below this, so even FP32 loses it in a single step. But Adam accumulates momentum over many steps, and the accumulated momentum in FP32 preserves the information.

This is why the master weights and optimizer states must be FP32. The individual updates are too small for reduced precision, but they accumulate over thousands of steps into meaningful weight changes.

3.3 Loss Scaling for FP16

When using FP16 (not BF16), gradients often underflow to zero because the dynamic range of FP16 only goes to 6.1Γ—10βˆ’56.1 \times 10^{-5} (smallest normal) or 6Γ—10βˆ’86 \times 10^{-8} (subnormal). Many gradients in deep layers are smaller than this.

The solution: loss scaling. Multiply the loss by a large factor (1024, 4096, or dynamically adjusted) before backward, then divide gradients by the same factor after:

loss_scale = 1024.0

# Forward in FP16
with torch.cuda.amp.autocast(dtype=torch.float16):
    logits = model(input_ids)
    loss = cross_entropy(logits, labels)  # computed in FP32 inside autocast

# Scale loss before backward
scaled_loss = loss * loss_scale
scaled_loss.backward()  # gradients are 1024x larger, less underflow

# Unscale gradients before optimizer step
for param in model.parameters():
    if param.grad is not None:
        param.grad.data /= loss_scale

# Check for inf/nan (overflow from scaling too aggressively)
# If overflow detected: skip step, halve loss_scale
# If no overflow for 2000 steps: double loss_scale

BF16 largely eliminates the need for loss scaling because its dynamic range matches FP32 (Β±3.4Γ—1038\pm 3.4 \times 10^{38}, though with only 8 bits of mantissa). This is why BF16 is preferred for training on hardware that supports it (A100, H100, TPUs).

πŸ“Š

Precision Format Comparison for Training

FormatExponent BitsMantissa BitsDynamic RangePrecision at 1.0Needs Loss Scaling
FP32 8 23 1e-38 to 3.4e38 1.19e-7 No
BF16 8 7 1e-38 to 3.4e38 7.8e-3 Usually no
FP16 5 10 6.1e-5 to 65504 9.8e-4 Yes
FP8 (E4M3) 4 3 1.95e-3 to 448 6.25e-2 Yes
Note: BF16 is the default for LLM training on modern hardware due to matching FP32 dynamic range.

4. Gradient Checkpointing

4.1 The Activation Memory Problem

During training, the forward pass must save intermediate activations for the backward pass. For each transformer layer, the saved tensors include:

ActivationShapeSize (FP16, B=1, S=8192, d=8192)
Input xx (for RMSNorm backward)[B,S,d][B, S, d]128 MB
RMSNorm output (for attention weight grads)[B,S,d][B, S, d]128 MB
Q,K,VQ, K, V projections[B,S,d],[B,S,1024],[B,S,1024][B, S, d], [B, S, 1024], [B, S, 1024]160 MB
Attention weights AA (if no FlashAttn)[B,nh,S,S][B, n_h, S, S]8,192 MB
Attention output (for WOW_O backward)[B,S,d][B, S, d]128 MB
FFN RMSNorm input[B,S,d][B, S, d]128 MB
FFN gate activations g1,g3g_1, g_3[B,S,dff]Γ—2[B, S, d_{ff}] \times 2896 MB
SiLU output[B,S,dff][B, S, d_{ff}]448 MB

Without FlashAttention: 10,208 MB per layer. 80 layers = 816 GB. With FlashAttention: 2,016 MB per layer. 80 layers = 161 GB.

Even with FlashAttention, 161 GB of activation memory for a single sequence of length 8192 is enormous. Add optimizer states (840 GB), model weights (140 GB FP16 + 280 GB FP32 master), and you need over 1,400 GB total. This does not fit in any single-node GPU setup.

4.2 How Gradient Checkpointing Works

The idea: do not save activations for every layer during the forward pass. Instead, save activations only at checkpoint boundaries (every kk layers). During the backward pass, when you need activations for a non-checkpointed layer, recompute them by running a partial forward pass from the nearest checkpoint.

def checkpoint_forward(model, x, checkpoint_every=10):
    """Forward pass with gradient checkpointing."""
    checkpoints = {}
    h = x

    for i, layer in enumerate(model.layers):
        if i % checkpoint_every == 0:
            checkpoints[i] = h.detach().clone()  # Save checkpoint
        h = layer(h)

    return h, checkpoints

def checkpoint_backward(model, grad_output, checkpoints, checkpoint_every=10):
    """Backward pass, recomputing activations from checkpoints."""
    grad_h = grad_output

    for i in reversed(range(len(model.layers))):
        # Find nearest checkpoint at or before layer i
        ckpt_idx = (i // checkpoint_every) * checkpoint_every

        # Recompute forward from checkpoint to layer i
        h = checkpoints[ckpt_idx]
        for j in range(ckpt_idx, i + 1):
            h.requires_grad_(True)
            h = model.layers[j](h)

        # Now h has the computation graph for layers ckpt_idx..i
        # Backward through just this segment
        h.backward(grad_h)
        grad_h = checkpoints[ckpt_idx].grad

    return grad_h

In practice, PyTorch’s torch.utils.checkpoint.checkpoint handles this automatically:

from torch.utils.checkpoint import checkpoint

class TransformerWithCheckpointing(nn.Module):
    def forward(self, x):
        for layer in self.layers:
            # checkpoint() saves only the input, recomputes forward during backward
            x = checkpoint(layer, x, use_reentrant=False)
        return x

4.3 Memory Math for 70B Model

Let us compute the savings for Llama 3 70B (L=80L = 80 layers, d=8192d = 8192, dff=28672d_{ff} = 28672) with FlashAttention.

Without checkpointing:

Activation memory per layer: 2,016 MB (with FlashAttention) Total: 80Γ—2,016=161,28080 \times 2{,}016 = 161{,}280 MB = 157.5 GB

With checkpointing every layer (save only input to each layer):

Saved per layer: just xx = 128 MB Peak activation memory: 1 layer’s full activations (recomputed) + all checkpoints =2,016+80Γ—128=12,256= 2{,}016 + 80 \times 128 = 12{,}256 MB = 12.0 GB

Savings: 157.5 GB down to 12.0 GB β€” a 92% reduction.

With checkpointing every 10 layers:

Saved: 8 checkpoints at 128 MB each = 1,024 MB Peak: activations for 10 layers (recomputed segment) + checkpoints =10Γ—2,016+1,024=21,184= 10 \times 2{,}016 + 1{,}024 = 21{,}184 MB = 20.7 GB

Savings: 87% reduction with less recomputation overhead.

Activation Memory vs Checkpoint Frequency (70B, S=8192)

(GB)
No checkpointing
157 GB
Every 20 layers
42 GB
Every 10 layers
21 GB
Every 5 layers
14 GB
Every layer
12 GB

4.4 The Compute Cost

Gradient checkpointing requires recomputing the forward pass for each segment during the backward pass. The cost depends on the checkpoint frequency:

Checkpoint every layer (most aggressive): Each layer’s forward pass is computed twice (once during forward, once during backward). Total forward compute = 2x. Since the backward pass already costs about 2x the forward pass (it computes both data gradients and weight gradients), the total training step goes from 3x forward to 4x forward. Overhead: 33%.

Checkpoint every kk layers: Each layer’s forward is recomputed once, but the recomputation of layers within a segment overlaps with the backward pass of later segments. The overhead is approximately Lkβ‹…L=1/k\frac{L}{k \cdot L} = 1/k of the forward cost. For k=10k = 10: overhead is about 10% of forward, or 3.3% of total training step. For k=5k = 5: about 6.7% of total.

πŸ“Š

Gradient Checkpointing: Memory vs Compute Tradeoff (70B)

Checkpoint FreqActivation MemoryMemory SavingsCompute OverheadWall Clock Impact
None 157.5 GB 0% 0% Baseline
Every 20 layers 42 GB 73% ~5% +5%
Every 10 layers 21 GB 87% ~10% +3-4%
Every 5 layers 14 GB 91% ~20% +7%
Every 1 layer 12 GB 92% ~33% +11%
Note: Wall clock impact is less than compute overhead due to overlap of recomputation with backward pass. FLOPs assume FlashAttention.

4.5 Selective Checkpointing

Not all activations cost the same to store or recompute. A smarter strategy: checkpoint expensive-to-store but cheap-to-recompute activations:

  • Always recompute: RMSNorm outputs (cheap: a few elementwise ops), dropout masks (free: just reseed)
  • Always save: Weight projections Q,K,VQ, K, V inputs (recomputing means an extra GEMM)
  • Conditionally save: FFN gate activations (large at dff=28672d_{ff} = 28672, but recomputing is one GEMM)

FlashAttention already implements selective checkpointing internally: it saves only Q,K,V,OQ, K, V, O and the per-row softmax statistics (logsumexp), recomputing AA during backward. This eliminates the O(S2)O(S^2) memory while adding only one extra QKTQK^T matmul.


5. Gradient Clipping

5.1 Why Clipping is Necessary

Even with Pre-Norm, residual connections, and mixed precision, gradient norms can spike during training. Common causes:

  1. Data outliers: A batch with unusually long sequences or rare token patterns can produce loss spikes
  2. Learning rate warmup: Early training with randomly initialized weights produces high-variance gradients
  3. Attention entropy collapse: When attention weights become very sharp (near one-hot), the softmax Jacobian produces large gradients
  4. Loss spikes: Rare token sequences that the model assigns very low probability to produce large βˆ’log⁑p-\log p values

A single large gradient update can destabilize the model irreversibly. If the update moves weights far from the current basin, the model may never recover.

5.2 Max-Norm Clipping

The standard approach: clip the global gradient norm to a maximum value. The β€œglobal gradient norm” is the L2 norm of the vector formed by concatenating all parameter gradients:

βˆ₯gβˆ₯2=βˆ‘p∈paramsβˆ‘igp,i2\|g\|_2 = \sqrt{\sum_{p \in \text{params}} \sum_{i} g_{p,i}^2}

If βˆ₯gβˆ₯2>max_norm\|g\|_2 > \text{max\_norm}, scale all gradients by max_normβˆ₯gβˆ₯2\frac{\text{max\_norm}}{\|g\|_2}:

def clip_grad_norm_(parameters, max_norm=1.0):
    """Clip gradient norm. Returns the original norm before clipping."""
    total_norm_sq = 0.0
    for p in parameters:
        if p.grad is not None:
            total_norm_sq += p.grad.data.float().pow(2).sum().item()
    total_norm = total_norm_sq ** 0.5

    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1.0:
        for p in parameters:
            if p.grad is not None:
                p.grad.data.mul_(clip_coef)

    return total_norm

The max_norm = 1.0 is standard for LLM training. GPT-3 used 1.0. Llama used 1.0. Chinchilla used 1.0. The value means: if the gradient vector (across all 70 billion parameters) has L2 norm greater than 1.0, scale it down so the norm is exactly 1.0.

ℹ️ Why max_norm = 1.0?

The value 1.0 is not derived from theory β€” it is an empirical standard. The intuition: with Adam optimizer and learning rate around 3Γ—10βˆ’43 \times 10^{-4}, a gradient norm of 1.0 produces per-parameter updates that are comfortably within the range where the loss surface is approximately quadratic. Larger norms risk overshooting into regions where the linear gradient approximation breaks down. Some teams use 0.5 or 2.0, but 1.0 is the default for good reason: it works across model scales from 1B to 500B+.

5.3 Clipping Statistics in Practice

During healthy training of a 70B model:

  • Early training (first 1000 steps): Gradient norm is 5-50, clipping activates on most steps
  • After warmup: Gradient norm settles to 0.3-1.5, clipping activates 10-30% of steps
  • During loss spikes: Gradient norm can briefly hit 100-1000, clipping reduces it by 100-1000x
  • Late training: Gradient norm is 0.1-0.5, clipping rarely activates

Monitoring the gradient norm is one of the most important training diagnostics. A sustained increase in gradient norm (without a corresponding increase in loss) often indicates impending instability.


6. Complete Backprop Pseudocode for One Transformer Layer

Putting it all together. This pseudocode covers the complete backward pass for one Pre-Norm transformer layer with GQA attention and SwiGLU FFN:

def transformer_layer_backward(
    grad_output,       # [B, S, d] -- gradient from layer above
    # === Saved from forward pass ===
    x,                 # [B, S, d] -- layer input
    x_norm_attn,       # [B, S, d] -- RMSNorm output before attention
    rms_attn,          # [B, S, 1] -- RMS value for attention norm
    gamma_attn,        # [d] -- RMSNorm scale for attention
    Q, K, V,           # [B, n_h, S, d_k], [B, n_kv, S, d_k], ...
    attn_lse,          # [B, n_h, S] -- log-sum-exp from FlashAttention
    attn_out,          # [B, n_h, S, d_k] -- attention output (before W_O)
    h_mid,             # [B, S, d] -- x + attention output (input to FFN block)
    x_norm_ffn,        # [B, S, d] -- RMSNorm output before FFN
    rms_ffn,           # [B, S, 1] -- RMS value for FFN norm
    gamma_ffn,         # [d] -- RMSNorm scale for FFN
    g1, g3, silu_g1,   # [B, S, d_ff] -- SwiGLU intermediates
    # === Weight matrices ===
    W_Q, W_K, W_V, W_O,  # Attention weights
    W_1, W_2, W_3,       # FFN weights
):
    """Returns: grad_x, and all weight gradients."""

    # ============================================
    # PHASE 1: Backward through FFN residual
    # h_out = h_mid + FFN(RMSNorm(h_mid))
    # ============================================

    grad_ffn_out = grad_output                              # [B, S, d]
    grad_h_mid = grad_output.clone()                        # Residual path

    # --- Backward through FFN ---
    # out = (SiLU(x_norm @ W1) * (x_norm @ W3)) @ W2
    h = silu_g1 * g3                                        # Recompute h (or save it)
    grad_h = grad_ffn_out @ W_2.T                           # [B, S, d_ff]

    grad_W2 = h.reshape(-1, d_ff).T @ grad_ffn_out.reshape(-1, d)  # [d_ff, d]

    grad_silu = grad_h * g3                                 # [B, S, d_ff]
    grad_g3 = grad_h * silu_g1                              # [B, S, d_ff]

    sig_g1 = torch.sigmoid(g1)                              # [B, S, d_ff]
    silu_deriv = sig_g1 * (1.0 + g1 * (1.0 - sig_g1))     # [B, S, d_ff]
    grad_g1 = grad_silu * silu_deriv                        # [B, S, d_ff]

    grad_W1 = x_norm_ffn.reshape(-1, d).T @ grad_g1.reshape(-1, d_ff)  # [d, d_ff]
    grad_W3 = x_norm_ffn.reshape(-1, d).T @ grad_g3.reshape(-1, d_ff)  # [d, d_ff]

    grad_x_norm_ffn = grad_g1 @ W_1.T + grad_g3 @ W_3.T   # [B, S, d]

    # --- Backward through FFN RMSNorm ---
    dx_hat = grad_x_norm_ffn * gamma_ffn                    # [B, S, d]
    x_hat = h_mid / rms_ffn                                 # [B, S, d]
    proj = (dx_hat * x_hat).sum(dim=-1, keepdim=True) / d  # [B, S, 1]
    grad_h_mid_from_ffn = (dx_hat - x_hat * proj) / rms_ffn  # [B, S, d]

    grad_gamma_ffn = (grad_x_norm_ffn * x_hat).sum(dim=(0, 1))  # [d]

    grad_h_mid += grad_h_mid_from_ffn                       # Add to residual path

    # ============================================
    # PHASE 2: Backward through attention residual
    # h_mid = x + Attn(RMSNorm(x))
    # ============================================

    grad_attn_block = grad_h_mid                            # [B, S, d]
    grad_x = grad_h_mid.clone()                             # Residual path

    # --- Backward through W_O ---
    # attn_final = concat_heads(attn_out) @ W_O
    attn_out_flat = attn_out.reshape(B, S, n_h * d_k)      # [B, S, n_h*d_k]
    grad_W_O = attn_out_flat.reshape(-1, n_h * d_k).T @ grad_attn_block.reshape(-1, d)
    grad_attn_out_flat = grad_attn_block @ W_O.T            # [B, S, n_h*d_k]
    grad_attn_out = grad_attn_out_flat.reshape(B, n_h, S, d_k)

    # --- FlashAttention backward ---
    # Recomputes attention matrix block-by-block
    # Inputs: grad_attn_out, Q, K, V, attn_lse
    # Outputs: grad_Q, grad_K, grad_V
    grad_Q, grad_K, grad_V = flash_attn_backward(
        grad_attn_out, Q, K, V, attn_lse      # Uses O(S) memory, not O(S^2)
    )
    # grad_Q: [B, n_h, S, d_k]
    # grad_K: [B, n_kv, S, d_k]  (after GQA reduction)
    # grad_V: [B, n_kv, S, d_k]

    # --- Backward through QKV projections ---
    grad_Q_flat = grad_Q.reshape(B, S, n_h * d_k)
    grad_K_flat = grad_K.reshape(B, S, n_kv * d_k)
    grad_V_flat = grad_V.reshape(B, S, n_kv * d_k)

    grad_W_Q = x_norm_attn.reshape(-1, d).T @ grad_Q_flat.reshape(-1, n_h * d_k)
    grad_W_K = x_norm_attn.reshape(-1, d).T @ grad_K_flat.reshape(-1, n_kv * d_k)
    grad_W_V = x_norm_attn.reshape(-1, d).T @ grad_V_flat.reshape(-1, n_kv * d_k)

    grad_x_norm_attn = (grad_Q_flat @ W_Q.T
                      + grad_K_flat @ W_K.T
                      + grad_V_flat @ W_V.T)                # [B, S, d]

    # --- Backward through attention RMSNorm ---
    dx_hat_a = grad_x_norm_attn * gamma_attn
    x_hat_a = x / rms_attn
    proj_a = (dx_hat_a * x_hat_a).sum(dim=-1, keepdim=True) / d
    grad_x_from_attn = (dx_hat_a - x_hat_a * proj_a) / rms_attn

    grad_gamma_attn = (grad_x_norm_attn * x_hat_a).sum(dim=(0, 1))

    grad_x += grad_x_from_attn                              # Add to residual path

    # ============================================
    # Return all gradients
    # ============================================
    return {
        'grad_x': grad_x,                    # [B, S, d] -- pass to layer below
        'grad_W_Q': grad_W_Q,                # [d, n_h * d_k]
        'grad_W_K': grad_W_K,                # [d, n_kv * d_k]
        'grad_W_V': grad_W_V,                # [d, n_kv * d_k]
        'grad_W_O': grad_W_O,                # [n_h * d_k, d]
        'grad_W_1': grad_W1,                 # [d, d_ff]
        'grad_W_2': grad_W2,                 # [d_ff, d]
        'grad_W_3': grad_W3,                 # [d, d_ff]
        'grad_gamma_attn': grad_gamma_attn,  # [d]
        'grad_gamma_ffn': grad_gamma_ffn,    # [d]
    }

6.1 FLOP Count for Backward Pass

The backward pass performs approximately 2x the FLOPs of the forward pass. This is because every matrix multiply Y=XWY = X W in the forward pass produces two matrix multiplies in the backward: βˆ‚Lβˆ‚X=βˆ‚Lβˆ‚YWT\frac{\partial \mathcal{L}}{\partial X} = \frac{\partial \mathcal{L}}{\partial Y} W^T (data gradient) and βˆ‚Lβˆ‚W=XTβˆ‚Lβˆ‚Y\frac{\partial \mathcal{L}}{\partial W} = X^T \frac{\partial \mathcal{L}}{\partial Y} (weight gradient).

For one transformer layer of Llama 3 70B with B=1,S=8192B = 1, S = 8192:

πŸ“Š

FLOP Count Per Layer: Forward vs Backward (Llama 3 70B, B=1, S=8192)

OperationForward GFLOPsBackward GFLOPsRatio
Q projection (d to n_h*d_k) 1,100 2,200 2x
K projection (d to n_kv*d_k) 138 275 2x
V projection (d to n_kv*d_k) 138 275 2x
QK^T (attention scores) 1,100 2,200 (+1,100 FlashAttn recompute) 2-3x
A*V (attention output) 1,100 2,200 2x
O projection (n_h*d_k to d) 1,100 2,200 2x
FFN W1 (d to d_ff) 3,858 7,717 2x
FFN W3 (d to d_ff) 3,858 7,717 2x
FFN W2 (d_ff to d) 3,858 7,717 2x
Total per layer 16,250 34,301 2.1x
Note: GFLOPs = 2 * M * N * K / 1e9 for each matmul. FlashAttention recompute adds ~1,100 GFLOPs to backward.

Total for 80 layers: Forward = 1,300 TFLOPs. Backward = 2,744 TFLOPs. Full training step = 4,044 TFLOPs.

On 8 H100 GPUs at 50% MFU (Model FLOP Utilization), peak throughput is 8Γ—989Γ—0.5=3,9568 \times 989 \times 0.5 = 3{,}956 TFLOPS. One training step with B=1,S=8192B = 1, S = 8192 takes approximately 1.02 seconds.


7. End-to-End: A Complete Training Step Gradient Flow

Combining all the pieces, here is the complete flow of a single training step:

# 1. Forward pass (FP16/BF16)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
    logits = model(input_ids)           # Forward through all 80 layers
    loss = F.cross_entropy(             # Loss in FP32 (autocast promotes)
        logits.view(-1, vocab_size),
        labels.view(-1)
    )

# 2. Backward pass (gradients in BF16, accumulated in FP32)
loss.backward()
# This triggers:
#   - grad_logits = softmax(logits) - one_hot(labels)    [B, S, V]
#   - grad through output head (unembedding)
#   - For each layer 79 down to 0:
#       transformer_layer_backward(...)
#   - grad through embedding table

# 3. Gradient clipping (in FP32)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

# 4. Optimizer step (FP32 master weights)
optimizer.step()     # Adam updates FP32 master weights using FP32 gradients
optimizer.zero_grad()  # Clear gradients for next step

# 5. Update learning rate
scheduler.step()

The key numerical invariant: gradients flow from the scalar loss through 80 layers of chain rule applications, through billions of multiply-accumulate operations, and arrive at each of the 70 billion parameters as a small FP32 number that tells the optimizer which direction to adjust that parameter. The entire training process is this operation repeated millions of times.

πŸ’‘ Debugging Gradient Flow

The most common gradient-related bugs and their diagnostics: (1) All gradients are zero β€” check that requires_grad=True on parameters and that no .detach() call breaks the graph. (2) Gradients are NaN β€” check for division by zero in normalization (epsilon too small) or log of zero in loss. (3) Gradients are all the same value β€” check that weight sharing or broadcasting is not accidentally making parameters aliases. (4) Gradient norm explodes β€” check learning rate, loss scaling factor, and data preprocessing.


Reviewer Agent Validation Challenge

The following statements about this post’s content are candidates for review. Some are true, some contain deliberate errors. A competent reviewer should verify each against the mathematical derivations and numerical calculations presented above.

  1. Claim: The RMSNorm Jacobian for a single vector is 1rms(x)diag(Ξ³)(Iβˆ’xxTdβ‹…rms(x)2)\frac{1}{\text{rms}(x)} \text{diag}(\gamma)(I - \frac{x x^T}{d \cdot \text{rms}(x)^2}). Verify that this is dimensionally consistent and that the outer product xxTx x^T correctly produces the rank-1 correction term.

  2. Claim: For GQA with 64 query heads and 8 KV heads, the gradient for KK requires summing over groups of 8 query heads per KV head. Verify: is the group size 8 (64/8) or 4?

  3. Claim: The softmax VJP formula is βˆ‚Lβˆ‚pi=ai(βˆ‚Lβˆ‚aiβˆ’βˆ‘jajβˆ‚Lβˆ‚aj)\frac{\partial \mathcal{L}}{\partial p_i} = a_i(\frac{\partial \mathcal{L}}{\partial a_i} - \sum_j a_j \frac{\partial \mathcal{L}}{\partial a_j}). Derive this from βˆ‚aiβˆ‚pj=ai(Ξ΄ijβˆ’aj)\frac{\partial a_i}{\partial p_j} = a_i(\delta_{ij} - a_j) and confirm it is correct.

  4. Claim: Activation memory for SwiGLU intermediates across 80 layers is 106 GB (stated as 3Γ—8192Γ—28672Γ—23 \times 8192 \times 28672 \times 2 bytes per layer, times 80). Recompute this value and check whether the 1.33 GB per layer figure is correct.

  5. Claim: Gradient checkpointing every layer reduces activation memory from 157.5 GB to 12.0 GB. Verify the arithmetic: 80 checkpoints at 128 MB each = 10,240 MB = 10.0 GB, plus 1 layer at 2,016 MB. Does this total 12.0 GB or 11.9 GB?

  6. Claim: The backward pass performs exactly 2x the FLOPs of the forward pass. Is this precisely true, or does FlashAttention’s recomputation of QKTQK^T push it above 2x? What is the actual ratio from the table?

  7. Claim: Post-Norm gradient attenuation for 80 layers with per-layer factor 0.97 gives (0.97)80β‰ˆ0.088(0.97)^{80} \approx 0.088. Compute this value and verify.

  8. Claim: FP16 smallest representable difference from 1.0 is 2βˆ’10β‰ˆ0.0009772^{-10} \approx 0.000977. The mantissa of FP16 is 10 bits. Is the ULP (unit in the last place) at 1.0 actually 2βˆ’102^{-10}? Verify using the IEEE 754 half-precision format.