Part of Series Transformer Anatomy 18 of 23
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 21 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 22 Knowledge Distillation: Training Small Models to Match Large Ones 23 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search

A single training iteration of a large language model is the most expensive computational operation routinely performed on modern hardware. For a 70B parameter model, one step costs approximately 4,000 TFLOPs — consuming several seconds on a cluster of 256 H100 GPUs. A full training run consists of hundreds of thousands of these steps. Meta’s Llama 3 70B was trained for 1.4 trillion tokens with 15 trillion FLOPs total.

This post dissects one training iteration into its constituent operations. Every step is concrete: specific shapes, specific FLOPs, specific memory usage, specific latencies. The post culminates in a complete, production-grade training loop in approximately 50 lines of PyTorch.

We use Llama 3 70B as the reference model: d=8192d = 8192, L=80L = 80 layers, nh=64n_h = 64 query heads, nkv=8n_{kv} = 8 KV heads, dff=28672d_{ff} = 28672 (SwiGLU), V=128256V = 128256, BF16 training.


1. The Complete Training Iteration

A single training step consists of 7 stages, executed in strict sequence:

1. Data loading       (CPU, overlapped with previous GPU step)
2. Forward pass       (GPU, BF16 autocast)
3. Loss computation   (GPU, FP32)
4. Backward pass      (GPU, BF16/FP32)
5. Gradient clipping  (GPU, FP32)
6. Optimizer step     (GPU, FP32)
7. LR schedule update (CPU, negligible)

Each stage has a specific cost profile:

📊

Training Step Time Breakdown (70B, 256 H100s, seq_len=8192)

StageTime (ms)% of StepCompute Bound?Memory Peak?
Data loading ~0 (overlapped) 0% No (CPU) No
Forward pass 580 33% Yes (matmuls) Growing
Loss computation 15 1% No No
Backward pass 1080 62% Yes (2x forward) Peak here
Gradient clipping 8 0.5% No (reduction) No
Optimizer step 55 3% Memory bound 2nd peak
LR update 0.01 0% CPU No
Total 1738 100%
Note: 256 H100 GPUs with tensor parallelism (TP=8), pipeline parallelism (PP=4), data parallelism (DP=8). Effective batch size 1024 sequences of 8192 tokens each.

2. Data Loading Pipeline

2.1 Tokenization

Raw text is pre-tokenized offline and stored as memory-mapped binary files. Each token is a 32-bit integer (to support vocabularies larger than 65,536). A 1 trillion token dataset occupies:

1×1012×4 bytes=4 TB1 \times 10^{12} \times 4 \text{ bytes} = 4 \text{ TB}

The data loader reads from these files at training time. There is no on-the-fly tokenization — that would be a CPU bottleneck at scale.

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

class TokenizedDataset(Dataset):
    """Memory-mapped tokenized dataset for LLM training."""

    def __init__(self, data_path, seq_len=8192):
        self.seq_len = seq_len
        # Memory-map the file: no copy into RAM
        self.tokens = np.memmap(data_path, dtype=np.uint32, mode='r')
        self.n_sequences = len(self.tokens) // seq_len

    def __len__(self):
        return self.n_sequences

    def __getitem__(self, idx):
        start = idx * self.seq_len
        end = start + self.seq_len
        tokens = torch.from_numpy(self.tokens[start:end].astype(np.int64))

        # Input: tokens[0:seq_len-1], Target: tokens[1:seq_len]
        input_ids = tokens[:-1]     # [seq_len - 1]
        labels = tokens[1:]         # [seq_len - 1]
        return input_ids, labels

2.2 Sequence Packing

Naive padding wastes compute. If your dataset has sequences of varying length and you pad to the longest in the batch, short sequences waste FLOPs on padding tokens.

Packing concatenates multiple documents into a single sequence of exactly seq_len tokens, separated by end-of-document (EOD) tokens. An attention mask prevents cross-document attention:

def pack_sequences(documents, seq_len, eod_token):
    """
    Pack variable-length documents into fixed-length sequences.

    Args:
        documents: list of token lists (variable length)
        seq_len: target sequence length
        eod_token: end-of-document token ID

    Returns:
        packed_sequences: list of [seq_len] token arrays
        document_masks: list of [seq_len, seq_len] attention masks
    """
    packed = []
    masks = []
    current_seq = []
    current_doc_starts = []

    for doc in documents:
        tokens = doc + [eod_token]

        if len(current_seq) + len(tokens) > seq_len:
            # Current sequence is full (or nearly full)
            if len(current_seq) > 0:
                # Pad remaining space
                pad_len = seq_len - len(current_seq)
                current_seq.extend([eod_token] * pad_len)

                # Build attention mask: block-diagonal
                mask = build_packing_mask(current_doc_starts, seq_len)
                packed.append(current_seq[:seq_len])
                masks.append(mask)

            current_seq = []
            current_doc_starts = []

        current_doc_starts.append(len(current_seq))
        current_seq.extend(tokens)

    return packed, masks


