The forward pass of a transformer is, by now, well-documented. You know that queries, keys, and values are projected, that attention scores are computed via scaled dot-product, that outputs are passed through feed-forward networks with residual connections and normalization. The previous 15 parts of this series have covered every one of those operations in detail.
But the forward pass is only half the story. Training a transformer means computing gradients of the loss with respect to every parameter in the model β billions of them β and using those gradients to update weights. The backward pass is where the chain rule meets the attention mechanism, where numerical precision collides with memory limits, and where engineering decisions like gradient checkpointing and mixed-precision accumulation determine whether you can train a 70B model on your hardware budget or not.
This post covers the backward pass from first principles. Every gradient shape is explicit. Every memory cost is quantified. No hand-waving.
1. The Chain Rule Through a Transformer Layer
1.1 Layer Structure Review
A single Pre-Norm transformer layer computes:
where is the input with batch size , sequence length , and hidden dimension . For Llama 3 70B: , up to 8192 during training, depends on memory.
The loss is a scalar. We receive from the layer above and must compute:
- to pass to the layer below
- for attention weight updates
- for FFN weight updates (SwiGLU has three matrices)
- for RMSNorm scale parameters
1.2 Backprop Through the FFN Residual
Start from the output. The residual connection gives:
The identity matrix in the sum is the key insight of residual connections: the gradient flows through unchanged, plus an additive contribution from the FFN path. This is why residual connections prevent vanishing gradients β the identity path preserves gradient magnitude regardless of what the FFN Jacobian looks like.
In practice, autograd frameworks do not form the full Jacobian. They compute vector-Jacobian products (VJPs). The VJP for the residual is simply:
# Backward through: h_out = h_mid + ffn(rmsnorm(h_mid))
# Given: grad_h_out shape [B, S, d]
grad_ffn_out = grad_h_out # [B, S, d]
grad_h_mid = grad_h_out.clone() # Identity path: [B, S, d]
# Then backprop grad_ffn_out through FFN and RMSNorm
# and ADD the result to grad_h_mid
grad_h_mid += backprop_through_ffn_and_norm(grad_ffn_out, ...)
1.3 Backprop Through RMSNorm
RMSNorm computes:
The gradient with respect to involves the normalization Jacobian. For a single vector :
where is the Kronecker delta. In matrix form:
This Jacobian is for Llama 3 70B. We never materialize it. The VJP computes:
def rmsnorm_backward(grad_output, x, gamma, rms_val):
# grad_output: [B, S, d]
# x: [B, S, d] (saved from forward)
# gamma: [d]
# rms_val: [B, S, 1] (saved from forward)
x_hat = x / rms_val # [B, S, d]
grad_gamma = (grad_output * x_hat).sum(dim=(0, 1)) # [d]
dx_hat = grad_output * gamma # [B, S, d]
d_rms = -(dx_hat * x_hat).sum(dim=-1, keepdim=True) # [B, S, 1]
grad_x = dx_hat / rms_val + d_rms * x / (d * rms_val**2) # [B, S, d] -- but simplified
# Fused version (what kernels actually compute):
grad_x = (1.0 / rms_val) * (dx_hat - x_hat * (dx_hat * x_hat).sum(dim=-1, keepdim=True) / d)
return grad_x, grad_gamma
The key cost: we need (the input) saved from the forward pass. Shape . For in FP16, that is MB per RMSNorm. Two RMSNorms per layer, 80 layers = MB = 20 GB just for RMSNorm activations.
RMSNorm backward is 5 elementwise operations on tensors of shape . FLOPs are negligible compared to the matmuls. But the memory to store the forward activation is substantial β 20 GB for a 70B model with sequence length 8192. This is one reason gradient checkpointing targets normalization layers for recomputation.
1.4 Backprop Through Attention: The Jacobian Shapes
This is where the backward pass gets mathematically dense. Multi-head attention computes:
where is the causal mask. For GQA with query heads and KV heads, the shapes are:
W_Q: [d, n_h * d_k] = [8192, 8192] (64 * 128 = 8192)
W_K: [d, n_kv * d_k] = [8192, 1024] (8 * 128 = 1024)
W_V: [d, n_kv * d_k] = [8192, 1024]
W_O: [n_h * d_k, d] = [8192, 8192]
Q: [B, n_h, S, d_k] = [B, 64, S, 128]
K: [B, n_kv, S, d_k] = [B, 8, S, 128]
V: [B, n_kv, S, d_k] = [B, 8, S, 128]
A: [B, n_h, S, S] = [B, 64, S, S]
O: [B, n_h, S, d_k] = [B, 64, S, 128]
The attention matrix has shape . For : that is GB in FP16. For , that is 8 GB per layer. Across 80 layers: 640 GB. This is why FlashAttention exists β it never materializes .
Gradient Through
Starting from the attention output projection:
These are standard matrix multiply backward passes. The reshape into per-head form gives for each head .
Gradient Through
For a single head (dropping the head subscript for clarity):
The Jacobian is conceptually β enormous. We never form it. The VJP computes as a batched matmul.
Gradient Through Softmax
This is the most mathematically involved step. Softmax is applied row-wise to the scaled scores :
The Jacobian of softmax for a single row is:
In matrix form: . This is an matrix per row, per head, per batch element. The VJP is:
def softmax_backward(grad_a, a):
# grad_a: [B, n_h, S, S]
# a: [B, n_h, S, S] (softmax output, saved from forward)
dot = (grad_a * a).sum(dim=-1, keepdim=True) # [B, n_h, S, 1]
grad_p = a * (grad_a - dot) # [B, n_h, S, S]
return grad_p
This requires saving (the attention weights) from the forward pass. Shape: . This is the single largest activation in the network.
FlashAttention avoids storing by recomputing it from and during the backward pass, block by block. The backward kernel tiles over blocks of , , , recomputes the local softmax for each tile, and accumulates gradients. This trades 1 extra matmul (recomputing ) for eliminating the memory cost. For , this saves 8 GB per layer.
Gradient Through and
Given (the gradient through the pre-softmax scores):
For GQA, the gradient for requires summing over the query heads that share each KV head. With 64 query heads and 8 KV heads, each KV head serves 8 query heads:
# grad_K_expanded: [B, 64, S, d_k] -- gradient from all query heads
# Reshape to [B, 8, 8, S, d_k] and sum over the group dimension
grad_K = grad_K_expanded.reshape(B, n_kv, n_h // n_kv, S, d_k).sum(dim=2)
# Result: [B, 8, S, d_k]
Gradient Through Weight Projections
Finally, the weight gradients:
where is the RMSNorm output reshaped to 2D. Each weight gradient is a GEMM of shape .
Attention Backward Pass: Jacobian and Gradient Shapes (Llama 3 70B)
| Tensor | Shape | Size (FP16, B=1, S=8192) | Notes |
|---|---|---|---|
| dL/dA (attn weights grad) | [1, 64, 8192, 8192] | 8 GB | Eliminated by FlashAttn |
| dL/dQ | [1, 64, 8192, 128] | 128 MB | Needed for W_Q grad |
| dL/dK | [1, 8, 8192, 128] | 16 MB | Summed over GQA groups |
| dL/dV | [1, 8, 8192, 128] | 16 MB | From A^T * dL/dO |
| dL/dW_Q | [8192, 8192] | 128 MB | Accumulated across batch |
| dL/dW_K | [8192, 1024] | 16 MB | Accumulated across batch |
| dL/dW_V | [8192, 1024] | 16 MB | Accumulated across batch |
| dL/dW_O | [8192, 8192] | 128 MB | Accumulated across batch |
1.5 Backprop Through SwiGLU FFN
SwiGLU computes:
where and with for Llama 3 70B.
Let , , , and .
Backward through :
Backward through the elementwise product:
Backward through SiLU ():
Weight gradients:
def swiglu_backward(grad_out, x_norm, W1, W2, W3, g1, g3, silu_g1):
# grad_out: [B, S, d]
# Saved from forward: x_norm, g1, g3, silu_g1 (= SiLU(g1))
grad_h = grad_out @ W2.T # [B, S, d_ff] -- GEMM
grad_silu = grad_h * g3 # [B, S, d_ff] -- elementwise
grad_g3 = grad_h * silu_g1 # [B, S, d_ff] -- elementwise
sig_g1 = torch.sigmoid(g1) # [B, S, d_ff]
silu_deriv = sig_g1 * (1 + g1 * (1 - sig_g1)) # [B, S, d_ff]
grad_g1 = grad_silu * silu_deriv # [B, S, d_ff] -- elementwise
# Weight gradients (GEMMs)
grad_W2 = (silu_g1 * g3).reshape(-1, d_ff).T @ grad_out.reshape(-1, d) # [d_ff, d]
grad_W1 = x_norm.reshape(-1, d).T @ grad_g1.reshape(-1, d_ff) # [d, d_ff]
grad_W3 = x_norm.reshape(-1, d).T @ grad_g3.reshape(-1, d_ff) # [d, d_ff]
# Input gradient
grad_x_norm = grad_g1 @ W1.T + grad_g3 @ W3.T # [B, S, d] -- two GEMMs
return grad_x_norm, grad_W1, grad_W2, grad_W3
The SwiGLU backward requires saving , , and from the forward pass. Each is . In FP16 with : that is GB per layer. Across 80 layers: 106 GB. This is often the dominant activation memory cost, exceeding even the attention matrix (when FlashAttention is used).
2. Why Pre-Norm Helps Gradient Flow
2.1 Post-Norm vs Pre-Norm Gradient Paths
Post-Norm (original transformer):
Pre-Norm (modern transformers):
The difference is critical for gradient flow. In Pre-Norm, the gradient from to has a direct identity path:
The identity matrix guarantees that the gradient passes through with magnitude 1, regardless of the sublayerβs Jacobian. Stack layers and the gradient from the final layer to the input always has a component that is simply the product of identity matrices β which is the identity matrix.
In Post-Norm, the gradient path is:
The normalization Jacobian wraps the entire sum, including the identity path. This Jacobian has eigenvalues that can be less than 1 (it projects out the component along the mean direction). After layers, the product of normalization Jacobians can significantly attenuate the gradient, even though the individual attenuation per layer is small.
2.2 Quantifying the Difference
For a model with layers, the gradient magnitude ratio between Pre-Norm and Post-Norm can be estimated as:
Pre-Norm:
Post-Norm:
where is the effective spectral norm of the -th normalization Jacobian, typically 0.95-0.99. For : . The gradient is attenuated by 11x. For : β a 50x attenuation.
Gradient Magnitude at Layer 0 (Relative to Layer L)
(relative magnitude)2.3 The Clean Residual Path
The Pre-Norm architecture creates what Anthropic researchers call the βresidual stream.β The hidden state flows through the network, and each sublayer reads from it and writes an additive update:
The gradient of the loss with respect to expands as:
The first term passes directly from the output to the input β no attenuation, no transformation. The sum contains the βusefulβ gradient information from each layerβs transformation. This structure means that even if some layers have vanishing or noisy gradients, the direct path ensures that the early layers always receive a meaningful signal.
In practice, this is why Pre-Norm models train stably with learning rates that would cause Post-Norm models to diverge. The gradient variance across layers is much lower:
Gradient Norm Statistics by Layer Position
| Architecture | Grad Norm (First Layer) | Grad Norm (Last Layer) | Ratio | Training Outcome |
|---|---|---|---|---|
| Pre-Norm (80L) | 0.0042 | 0.0038 | 1.1x | Stable |
| Post-Norm (80L) | 0.00003 | 0.0041 | 137x | Diverges at high LR |
| Post-Norm (32L) | 0.0011 | 0.0039 | 3.5x | Stable but slower |
| Pre-Norm (128L) | 0.0040 | 0.0036 | 1.1x | Stable |
3. Mixed Precision Gradient Accumulation
3.1 The Precision Hierarchy
Modern training uses a precision hierarchy:
- Forward pass: FP16 or BF16 (2 bytes per element)
- Attention scores: FP32 for softmax intermediate (4 bytes) β avoids overflow
- Loss computation: FP32 (4 bytes) β cross-entropy is numerically sensitive
- Gradient computation: FP16/BF16 for most operations
- Gradient accumulation: FP32 (4 bytes) β gradients are small, need precision
- Master weights: FP32 (4 bytes) β weight updates are tiny relative to weights
- Optimizer state: FP32 (Adam has two states per parameter: 4 + 4 bytes)
For a 70B parameter model, the memory breakdown:
Master weights (FP32): 70B * 4 bytes = 280 GB
Optimizer momentum (FP32): 70B * 4 bytes = 280 GB
Optimizer variance (FP32): 70B * 4 bytes = 280 GB
FP16 weights (working copy): 70B * 2 bytes = 140 GB
FP16 gradients: 70B * 2 bytes = 140 GB
--------
Total (no activations): 1,120 GB
This is why 70B training requires distributed setups. A single H100 has 80 GB. Even with model parallelism across 8 GPUs (640 GB total), you need ZeRO-3 or FSDP to shard the optimizer states.
3.2 Why FP32 Gradients Matter
Consider a weight with a gradient and learning rate . The update is .
In FP16, the smallest representable difference from 1.0 is . The update is 32,000x smaller than the FP16 precision at 1.0. It would be rounded to zero β the weight would never change.
In BF16, the smallest difference from 1.0 is . Still 260,000x too large. The update is lost.
In FP32, the smallest difference from 1.0 is . The update is below this, so even FP32 loses it in a single step. But Adam accumulates momentum over many steps, and the accumulated momentum in FP32 preserves the information.
This is why the master weights and optimizer states must be FP32. The individual updates are too small for reduced precision, but they accumulate over thousands of steps into meaningful weight changes.
3.3 Loss Scaling for FP16
When using FP16 (not BF16), gradients often underflow to zero because the dynamic range of FP16 only goes to (smallest normal) or (subnormal). Many gradients in deep layers are smaller than this.
The solution: loss scaling. Multiply the loss by a large factor (1024, 4096, or dynamically adjusted) before backward, then divide gradients by the same factor after:
loss_scale = 1024.0
# Forward in FP16
with torch.cuda.amp.autocast(dtype=torch.float16):
logits = model(input_ids)
loss = cross_entropy(logits, labels) # computed in FP32 inside autocast
# Scale loss before backward
scaled_loss = loss * loss_scale
scaled_loss.backward() # gradients are 1024x larger, less underflow
# Unscale gradients before optimizer step
for param in model.parameters():
if param.grad is not None:
param.grad.data /= loss_scale
# Check for inf/nan (overflow from scaling too aggressively)
# If overflow detected: skip step, halve loss_scale
# If no overflow for 2000 steps: double loss_scale
BF16 largely eliminates the need for loss scaling because its dynamic range matches FP32 (, though with only 8 bits of mantissa). This is why BF16 is preferred for training on hardware that supports it (A100, H100, TPUs).
Precision Format Comparison for Training
| Format | Exponent Bits | Mantissa Bits | Dynamic Range | Precision at 1.0 | Needs Loss Scaling |
|---|---|---|---|---|---|
| FP32 | 8 | 23 | 1e-38 to 3.4e38 | 1.19e-7 | No |
| BF16 | 8 | 7 | 1e-38 to 3.4e38 | 7.8e-3 | Usually no |
| FP16 | 5 | 10 | 6.1e-5 to 65504 | 9.8e-4 | Yes |
| FP8 (E4M3) | 4 | 3 | 1.95e-3 to 448 | 6.25e-2 | Yes |
4. Gradient Checkpointing
4.1 The Activation Memory Problem
During training, the forward pass must save intermediate activations for the backward pass. For each transformer layer, the saved tensors include:
| Activation | Shape | Size (FP16, B=1, S=8192, d=8192) |
|---|---|---|
| Input (for RMSNorm backward) | 128 MB | |
| RMSNorm output (for attention weight grads) | 128 MB | |
| projections | 160 MB | |
| Attention weights (if no FlashAttn) | 8,192 MB | |
| Attention output (for backward) | 128 MB | |
| FFN RMSNorm input | 128 MB | |
| FFN gate activations | 896 MB | |
| SiLU output | 448 MB |
Without FlashAttention: 10,208 MB per layer. 80 layers = 816 GB. With FlashAttention: 2,016 MB per layer. 80 layers = 161 GB.
Even with FlashAttention, 161 GB of activation memory for a single sequence of length 8192 is enormous. Add optimizer states (840 GB), model weights (140 GB FP16 + 280 GB FP32 master), and you need over 1,400 GB total. This does not fit in any single-node GPU setup.
4.2 How Gradient Checkpointing Works
The idea: do not save activations for every layer during the forward pass. Instead, save activations only at checkpoint boundaries (every layers). During the backward pass, when you need activations for a non-checkpointed layer, recompute them by running a partial forward pass from the nearest checkpoint.
def checkpoint_forward(model, x, checkpoint_every=10):
"""Forward pass with gradient checkpointing."""
checkpoints = {}
h = x
for i, layer in enumerate(model.layers):
if i % checkpoint_every == 0:
checkpoints[i] = h.detach().clone() # Save checkpoint
h = layer(h)
return h, checkpoints
def checkpoint_backward(model, grad_output, checkpoints, checkpoint_every=10):
"""Backward pass, recomputing activations from checkpoints."""
grad_h = grad_output
for i in reversed(range(len(model.layers))):
# Find nearest checkpoint at or before layer i
ckpt_idx = (i // checkpoint_every) * checkpoint_every
# Recompute forward from checkpoint to layer i
h = checkpoints[ckpt_idx]
for j in range(ckpt_idx, i + 1):
h.requires_grad_(True)
h = model.layers[j](h)
# Now h has the computation graph for layers ckpt_idx..i
# Backward through just this segment
h.backward(grad_h)
grad_h = checkpoints[ckpt_idx].grad
return grad_h
In practice, PyTorchβs torch.utils.checkpoint.checkpoint handles this automatically:
from torch.utils.checkpoint import checkpoint
class TransformerWithCheckpointing(nn.Module):
def forward(self, x):
for layer in self.layers:
# checkpoint() saves only the input, recomputes forward during backward
x = checkpoint(layer, x, use_reentrant=False)
return x
4.3 Memory Math for 70B Model
Let us compute the savings for Llama 3 70B ( layers, , ) with FlashAttention.
Without checkpointing:
Activation memory per layer: 2,016 MB (with FlashAttention) Total: MB = 157.5 GB
With checkpointing every layer (save only input to each layer):
Saved per layer: just = 128 MB Peak activation memory: 1 layerβs full activations (recomputed) + all checkpoints MB = 12.0 GB
Savings: 157.5 GB down to 12.0 GB β a 92% reduction.
With checkpointing every 10 layers:
Saved: 8 checkpoints at 128 MB each = 1,024 MB Peak: activations for 10 layers (recomputed segment) + checkpoints MB = 20.7 GB
Savings: 87% reduction with less recomputation overhead.
Activation Memory vs Checkpoint Frequency (70B, S=8192)
(GB)4.4 The Compute Cost
Gradient checkpointing requires recomputing the forward pass for each segment during the backward pass. The cost depends on the checkpoint frequency:
Checkpoint every layer (most aggressive): Each layerβs forward pass is computed twice (once during forward, once during backward). Total forward compute = 2x. Since the backward pass already costs about 2x the forward pass (it computes both data gradients and weight gradients), the total training step goes from 3x forward to 4x forward. Overhead: 33%.
Checkpoint every layers: Each layerβs forward is recomputed once, but the recomputation of layers within a segment overlaps with the backward pass of later segments. The overhead is approximately of the forward cost. For : overhead is about 10% of forward, or 3.3% of total training step. For : about 6.7% of total.
Gradient Checkpointing: Memory vs Compute Tradeoff (70B)
| Checkpoint Freq | Activation Memory | Memory Savings | Compute Overhead | Wall Clock Impact |
|---|---|---|---|---|
| None | 157.5 GB | 0% | 0% | Baseline |
| Every 20 layers | 42 GB | 73% | ~5% | +5% |
| Every 10 layers | 21 GB | 87% | ~10% | +3-4% |
| Every 5 layers | 14 GB | 91% | ~20% | +7% |
| Every 1 layer | 12 GB | 92% | ~33% | +11% |
4.5 Selective Checkpointing
Not all activations cost the same to store or recompute. A smarter strategy: checkpoint expensive-to-store but cheap-to-recompute activations:
- Always recompute: RMSNorm outputs (cheap: a few elementwise ops), dropout masks (free: just reseed)
- Always save: Weight projections inputs (recomputing means an extra GEMM)
- Conditionally save: FFN gate activations (large at , but recomputing is one GEMM)
FlashAttention already implements selective checkpointing internally: it saves only and the per-row softmax statistics (logsumexp), recomputing during backward. This eliminates the memory while adding only one extra matmul.
5. Gradient Clipping
5.1 Why Clipping is Necessary
Even with Pre-Norm, residual connections, and mixed precision, gradient norms can spike during training. Common causes:
- Data outliers: A batch with unusually long sequences or rare token patterns can produce loss spikes
- Learning rate warmup: Early training with randomly initialized weights produces high-variance gradients
- Attention entropy collapse: When attention weights become very sharp (near one-hot), the softmax Jacobian produces large gradients
- Loss spikes: Rare token sequences that the model assigns very low probability to produce large values
A single large gradient update can destabilize the model irreversibly. If the update moves weights far from the current basin, the model may never recover.
5.2 Max-Norm Clipping
The standard approach: clip the global gradient norm to a maximum value. The βglobal gradient normβ is the L2 norm of the vector formed by concatenating all parameter gradients:
If , scale all gradients by :
def clip_grad_norm_(parameters, max_norm=1.0):
"""Clip gradient norm. Returns the original norm before clipping."""
total_norm_sq = 0.0
for p in parameters:
if p.grad is not None:
total_norm_sq += p.grad.data.float().pow(2).sum().item()
total_norm = total_norm_sq ** 0.5
clip_coef = max_norm / (total_norm + 1e-6)
if clip_coef < 1.0:
for p in parameters:
if p.grad is not None:
p.grad.data.mul_(clip_coef)
return total_norm
The max_norm = 1.0 is standard for LLM training. GPT-3 used 1.0. Llama used 1.0. Chinchilla used 1.0. The value means: if the gradient vector (across all 70 billion parameters) has L2 norm greater than 1.0, scale it down so the norm is exactly 1.0.
The value 1.0 is not derived from theory β it is an empirical standard. The intuition: with Adam optimizer and learning rate around , a gradient norm of 1.0 produces per-parameter updates that are comfortably within the range where the loss surface is approximately quadratic. Larger norms risk overshooting into regions where the linear gradient approximation breaks down. Some teams use 0.5 or 2.0, but 1.0 is the default for good reason: it works across model scales from 1B to 500B+.
5.3 Clipping Statistics in Practice
During healthy training of a 70B model:
- Early training (first 1000 steps): Gradient norm is 5-50, clipping activates on most steps
- After warmup: Gradient norm settles to 0.3-1.5, clipping activates 10-30% of steps
- During loss spikes: Gradient norm can briefly hit 100-1000, clipping reduces it by 100-1000x
- Late training: Gradient norm is 0.1-0.5, clipping rarely activates
Monitoring the gradient norm is one of the most important training diagnostics. A sustained increase in gradient norm (without a corresponding increase in loss) often indicates impending instability.
6. Complete Backprop Pseudocode for One Transformer Layer
Putting it all together. This pseudocode covers the complete backward pass for one Pre-Norm transformer layer with GQA attention and SwiGLU FFN:
def transformer_layer_backward(
grad_output, # [B, S, d] -- gradient from layer above
# === Saved from forward pass ===
x, # [B, S, d] -- layer input
x_norm_attn, # [B, S, d] -- RMSNorm output before attention
rms_attn, # [B, S, 1] -- RMS value for attention norm
gamma_attn, # [d] -- RMSNorm scale for attention
Q, K, V, # [B, n_h, S, d_k], [B, n_kv, S, d_k], ...
attn_lse, # [B, n_h, S] -- log-sum-exp from FlashAttention
attn_out, # [B, n_h, S, d_k] -- attention output (before W_O)
h_mid, # [B, S, d] -- x + attention output (input to FFN block)
x_norm_ffn, # [B, S, d] -- RMSNorm output before FFN
rms_ffn, # [B, S, 1] -- RMS value for FFN norm
gamma_ffn, # [d] -- RMSNorm scale for FFN
g1, g3, silu_g1, # [B, S, d_ff] -- SwiGLU intermediates
# === Weight matrices ===
W_Q, W_K, W_V, W_O, # Attention weights
W_1, W_2, W_3, # FFN weights
):
"""Returns: grad_x, and all weight gradients."""
# ============================================
# PHASE 1: Backward through FFN residual
# h_out = h_mid + FFN(RMSNorm(h_mid))
# ============================================
grad_ffn_out = grad_output # [B, S, d]
grad_h_mid = grad_output.clone() # Residual path
# --- Backward through FFN ---
# out = (SiLU(x_norm @ W1) * (x_norm @ W3)) @ W2
h = silu_g1 * g3 # Recompute h (or save it)
grad_h = grad_ffn_out @ W_2.T # [B, S, d_ff]
grad_W2 = h.reshape(-1, d_ff).T @ grad_ffn_out.reshape(-1, d) # [d_ff, d]
grad_silu = grad_h * g3 # [B, S, d_ff]
grad_g3 = grad_h * silu_g1 # [B, S, d_ff]
sig_g1 = torch.sigmoid(g1) # [B, S, d_ff]
silu_deriv = sig_g1 * (1.0 + g1 * (1.0 - sig_g1)) # [B, S, d_ff]
grad_g1 = grad_silu * silu_deriv # [B, S, d_ff]
grad_W1 = x_norm_ffn.reshape(-1, d).T @ grad_g1.reshape(-1, d_ff) # [d, d_ff]
grad_W3 = x_norm_ffn.reshape(-1, d).T @ grad_g3.reshape(-1, d_ff) # [d, d_ff]
grad_x_norm_ffn = grad_g1 @ W_1.T + grad_g3 @ W_3.T # [B, S, d]
# --- Backward through FFN RMSNorm ---
dx_hat = grad_x_norm_ffn * gamma_ffn # [B, S, d]
x_hat = h_mid / rms_ffn # [B, S, d]
proj = (dx_hat * x_hat).sum(dim=-1, keepdim=True) / d # [B, S, 1]
grad_h_mid_from_ffn = (dx_hat - x_hat * proj) / rms_ffn # [B, S, d]
grad_gamma_ffn = (grad_x_norm_ffn * x_hat).sum(dim=(0, 1)) # [d]
grad_h_mid += grad_h_mid_from_ffn # Add to residual path
# ============================================
# PHASE 2: Backward through attention residual
# h_mid = x + Attn(RMSNorm(x))
# ============================================
grad_attn_block = grad_h_mid # [B, S, d]
grad_x = grad_h_mid.clone() # Residual path
# --- Backward through W_O ---
# attn_final = concat_heads(attn_out) @ W_O
attn_out_flat = attn_out.reshape(B, S, n_h * d_k) # [B, S, n_h*d_k]
grad_W_O = attn_out_flat.reshape(-1, n_h * d_k).T @ grad_attn_block.reshape(-1, d)
grad_attn_out_flat = grad_attn_block @ W_O.T # [B, S, n_h*d_k]
grad_attn_out = grad_attn_out_flat.reshape(B, n_h, S, d_k)
# --- FlashAttention backward ---
# Recomputes attention matrix block-by-block
# Inputs: grad_attn_out, Q, K, V, attn_lse
# Outputs: grad_Q, grad_K, grad_V
grad_Q, grad_K, grad_V = flash_attn_backward(
grad_attn_out, Q, K, V, attn_lse # Uses O(S) memory, not O(S^2)
)
# grad_Q: [B, n_h, S, d_k]
# grad_K: [B, n_kv, S, d_k] (after GQA reduction)
# grad_V: [B, n_kv, S, d_k]
# --- Backward through QKV projections ---
grad_Q_flat = grad_Q.reshape(B, S, n_h * d_k)
grad_K_flat = grad_K.reshape(B, S, n_kv * d_k)
grad_V_flat = grad_V.reshape(B, S, n_kv * d_k)
grad_W_Q = x_norm_attn.reshape(-1, d).T @ grad_Q_flat.reshape(-1, n_h * d_k)
grad_W_K = x_norm_attn.reshape(-1, d).T @ grad_K_flat.reshape(-1, n_kv * d_k)
grad_W_V = x_norm_attn.reshape(-1, d).T @ grad_V_flat.reshape(-1, n_kv * d_k)
grad_x_norm_attn = (grad_Q_flat @ W_Q.T
+ grad_K_flat @ W_K.T
+ grad_V_flat @ W_V.T) # [B, S, d]
# --- Backward through attention RMSNorm ---
dx_hat_a = grad_x_norm_attn * gamma_attn
x_hat_a = x / rms_attn
proj_a = (dx_hat_a * x_hat_a).sum(dim=-1, keepdim=True) / d
grad_x_from_attn = (dx_hat_a - x_hat_a * proj_a) / rms_attn
grad_gamma_attn = (grad_x_norm_attn * x_hat_a).sum(dim=(0, 1))
grad_x += grad_x_from_attn # Add to residual path
# ============================================
# Return all gradients
# ============================================
return {
'grad_x': grad_x, # [B, S, d] -- pass to layer below
'grad_W_Q': grad_W_Q, # [d, n_h * d_k]
'grad_W_K': grad_W_K, # [d, n_kv * d_k]
'grad_W_V': grad_W_V, # [d, n_kv * d_k]
'grad_W_O': grad_W_O, # [n_h * d_k, d]
'grad_W_1': grad_W1, # [d, d_ff]
'grad_W_2': grad_W2, # [d_ff, d]
'grad_W_3': grad_W3, # [d, d_ff]
'grad_gamma_attn': grad_gamma_attn, # [d]
'grad_gamma_ffn': grad_gamma_ffn, # [d]
}
6.1 FLOP Count for Backward Pass
The backward pass performs approximately 2x the FLOPs of the forward pass. This is because every matrix multiply in the forward pass produces two matrix multiplies in the backward: (data gradient) and (weight gradient).
For one transformer layer of Llama 3 70B with :
FLOP Count Per Layer: Forward vs Backward (Llama 3 70B, B=1, S=8192)
| Operation | Forward GFLOPs | Backward GFLOPs | Ratio |
|---|---|---|---|
| Q projection (d to n_h*d_k) | 1,100 | 2,200 | 2x |
| K projection (d to n_kv*d_k) | 138 | 275 | 2x |
| V projection (d to n_kv*d_k) | 138 | 275 | 2x |
| QK^T (attention scores) | 1,100 | 2,200 (+1,100 FlashAttn recompute) | 2-3x |
| A*V (attention output) | 1,100 | 2,200 | 2x |
| O projection (n_h*d_k to d) | 1,100 | 2,200 | 2x |
| FFN W1 (d to d_ff) | 3,858 | 7,717 | 2x |
| FFN W3 (d to d_ff) | 3,858 | 7,717 | 2x |
| FFN W2 (d_ff to d) | 3,858 | 7,717 | 2x |
| Total per layer | 16,250 | 34,301 | 2.1x |
Total for 80 layers: Forward = 1,300 TFLOPs. Backward = 2,744 TFLOPs. Full training step = 4,044 TFLOPs.
On 8 H100 GPUs at 50% MFU (Model FLOP Utilization), peak throughput is TFLOPS. One training step with takes approximately 1.02 seconds.
7. End-to-End: A Complete Training Step Gradient Flow
Combining all the pieces, here is the complete flow of a single training step:
# 1. Forward pass (FP16/BF16)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
logits = model(input_ids) # Forward through all 80 layers
loss = F.cross_entropy( # Loss in FP32 (autocast promotes)
logits.view(-1, vocab_size),
labels.view(-1)
)
# 2. Backward pass (gradients in BF16, accumulated in FP32)
loss.backward()
# This triggers:
# - grad_logits = softmax(logits) - one_hot(labels) [B, S, V]
# - grad through output head (unembedding)
# - For each layer 79 down to 0:
# transformer_layer_backward(...)
# - grad through embedding table
# 3. Gradient clipping (in FP32)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 4. Optimizer step (FP32 master weights)
optimizer.step() # Adam updates FP32 master weights using FP32 gradients
optimizer.zero_grad() # Clear gradients for next step
# 5. Update learning rate
scheduler.step()
The key numerical invariant: gradients flow from the scalar loss through 80 layers of chain rule applications, through billions of multiply-accumulate operations, and arrive at each of the 70 billion parameters as a small FP32 number that tells the optimizer which direction to adjust that parameter. The entire training process is this operation repeated millions of times.
The most common gradient-related bugs and their diagnostics: (1) All gradients are zero β check that requires_grad=True on parameters and that no .detach() call breaks the graph. (2) Gradients are NaN β check for division by zero in normalization (epsilon too small) or log of zero in loss. (3) Gradients are all the same value β check that weight sharing or broadcasting is not accidentally making parameters aliases. (4) Gradient norm explodes β check learning rate, loss scaling factor, and data preprocessing.
Reviewer Agent Validation Challenge
The following statements about this postβs content are candidates for review. Some are true, some contain deliberate errors. A competent reviewer should verify each against the mathematical derivations and numerical calculations presented above.
-
Claim: The RMSNorm Jacobian for a single vector is . Verify that this is dimensionally consistent and that the outer product correctly produces the rank-1 correction term.
-
Claim: For GQA with 64 query heads and 8 KV heads, the gradient for requires summing over groups of 8 query heads per KV head. Verify: is the group size 8 (64/8) or 4?
-
Claim: The softmax VJP formula is . Derive this from and confirm it is correct.
-
Claim: Activation memory for SwiGLU intermediates across 80 layers is 106 GB (stated as bytes per layer, times 80). Recompute this value and check whether the 1.33 GB per layer figure is correct.
-
Claim: Gradient checkpointing every layer reduces activation memory from 157.5 GB to 12.0 GB. Verify the arithmetic: 80 checkpoints at 128 MB each = 10,240 MB = 10.0 GB, plus 1 layer at 2,016 MB. Does this total 12.0 GB or 11.9 GB?
-
Claim: The backward pass performs exactly 2x the FLOPs of the forward pass. Is this precisely true, or does FlashAttentionβs recomputation of push it above 2x? What is the actual ratio from the table?
-
Claim: Post-Norm gradient attenuation for 80 layers with per-layer factor 0.97 gives . Compute this value and verify.
-
Claim: FP16 smallest representable difference from 1.0 is . The mantissa of FP16 is 10 bits. Is the ULP (unit in the last place) at 1.0 actually ? Verify using the IEEE 754 half-precision format.