def build_packing_mask(doc_starts, seq_len):
    """
    Build a causal attention mask that prevents cross-document attention.

    Each document can only attend to tokens within the same document,
    and only to previous positions (causal).
    """
    mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)

    for i, start in enumerate(doc_starts):
        end = doc_starts[i + 1] if i + 1 < len(doc_starts) else seq_len
        # Within-document causal mask
        for row in range(start, end):
            mask[row, start:row+1] = True

    return mask

2.3 Packing Efficiency

The efficiency of packing depends on the distribution of document lengths relative to seq_len:

📊

Packing Efficiency vs Document Length Distribution

Document Length (mean)Packing EfficiencyWasted Tokens (%)Throughput Gain vs Padding
256 tokens (short docs) 96% 4% +45% vs naive padding
1024 tokens 98% 2% +20%
4096 tokens 99% 1% +8%
8192+ tokens (full seq) 100% 0% +0% (no padding anyway)
Note: seq_len = 8192. Packing gain is largest when document lengths are much shorter than seq_len.

2.4 Data Loading Overlap

The data loader runs on CPU in a separate thread (or process), preparing the next batch while the GPU processes the current batch. With num_workers=4 and prefetch_factor=2, the data loader maintains a buffer of 8 ready batches. At 1.7 seconds per training step, the CPU has ample time to load, memmap, and collate the next batch.

train_loader = DataLoader(
    dataset,
    batch_size=micro_batch_size,   # Per-GPU batch size (e.g., 1-4)
    shuffle=True,
    num_workers=4,
    pin_memory=True,               # Pre-copy to GPU-pinned host memory
    prefetch_factor=2,
    persistent_workers=True,        # Keep workers alive between epochs
    drop_last=True,                 # Avoid uneven batch sizes
)

3. Forward Pass in Mixed Precision

3.1 BF16 Autocast

The forward pass runs under torch.autocast, which automatically selects the precision for each operation:

OperationPrecisionWhy
Linear layers (matmuls)BF16Tensor cores require FP16/BF16 input
Embedding lookupBF16No compute, just memory access
RMSNormFP32Variance computation sensitive to precision
SoftmaxFP32Overflow risk with large attention logits
SiLU activationBF16Elementwise, precision not critical
Residual additionBF16Elementwise
DropoutBF16Binary mask, precision irrelevant

The critical rule: normalization and softmax in FP32, everything else in BF16. This is because normalization computes a variance (sum of squares divided by dd), and in BF16 with d=8192d = 8192, the accumulated sum can lose significant precision. Softmax computes exponentials, which can overflow in BF16 for large inputs.

@torch.no_grad()
def forward_pass_annotated(model, input_ids):
    """Annotated forward pass showing precision at each step."""
    B, S = input_ids.shape

    with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        # Embedding: BF16 lookup
        h = model.embed_tokens(input_ids)          # [B, S, 8192] BF16
        # Memory: B * S * 8192 * 2 bytes

        for layer in model.layers:
            # RMSNorm: promoted to FP32 internally
            h_norm = layer.input_layernorm(h)       # [B, S, 8192] -> FP32 -> BF16

            # QKV projection: BF16 matmul on tensor cores
            q = h_norm @ layer.self_attn.q_proj.weight.T   # [B, S, 8192] BF16
            k = h_norm @ layer.self_attn.k_proj.weight.T   # [B, S, 1024] BF16
            v = h_norm @ layer.self_attn.v_proj.weight.T   # [B, S, 1024] BF16

            # FlashAttention: BF16 matmuls, FP32 softmax internally
            attn_output = flash_attention(q, k, v)  # [B, S, 8192] BF16

            # Output projection: BF16 matmul
            attn_output = attn_output @ layer.self_attn.o_proj.weight.T

            # Residual: BF16 addition
            h = h + attn_output                     # [B, S, 8192] BF16

            # FFN block (same pattern: norm in FP32, matmuls in BF16)
            h_norm = layer.post_attention_layernorm(h)
            gate = torch.nn.functional.silu(h_norm @ layer.mlp.gate_proj.weight.T)
            up = h_norm @ layer.mlp.up_proj.weight.T
            ffn_out = (gate * up) @ layer.mlp.down_proj.weight.T
            h = h + ffn_out

        # Final norm
        h = model.norm(h)                           # [B, S, 8192] BF16

        # Output projection (lm_head): BF16 matmul
        logits = h @ model.lm_head.weight.T         # [B, S, 128256] BF16

    return logits

3.2 Forward Pass Memory Accumulation

During training (not inference), the forward pass must save intermediate tensors for the backward pass. Memory grows as each layer deposits its activations:

GPU Memory During Forward Pass (70B, B=1, S=8192, FlashAttn)

(GB)
Model weights (BF16) constant
140 GB
After embed
140.1 GB
+0.1%
After layer 20
180 GB
+28.6%
After layer 40
220 GB
+57.1%
After layer 60
260 GB
+85.7%
After layer 80 activations: 160 GB
300 GB
+114.3%

With gradient checkpointing (every layer), the forward pass stores only one activation per layer (128 MB each), reducing the activation footprint to 80×128 MB=1080 \times 128 \text{ MB} = 10 GB instead of 160 GB.

3.3 Forward Pass FLOP Count

For Llama 3 70B, one forward pass through all 80 layers with B=1,S=8192B = 1, S = 8192:

Each matrix multiply costs 2×M×N×K2 \times M \times N \times K FLOPs. Per layer:

Q projection:   2 * (B*S) * d * d       = 2 * 8192 * 8192 * 8192     = 1,100 GFLOPs
K projection:   2 * (B*S) * d * (n_kv*dk) = 2 * 8192 * 8192 * 1024  = 138 GFLOPs
V projection:   2 * (B*S) * d * (n_kv*dk) = 138 GFLOPs
QK^T:           2 * B * n_h * S * S * dk = 2 * 1 * 64 * 8192 * 8192 * 128 = 1,100 GFLOPs
A*V:            2 * B * n_h * S * dk * S = 1,100 GFLOPs
O projection:   2 * (B*S) * d * d       = 1,100 GFLOPs
FFN gate_proj:  2 * (B*S) * d * d_ff    = 2 * 8192 * 8192 * 28672   = 3,858 GFLOPs
FFN up_proj:    3,858 GFLOPs
FFN down_proj:  3,858 GFLOPs
---
Per layer total: ~16,250 GFLOPs
80 layers:       ~1,300 TFLOPs
+ output head:   2 * 8192 * 8192 * 128256 = 17.2 TFLOPs
---
Total forward:   ~1,317 TFLOPs

4. Loss Computation

4.1 Cross-Entropy in FP32

The loss is always computed in FP32, even when the forward pass runs in BF16. The logits ([B,S,V][B, S, V] where V=128256V = 128256) are upcast to FP32 before softmax:

def compute_loss(logits_bf16, labels):
    """
    Compute cross-entropy loss in FP32.

    logits_bf16: [B, S, V] in BF16
    labels: [B, S] integer token IDs
    """
    # Upcast to FP32 for numerical stability
    logits = logits_bf16.float()  # [B, S, 128256] FP32

    # Reshape for cross_entropy: [B*S, V] and [B*S]
    logits_flat = logits.view(-1, logits.size(-1))  # [B*S, 128256]
    labels_flat = labels.view(-1)                     # [B*S]

    # Cross-entropy = -log(softmax(logits)[correct_class])
    # PyTorch fuses this into a numerically stable operation:
    #   1. Subtract max (log-sum-exp trick)
    #   2. Compute log-softmax
    #   3. Gather the correct class
    loss = torch.nn.functional.cross_entropy(
        logits_flat, labels_flat,
        ignore_index=-100,   # Padding tokens
        reduction='mean'
    )

    return loss  # Scalar, FP32

4.2 Why FP32 Matters for Loss

The cross-entropy loss for a single token is logp(ytrue)-\log p(y_{\text{true}}) where p=softmax(logits)p = \text{softmax}(\text{logits}). The softmax involves:

  1. max(logits)\max(\text{logits}) — finding the maximum across 128,256 classes
  2. exp(logitimax)\exp(\text{logit}_i - \max) — exponentiation
  3. exp()\sum \exp(\cdot) — summation over 128,256 terms

In BF16, the summation over 128,256 terms loses precision due to the limited 8-bit mantissa. The relative error in the sum can reach 1-2%, which translates to a bias in the gradient of the same magnitude. Over millions of steps, this bias accumulates into a measurably worse model.

Loss Memory Cost

The FP32 upcast of logits adds B×S×V×4B \times S \times V \times 4 bytes. For B=1,S=8192,V=128256B = 1, S = 8192, V = 128256: that is 1×8192×128256×4=4.21 \times 8192 \times 128256 \times 4 = 4.2 GB. This is a significant memory spike. Some implementations compute loss in chunks (processing 1024 vocabulary elements at a time) to avoid materializing the full FP32 logit tensor.

4.3 Loss Values and Perplexity

For reference, here are typical loss values during 70B training:

📊

Training Loss Milestones (70B Model)

Training StageTokens SeenLossPerplexityNotes
Random init 0 11.76 128,256 = ln(V), random guessing
After 100 steps ~800K 7.5 1,808 Learning basic frequencies
After 1K steps ~8M 4.2 66.7 Learning common patterns
After 10K steps ~80M 3.0 20.1 Grammatical text
After 100K steps ~800M 2.3 10.0 Coherent paragraphs
After 500K steps ~4B 1.9 6.7 Factual knowledge emerging
End of training ~1.4T 1.6 5.0 Near convergence
Note: Loss is natural log cross-entropy. Perplexity = exp(loss). Values are approximate and depend on data mix.

5. Backward Pass

5.1 Overview

The backward pass computes Lθ\frac{\partial \mathcal{L}}{\partial \theta} for every parameter θ\theta in the model via reverse-mode automatic differentiation (backpropagation). The computational cost is approximately 2x the forward pass (each forward matmul generates two backward matmuls: data gradient and weight gradient).

The backward pass flows in reverse through the model:

  1. Gradient of loss w.r.t. logits: Llogits=softmax(logits)one_hot(labels)\frac{\partial \mathcal{L}}{\partial \text{logits}} = \text{softmax}(\text{logits}) - \text{one\_hot}(\text{labels}). Shape: [B,S,V][B, S, V]
  2. Gradient through the output head (lm_head): [B,S,V]×[V,d][B,S,d][B, S, V] \times [V, d] \to [B, S, d]
  3. Gradient through each transformer layer (80 to 0): as detailed in Part 16 of this series
  4. Gradient through the embedding table

5.2 Backward Pass Memory Timeline

The backward pass is where memory peaks. For each layer, the backward computes gradients and then frees the saved activations for that layer. The peak occurs when:

  • All model weights are in memory (140 GB in BF16)
  • Optimizer states are in memory (840 GB in FP32 — but distributed across GPUs)
  • Gradients for the current layer are being computed (transient)
  • Saved activations for remaining layers (not yet freed)
Memory timeline during backward (single GPU shard, simplified):

Time -->
                     |<-- Backward layer 80 -->|<-- Layer 79 -->| ... |<-- Layer 0 -->|
                     |                          |                |     |               |
Activations saved:   | layers 0-80             | layers 0-79    |     | layer 0       |
Gradients computed:  | layer 80 grads          | layer 79 grads |     | layer 0 grads |
                     |                          |                |     |               |
Memory:              | PEAK                    | decreasing     |     | minimum       |

The peak is at the start of backward (layer 80), where all forward activations are still in memory plus the first backward computation.

With gradient checkpointing, the picture changes: activations are recomputed on-the-fly, so the “saved activations” component is much smaller (only the checkpoint tensors).

5.3 The GradScaler (FP16 Only)

When training in FP16 (not BF16), a GradScaler is needed to prevent gradient underflow:

scaler = torch.cuda.amp.GradScaler(
    init_scale=2**16,          # Initial loss scale: 65536
    growth_factor=2.0,         # Double scale every growth_interval steps
    backoff_factor=0.5,        # Halve scale on overflow
    growth_interval=2000,      # Steps between scale increases
)

# In the training loop:
with torch.cuda.amp.autocast(dtype=torch.float16):
    logits = model(input_ids)
    loss = F.cross_entropy(logits.view(-1, V), labels.view(-1))

scaler.scale(loss).backward()     # loss * scale_factor, then backward
scaler.unscale_(optimizer)        # Divide gradients by scale_factor
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)            # Skip step if inf/nan detected
scaler.update()                   # Adjust scale for next step

With BF16, the GradScaler is not needed — BF16’s dynamic range matches FP32. This simplifies the training loop and eliminates the overhead of scale management.

ℹ️ BF16 vs FP16 for Training

BF16 is preferred for training because: (1) No loss scaling needed — same dynamic range as FP32. (2) No GradScaler overhead. (3) Same tensor core throughput as FP16 on A100/H100. The only downside: BF16 has less mantissa precision (7 bits vs 10 bits), which occasionally matters for very precise operations. All modern LLM training uses BF16 unless the hardware does not support it.


6. Gradient Accumulation

6.1 Simulating Larger Batches

Large batch sizes improve training efficiency (better gradient estimates, higher hardware utilization) but require more memory. Gradient accumulation simulates a large batch by splitting it into micro-batches and accumulating gradients:

Effective batch size = micro_batch_size ×\times accumulation_steps ×\times n_data_parallel_gpus

For example, to achieve an effective batch of 1024 sequences with 256 GPUs and micro_batch_size = 1:

accumulation_steps=10241×256=4\text{accumulation\_steps} = \frac{1024}{1 \times 256} = 4

Each GPU processes 4 micro-batches, accumulating gradients, before performing a single optimizer step.

def train_step_with_accumulation(model, optimizer, batches, accumulation_steps):
    """
    Train with gradient accumulation.

    batches: iterator yielding (input_ids, labels) micro-batches
    """
    optimizer.zero_grad()
    total_loss = 0.0

    for step in range(accumulation_steps):
        input_ids, labels = next(batches)

        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            logits = model(input_ids)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1)
            ) / accumulation_steps  # Scale loss by 1/N to average gradients

        loss.backward()  # Gradients accumulate (not overwritten)
        total_loss += loss.item() * accumulation_steps

    # After all micro-batches: clip and step
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

    return total_loss / accumulation_steps

6.2 The Division Matters

The loss must be divided by accumulation_steps before backward(). This is because:

  • Without division: gradients are the sum across micro-batches, which is N×N \times larger than the mean gradient. This is equivalent to using N×N \times larger learning rate
  • With division: gradients are the mean across micro-batches, matching what you would get from a single large batch

This is a common bug that causes divergence when increasing accumulation steps without adjusting the loss scaling.

6.3 Memory Impact

Gradient accumulation does not save memory on activations (each micro-batch still needs full activation memory for its forward-backward). But it avoids needing memory for a large batch’s activations simultaneously:

📊

Memory: Large Batch vs Gradient Accumulation

ApproachActivationsGradientsTotalNotes
B=4, no accumulation 640 GB 140 GB 780 GB All 4 sequences simultaneously
B=1, 4x accumulation 160 GB 140 GB 300 GB One sequence at a time
B=1, 4x accum + checkpoint 12 GB 140 GB 152 GB Minimum memory
Note: Activations assume 70B model with S=8192, FlashAttention. Gradients are per-parameter (same regardless of batch).

7. The Optimizer Step

7.1 AdamW

Virtually all LLM training uses AdamW (Adam with decoupled weight decay). The update rule for parameter θ\theta at step tt:

mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 m^t=mt1β1t\hat{m}_t = \frac{m_t}{1 - \beta_1^t} v^t=vt1β2t\hat{v}_t = \frac{v_t}{1 - \beta_2^t} θt=θt1η(m^tv^t+ϵ+λθt1)\theta_t = \theta_{t-1} - \eta \left(\frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \theta_{t-1}\right)

Standard hyperparameters for LLM training:

HyperparameterSymbolTypical ValueNotes
Learning rateη\eta3×1043 \times 10^{-4}Peak LR, after warmup
Beta 1β1\beta_10.9Momentum decay
Beta 2β2\beta_20.95Variance decay
Epsilonϵ\epsilon10810^{-8}Numerical stability
Weight decayλ\lambda0.1Applied to all weights, not biases/norms

7.2 Optimizer Memory

AdamW stores two state tensors per parameter (first moment mm and second moment vv), both in FP32:

Model weights (BF16):           70B * 2 bytes = 140 GB
Model weights (FP32 master):    70B * 4 bytes = 280 GB
Adam first moment m (FP32):     70B * 4 bytes = 280 GB
Adam second moment v (FP32):    70B * 4 bytes = 280 GB
Gradients (BF16):               70B * 2 bytes = 140 GB
------------------------------------------------------
Total optimizer + weights:                      1,120 GB

This is why ZeRO-3 / FSDP is essential: the optimizer states alone (840 GB) exceed the memory of even 8 H100s (640 GB total).

With ZeRO-3 sharding across 256 GPUs:

Per-GPU optimizer memory: 840 GB / 256 = 3.28 GB
Per-GPU master weights:   280 GB / 256 = 1.09 GB
Per-GPU gradients:        140 GB / 256 = 0.55 GB
Per-GPU BF16 weights:     140 GB / 256 = 0.55 GB
----------------------------------------------
Per-GPU total (optimizer): 5.47 GB

This leaves approximately 74 GB per H100 for activations and communication buffers.

7.3 The Optimizer Step in Detail

def adamw_step(param, grad, m, v, step, lr, beta1, beta2, eps, weight_decay):
    """
    One AdamW update step for a single parameter tensor.
    All inputs are FP32 (master weights and optimizer states).
    """
    # Update biased first moment estimate
    m.mul_(beta1).add_(grad, alpha=1 - beta1)          # m = beta1*m + (1-beta1)*g

    # Update biased second moment estimate
    v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)  # v = beta2*v + (1-beta2)*g^2

    # Bias correction
    bias_correction1 = 1 - beta1 ** step
    bias_correction2 = 1 - beta2 ** step

    # Compute step size
    step_size = lr / bias_correction1

    # Compute denominator
    denom = (v.sqrt() / (bias_correction2 ** 0.5)).add_(eps)

    # Adam update
    param.addcdiv_(m, denom, value=-step_size)

    # Decoupled weight decay (applied to original param, not through Adam)
    param.add_(param, alpha=-lr * weight_decay)

7.4 Why Weight Decay is Decoupled

In the original Adam optimizer, weight decay was implemented as L2 regularization: adding λθ\lambda \theta to the gradient. This is wrong for adaptive optimizers because Adam divides the gradient by v\sqrt{v}, which means the effective weight decay is λ/v\lambda / \sqrt{v} — different for each parameter, depending on the gradient variance.

AdamW fixes this by applying weight decay directly to the parameter, outside of the Adam update. The decay λθt1\lambda \theta_{t-1} is not divided by v\sqrt{v}, so it is uniform across all parameters (proportional to each parameter’s magnitude).

The practical impact: with L2 regularization in Adam, parameters with large gradients (high vv) experience less effective weight decay. This means frequently updated parameters (which tend to be the most important) get less regularization — the opposite of what you want. AdamW corrects this.

Memory Breakdown: 70B Model Training on 256 H100s

(GB per GPU)
Optimizer states (sharded)
3.28 GB per GPU
Master weights (sharded)
1.09 GB per GPU
BF16 weights (full replica)
0.55 GB per GPU
Gradients (sharded)
0.55 GB per GPU
Activations (w/ checkpointing)
12 GB per GPU
Comm buffers + fragmentation
8 GB per GPU

8. Learning Rate Schedule

8.1 Warmup + Cosine Decay

The standard LR schedule for LLM training:

  1. Linear warmup: LR increases linearly from 0 to peak over warmup_steps (typically 2000)
  2. Cosine decay: LR decreases following a cosine curve to min_lr over the remaining steps
  3. Minimum LR: Typically 10% of peak LR (3×1053 \times 10^{-5} for peak 3×1043 \times 10^{-4})
import math

def cosine_schedule_with_warmup(step, warmup_steps, total_steps, peak_lr, min_lr):
    """
    Warmup + cosine decay learning rate schedule.

    step: current training step (0-indexed)
    warmup_steps: number of warmup steps
    total_steps: total training steps
    peak_lr: maximum learning rate (reached at end of warmup)
    min_lr: minimum learning rate (reached at end of training)
    """
    if step < warmup_steps:
        # Linear warmup: 0 -> peak_lr
        return peak_lr * step / warmup_steps

    # Cosine decay: peak_lr -> min_lr
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    cosine_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
    return min_lr + (peak_lr - min_lr) * cosine_factor

8.2 Why Warmup is Necessary

At initialization, the model’s activations and gradients are noisy and poorly conditioned. Large learning rates in this regime cause destructive updates: the model overshoots in parameter space, activations explode, and training diverges.

Warmup allows the model to:

  1. Build up Adam’s second moment estimate vv (which takes 1/(1β2)\sim 1/(1-\beta_2) steps to stabilize)
  2. Move weights to a region of parameter space with smoother loss landscape
  3. Calibrate the RMSNorm statistics and attention patterns to the data distribution

The typical warmup of 2000 steps means the model sees 2000×Beff×S2000 \times B_{\text{eff}} \times S tokens before reaching peak LR. For Beff=1024B_{\text{eff}} = 1024 and S=8192S = 8192: that is 16.8 billion tokens of warmup.

💡 Warmup Duration

The 2000-step warmup is not universal. Shorter warmups (500-1000 steps) sometimes work for smaller models. Longer warmups (5000-10000 steps) are sometimes used for very large models or when training instability is observed. The key diagnostic: if training loss spikes during the transition from warmup to peak LR, the warmup was too short.

8.3 LR Values Through Training

📊

Learning Rate at Key Training Milestones (70B)

StepTokens SeenLRPhaseNotes
0 0 0 Warmup start Zero LR, no updates
1000 8.4B 1.5e-4 Mid-warmup Half of peak
2000 16.8B 3.0e-4 Peak Maximum learning rate
50000 420B 2.7e-4 Early decay 90% of peak
100000 840B 2.0e-4 Mid decay 67% of peak
150000 1.26T 1.1e-4 Late decay 37% of peak
170000 1.43T 3.0e-5 Near end Minimum LR
Note: Assumes 170K total steps, 2000 warmup, peak LR 3e-4, min LR 3e-5, effective batch 1024, seq_len 8192.

9. Memory Timeline: Peak During Backward

9.1 Full Memory Timeline

Here is the complete memory timeline for one training step of a 70B model on a single GPU (with ZeRO-3 sharding across 256 GPUs):

Phase          | Weights | Optim States | Activations | Gradients | Total
               | (GB)    | (GB)         | (GB)        | (GB)      | (GB)
---------------+---------+--------------+-------------+-----------+------
Idle           | 0.55    | 3.28         | 0           | 0         | 3.83
Data loading   | 0.55    | 3.28         | 0           | 0         | 3.83
Forward (L=0)  | 0.55    | 3.28         | 0.13        | 0         | 3.96
Forward (L=40) | 0.55    | 3.28         | 5.13        | 0         | 8.96
Forward (L=80) | 0.55    | 3.28         | 10.13       | 0         | 13.96
Loss compute   | 0.55    | 3.28         | 10.13+4.2   | 0         | 18.16
Backward (L=80)| 0.55    | 3.28         | 10.13       | 0.55      | 14.51
Backward (L=40)| 0.55    | 3.28         | 5.13        | 0.55      | 9.51
Backward (L=0) | 0.55    | 3.28         | 0.13        | 0.55      | 4.51
AllReduce grads| 0.55    | 3.28         | 0           | 0.55      | 4.38
Optimizer step | 0.55    | 3.28         | 0           | 0.55      | 4.38
After step     | 0.55    | 3.28         | 0           | 0         | 3.83

Key observations:

  • Peak memory occurs during loss computation (14.4 GB for FP32 logits on top of activations) or at the start of backward
  • Activations are freed progressively during backward — layer 80’s activations are freed after layer 80’s backward is complete
  • Gradients accumulate during backward but are sharded across GPUs via reduce-scatter
  • After optimizer step, gradients are zeroed and memory returns to baseline

9.2 Communication Overlap

In production training with ZeRO-3/FSDP, communication overlaps with computation:

  • During forward: All-gather weights for the next layer while computing the current layer
  • During backward: Reduce-scatter gradients for the current layer while computing the previous layer’s backward
  • During optimizer step: Each GPU updates its shard of the optimizer states locally (no communication)

The communication cost per layer:

All-gather weights (forward):  2 * layer_params * (n-1)/n bytes
Reduce-scatter grads (backward): 2 * layer_params * (n-1)/n bytes

For one layer with 856M parameters in BF16, n=256n = 256 GPUs:

2×856×106×2×255256=3.42 GB per direction2 \times 856 \times 10^6 \times 2 \times \frac{255}{256} = 3.42 \text{ GB per direction}

On 400 Gbps InfiniBand (50 GB/s per GPU): 3.42/50=683.42 / 50 = 68 ms per layer for each all-gather or reduce-scatter. With overlap, this is hidden behind the 13.5 ms compute time per layer if the network bandwidth is sufficient — which it often is not, making communication the bottleneck for large-scale training.


10. The Complete Training Loop

Here is the production-grade training loop in approximately 50 lines of PyTorch:

import torch
import torch.nn.functional as F
import math
from torch.nn.utils import clip_grad_norm_

def train(
    model,
    optimizer,
    train_loader,
    total_steps,
    warmup_steps=2000,
    peak_lr=3e-4,
    min_lr=3e-5,
    max_grad_norm=1.0,
    accumulation_steps=1,
    log_interval=10,
    dtype=torch.bfloat16,
):
    """Complete LLM training loop."""
    model.train()
    step = 0
    accum_loss = 0.0

    data_iter = iter(train_loader)

    while step < total_steps:
        optimizer.zero_grad(set_to_none=True)  # Faster than zero_grad()

        # --- Gradient accumulation loop ---
        for micro_step in range(accumulation_steps):
            try:
                input_ids, labels = next(data_iter)
            except StopIteration:
                data_iter = iter(train_loader)
                input_ids, labels = next(data_iter)

            input_ids = input_ids.cuda(non_blocking=True)
            labels = labels.cuda(non_blocking=True)

            # Forward pass in mixed precision
            with torch.cuda.amp.autocast(dtype=dtype):
                logits = model(input_ids)
                loss = F.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    labels.view(-1),
                    ignore_index=-100,
                ) / accumulation_steps

            # Backward pass (gradients accumulate)
            loss.backward()
            accum_loss += loss.item() * accumulation_steps

        # --- Gradient clipping ---
        grad_norm = clip_grad_norm_(model.parameters(), max_grad_norm)

        # --- Optimizer step ---
        optimizer.step()

        # --- Learning rate schedule ---
        step += 1
        lr = cosine_schedule_with_warmup(step, warmup_steps, total_steps, peak_lr, min_lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # --- Logging ---
        if step % log_interval == 0:
            avg_loss = accum_loss / log_interval
            ppl = math.exp(min(avg_loss, 20))  # Cap to avoid overflow
            tokens_seen = step * accumulation_steps * train_loader.batch_size * 8192
            print(f"step={step:>6d} | loss={avg_loss:.4f} | ppl={ppl:.1f} | "
                  f"lr={lr:.2e} | grad_norm={grad_norm:.2f} | "
                  f"tokens={tokens_seen/1e9:.1f}B")
            accum_loss = 0.0


def cosine_schedule_with_warmup(step, warmup_steps, total_steps, peak_lr, min_lr):
    if step < warmup_steps:
        return peak_lr * step / warmup_steps
    progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    return min_lr + (peak_lr - min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress))

10.1 What This Loop Does Not Include

The 50-line loop above covers the core computation. A production training script also needs:

  1. Distributed setup: torch.distributed.init_process_group(), FSDP/ZeRO wrapping, tensor parallelism
  2. Checkpointing: Save model + optimizer + scheduler state every N steps
  3. Evaluation: Periodic evaluation on held-out data to track validation loss
  4. Gradient accumulation sync: For FSDP, no_sync() context manager during micro-batches to avoid premature all-reduce
  5. Exception handling: NaN detection, automatic checkpoint recovery
  6. Throughput logging: Tokens per second, MFU (model FLOP utilization)
  7. Data sampling: Multi-source data mixing with configurable ratios

10.2 Training Throughput

The end-to-end throughput for 70B training on common hardware configurations:

📊

Training Throughput: 70B Model on Various Configurations

HardwareGPUsTokens/secMFUTime for 1.4T tokens
8x H100 (single node) 8 ~4,800 ~32% ~3,370 days
64x H100 (8 nodes) 64 ~35,000 ~38% ~463 days
256x H100 (32 nodes) 256 ~130,000 ~44% ~125 days
512x H100 (64 nodes) 512 ~240,000 ~48% ~67 days
2048x H100 (256 nodes) 2048 ~850,000 ~50% ~19 days
Note: MFU = Model FLOP Utilization. Actual throughput depends on network topology, parallelism strategy, and batch size. Times assume continuous training without interruption.

At 2048 H100 GPUs, the cluster consumes approximately 1.5 MW of power. At \2/GPU-hour,thecomputecostfor19daysoftrainingisapproximately, the compute cost for 19 days of training is approximately $2 \times 2048 \times 24 \times 19 = $1.9M$.


11. Putting It All Together: Annotated Single Step

To close, here is one complete training step with every operation annotated with its cost:

# === STEP t ===
# Memory at start: 3.83 GB (weights + optimizer, sharded)

# 1. DATA LOADING (overlapped with previous step, effectively free)
input_ids, labels = next(data_iter)             # [B, S] int64, from pinned memory
input_ids = input_ids.cuda(non_blocking=True)   # Async H2D transfer: 0.06 ms
labels = labels.cuda(non_blocking=True)

# 2. FORWARD PASS (580 ms, 1,317 TFLOPs)
# Memory: 3.83 -> 18.16 GB (peak with logits)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
    logits = model(input_ids)                   # 80 layers, each:
    #   RMSNorm:      0.1 ms, ~0 TFLOPs        #   0.13 GFLOPs
    #   QKV project:  1.2 ms, 1,376 GFLOPs
    #   FlashAttn:    2.5 ms, 2,200 GFLOPs
    #   O project:    0.8 ms, 1,100 GFLOPs
    #   RMSNorm:      0.1 ms, ~0 TFLOPs
    #   SwiGLU FFN:   2.5 ms, 11,574 GFLOPs
    #   Total/layer:  7.2 ms, 16,250 GFLOPs
    # Output head:    2.1 ms, 17,200 GFLOPs

# 3. LOSS COMPUTATION (15 ms, negligible FLOPs)
    loss = F.cross_entropy(
        logits.float().view(-1, 128256),        # Upcast BF16->FP32: 4.2 GB
        labels.view(-1)
    )                                            # Scalar FP32

# 4. BACKWARD PASS (1,080 ms, 2,744 TFLOPs)
# Memory: 18.16 -> 14.51 -> decreasing -> 4.51 GB
loss.backward()
    # Gradient through output head:       4 ms
    # Per layer (80x):                    13.5 ms each
    #   Back through FFN:                 7 ms
    #   Back through attention:           5 ms
    #   Back through RMSNorm (2x):        0.5 ms
    #   Back through residual:            0.01 ms (just addition)
    #   Reduce-scatter gradients:         overlapped

# 5. GRADIENT CLIPPING (8 ms)
grad_norm = clip_grad_norm_(model.parameters(), max_norm=1.0)
    # All-reduce to compute global norm:  5 ms
    # Scale gradients if needed:          3 ms

# 6. OPTIMIZER STEP (55 ms)
optimizer.step()
    # Per parameter: 5 FP32 operations (m, v, bias_correct, update, decay)
    # Memory-bound: reads/writes 4 tensors per param (param, m, v, grad)
    # Total memory traffic: 70B * 4 * 4 bytes = 1.12 TB

# 7. LR UPDATE (0.01 ms)
lr = cosine_schedule_with_warmup(step, 2000, 170000, 3e-4, 3e-5)
for pg in optimizer.param_groups:
    pg['lr'] = lr

optimizer.zero_grad(set_to_none=True)           # Free gradient memory
# Memory back to 3.83 GB

# === Total step time: ~1,738 ms ===
# === Tokens processed: B * S = 1024 * 8192 = 8,388,608 ===
# === Throughput: 8.39M / 1.738 = 4.83M tokens/sec (across 256 GPUs) ===
Step Time Budget

The backward pass dominates at 62% of step time. The forward pass is 33%. The optimizer step is 3%. Everything else (data loading, loss, clipping, LR update) is collectively less than 2%. Any optimization effort should focus on making the backward pass faster — which is why FlashAttention (reducing backward memory and enabling larger batches), gradient checkpointing (enabling longer sequences), and communication overlap (hiding all-reduce latency) are the highest-impact techniques.


Reviewer Agent Validation Challenge

The following statements about this post’s content are candidates for review. Some are true, some contain deliberate errors.

  1. Claim: The forward pass of Llama 3 70B costs approximately 1,317 TFLOPs for B=1,S=8192B = 1, S = 8192. Verify: 80 layers at 16,250 GFLOPs each = 1,300 TFLOPs, plus 17.2 TFLOPs for the output head = 1,317.2 TFLOPs. Is this calculation consistent?

  2. Claim: The backward pass costs exactly 2x the forward pass in FLOPs. Check: the table in Section 5 shows 2,744 TFLOPs backward vs 1,317 TFLOPs forward, which is a ratio of 2.08x. Is this exactly 2x or slightly more? Why?

  3. Claim: AdamW stores 3 tensors per parameter: the parameter itself, the first moment mm, and the second moment vv. But the memory calculation shows 4 tensors (BF16 weights, FP32 master weights, mm, vv). Verify: how many distinct copies of each parameter exist, and what is the total per-parameter memory?

  4. Claim: Loss at random initialization is ln(V)=ln(128256)11.76\ln(V) = \ln(128256) \approx 11.76. Verify: is the loss log base ee (natural log) and is the value correct? Cross-entropy of a uniform distribution over VV classes with a one-hot target is log(1/V)=log(V)-\log(1/V) = \log(V).

  5. Claim: Warmup of 2000 steps with effective batch 1024 and S=8192S = 8192 means 2000×1024×8192=16.8×1092000 \times 1024 \times 8192 = 16.8 \times 10^9 tokens of warmup. Verify the arithmetic.

  6. Claim: Per-GPU memory for optimizer states with ZeRO-3 across 256 GPUs is 840/256=3.28840 / 256 = 3.28 GB. Check: is 840840 GB the correct total for Adam states on a 70B model?

  7. Claim: The output head matmul costs 2×8192×8192×128256=17.22 \times 8192 \times 8192 \times 128256 = 17.2 TFLOPs. Verify: the output head shape is [B×S,d]×[d,V][B \times S, d] \times [d, V], so the cost is 2×B×S×d×V2 \times B \times S \times d \times V. With B=1,S=8192,d=8192,V=128256B = 1, S = 8192, d = 8192, V = 128256: 2×1×8192×8192×1282562 \times 1 \times 8192 \times 8192 \times 128256. Compute this value.

  8. Claim: Communication cost per layer for all-gather is 2×856×106×2×255/256=3.422 \times 856 \times 10^6 \times 2 \times 255/256 = 3.42 GB. Verify: the factor of 2 at the start is for bytes per BF16 element. Should the formula use bytes or elements? Is 3.42 GB correct?