A single training iteration of a large language model is the most expensive computational operation routinely performed on modern hardware. For a 70B parameter model, one step costs approximately 4,000 TFLOPs — consuming several seconds on a cluster of 256 H100 GPUs. A full training run consists of hundreds of thousands of these steps. Meta’s Llama 3 70B was trained for 1.4 trillion tokens with 15 trillion FLOPs total.
This post dissects one training iteration into its constituent operations. Every step is concrete: specific shapes, specific FLOPs, specific memory usage, specific latencies. The post culminates in a complete, production-grade training loop in approximately 50 lines of PyTorch.
We use Llama 3 70B as the reference model: , layers, query heads, KV heads, (SwiGLU), , BF16 training.
1. The Complete Training Iteration
A single training step consists of 7 stages, executed in strict sequence:
1. Data loading (CPU, overlapped with previous GPU step)
2. Forward pass (GPU, BF16 autocast)
3. Loss computation (GPU, FP32)
4. Backward pass (GPU, BF16/FP32)
5. Gradient clipping (GPU, FP32)
6. Optimizer step (GPU, FP32)
7. LR schedule update (CPU, negligible)
Each stage has a specific cost profile:
Training Step Time Breakdown (70B, 256 H100s, seq_len=8192)
| Stage | Time (ms) | % of Step | Compute Bound? | Memory Peak? |
|---|---|---|---|---|
| Data loading | ~0 (overlapped) | 0% | No (CPU) | No |
| Forward pass | 580 | 33% | Yes (matmuls) | Growing |
| Loss computation | 15 | 1% | No | No |
| Backward pass | 1080 | 62% | Yes (2x forward) | Peak here |
| Gradient clipping | 8 | 0.5% | No (reduction) | No |
| Optimizer step | 55 | 3% | Memory bound | 2nd peak |
| LR update | 0.01 | 0% | CPU | No |
| Total | 1738 | 100% |
2. Data Loading Pipeline
2.1 Tokenization
Raw text is pre-tokenized offline and stored as memory-mapped binary files. Each token is a 32-bit integer (to support vocabularies larger than 65,536). A 1 trillion token dataset occupies:
The data loader reads from these files at training time. There is no on-the-fly tokenization — that would be a CPU bottleneck at scale.
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
class TokenizedDataset(Dataset):
"""Memory-mapped tokenized dataset for LLM training."""
def __init__(self, data_path, seq_len=8192):
self.seq_len = seq_len
# Memory-map the file: no copy into RAM
self.tokens = np.memmap(data_path, dtype=np.uint32, mode='r')
self.n_sequences = len(self.tokens) // seq_len
def __len__(self):
return self.n_sequences
def __getitem__(self, idx):
start = idx * self.seq_len
end = start + self.seq_len
tokens = torch.from_numpy(self.tokens[start:end].astype(np.int64))
# Input: tokens[0:seq_len-1], Target: tokens[1:seq_len]
input_ids = tokens[:-1] # [seq_len - 1]
labels = tokens[1:] # [seq_len - 1]
return input_ids, labels
2.2 Sequence Packing
Naive padding wastes compute. If your dataset has sequences of varying length and you pad to the longest in the batch, short sequences waste FLOPs on padding tokens.
Packing concatenates multiple documents into a single sequence of exactly seq_len tokens, separated by end-of-document (EOD) tokens. An attention mask prevents cross-document attention:
def pack_sequences(documents, seq_len, eod_token):
"""
Pack variable-length documents into fixed-length sequences.
Args:
documents: list of token lists (variable length)
seq_len: target sequence length
eod_token: end-of-document token ID
Returns:
packed_sequences: list of [seq_len] token arrays
document_masks: list of [seq_len, seq_len] attention masks
"""
packed = []
masks = []
current_seq = []
current_doc_starts = []
for doc in documents:
tokens = doc + [eod_token]
if len(current_seq) + len(tokens) > seq_len:
# Current sequence is full (or nearly full)
if len(current_seq) > 0:
# Pad remaining space
pad_len = seq_len - len(current_seq)
current_seq.extend([eod_token] * pad_len)
# Build attention mask: block-diagonal
mask = build_packing_mask(current_doc_starts, seq_len)
packed.append(current_seq[:seq_len])
masks.append(mask)
current_seq = []
current_doc_starts = []
current_doc_starts.append(len(current_seq))
current_seq.extend(tokens)
return packed, masks
def build_packing_mask(doc_starts, seq_len):
"""
Build a causal attention mask that prevents cross-document attention.
Each document can only attend to tokens within the same document,
and only to previous positions (causal).
"""
mask = torch.zeros(seq_len, seq_len, dtype=torch.bool)
for i, start in enumerate(doc_starts):
end = doc_starts[i + 1] if i + 1 < len(doc_starts) else seq_len
# Within-document causal mask
for row in range(start, end):
mask[row, start:row+1] = True
return mask
2.3 Packing Efficiency
The efficiency of packing depends on the distribution of document lengths relative to seq_len:
Packing Efficiency vs Document Length Distribution
| Document Length (mean) | Packing Efficiency | Wasted Tokens (%) | Throughput Gain vs Padding |
|---|---|---|---|
| 256 tokens (short docs) | 96% | 4% | +45% vs naive padding |
| 1024 tokens | 98% | 2% | +20% |
| 4096 tokens | 99% | 1% | +8% |
| 8192+ tokens (full seq) | 100% | 0% | +0% (no padding anyway) |
2.4 Data Loading Overlap
The data loader runs on CPU in a separate thread (or process), preparing the next batch while the GPU processes the current batch. With num_workers=4 and prefetch_factor=2, the data loader maintains a buffer of 8 ready batches. At 1.7 seconds per training step, the CPU has ample time to load, memmap, and collate the next batch.
train_loader = DataLoader(
dataset,
batch_size=micro_batch_size, # Per-GPU batch size (e.g., 1-4)
shuffle=True,
num_workers=4,
pin_memory=True, # Pre-copy to GPU-pinned host memory
prefetch_factor=2,
persistent_workers=True, # Keep workers alive between epochs
drop_last=True, # Avoid uneven batch sizes
)
3. Forward Pass in Mixed Precision
3.1 BF16 Autocast
The forward pass runs under torch.autocast, which automatically selects the precision for each operation:
| Operation | Precision | Why |
|---|---|---|
| Linear layers (matmuls) | BF16 | Tensor cores require FP16/BF16 input |
| Embedding lookup | BF16 | No compute, just memory access |
| RMSNorm | FP32 | Variance computation sensitive to precision |
| Softmax | FP32 | Overflow risk with large attention logits |
| SiLU activation | BF16 | Elementwise, precision not critical |
| Residual addition | BF16 | Elementwise |
| Dropout | BF16 | Binary mask, precision irrelevant |
The critical rule: normalization and softmax in FP32, everything else in BF16. This is because normalization computes a variance (sum of squares divided by ), and in BF16 with , the accumulated sum can lose significant precision. Softmax computes exponentials, which can overflow in BF16 for large inputs.
@torch.no_grad()
def forward_pass_annotated(model, input_ids):
"""Annotated forward pass showing precision at each step."""
B, S = input_ids.shape
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
# Embedding: BF16 lookup
h = model.embed_tokens(input_ids) # [B, S, 8192] BF16
# Memory: B * S * 8192 * 2 bytes
for layer in model.layers:
# RMSNorm: promoted to FP32 internally
h_norm = layer.input_layernorm(h) # [B, S, 8192] -> FP32 -> BF16
# QKV projection: BF16 matmul on tensor cores
q = h_norm @ layer.self_attn.q_proj.weight.T # [B, S, 8192] BF16
k = h_norm @ layer.self_attn.k_proj.weight.T # [B, S, 1024] BF16
v = h_norm @ layer.self_attn.v_proj.weight.T # [B, S, 1024] BF16
# FlashAttention: BF16 matmuls, FP32 softmax internally
attn_output = flash_attention(q, k, v) # [B, S, 8192] BF16
# Output projection: BF16 matmul
attn_output = attn_output @ layer.self_attn.o_proj.weight.T
# Residual: BF16 addition
h = h + attn_output # [B, S, 8192] BF16
# FFN block (same pattern: norm in FP32, matmuls in BF16)
h_norm = layer.post_attention_layernorm(h)
gate = torch.nn.functional.silu(h_norm @ layer.mlp.gate_proj.weight.T)
up = h_norm @ layer.mlp.up_proj.weight.T
ffn_out = (gate * up) @ layer.mlp.down_proj.weight.T
h = h + ffn_out
# Final norm
h = model.norm(h) # [B, S, 8192] BF16
# Output projection (lm_head): BF16 matmul
logits = h @ model.lm_head.weight.T # [B, S, 128256] BF16
return logits
3.2 Forward Pass Memory Accumulation
During training (not inference), the forward pass must save intermediate tensors for the backward pass. Memory grows as each layer deposits its activations:
GPU Memory During Forward Pass (70B, B=1, S=8192, FlashAttn)
(GB)With gradient checkpointing (every layer), the forward pass stores only one activation per layer (128 MB each), reducing the activation footprint to GB instead of 160 GB.
3.3 Forward Pass FLOP Count
For Llama 3 70B, one forward pass through all 80 layers with :
Each matrix multiply costs FLOPs. Per layer:
Q projection: 2 * (B*S) * d * d = 2 * 8192 * 8192 * 8192 = 1,100 GFLOPs
K projection: 2 * (B*S) * d * (n_kv*dk) = 2 * 8192 * 8192 * 1024 = 138 GFLOPs
V projection: 2 * (B*S) * d * (n_kv*dk) = 138 GFLOPs
QK^T: 2 * B * n_h * S * S * dk = 2 * 1 * 64 * 8192 * 8192 * 128 = 1,100 GFLOPs
A*V: 2 * B * n_h * S * dk * S = 1,100 GFLOPs
O projection: 2 * (B*S) * d * d = 1,100 GFLOPs
FFN gate_proj: 2 * (B*S) * d * d_ff = 2 * 8192 * 8192 * 28672 = 3,858 GFLOPs
FFN up_proj: 3,858 GFLOPs
FFN down_proj: 3,858 GFLOPs
---
Per layer total: ~16,250 GFLOPs
80 layers: ~1,300 TFLOPs
+ output head: 2 * 8192 * 8192 * 128256 = 17.2 TFLOPs
---
Total forward: ~1,317 TFLOPs
4. Loss Computation
4.1 Cross-Entropy in FP32
The loss is always computed in FP32, even when the forward pass runs in BF16. The logits ( where ) are upcast to FP32 before softmax:
def compute_loss(logits_bf16, labels):
"""
Compute cross-entropy loss in FP32.
logits_bf16: [B, S, V] in BF16
labels: [B, S] integer token IDs
"""
# Upcast to FP32 for numerical stability
logits = logits_bf16.float() # [B, S, 128256] FP32
# Reshape for cross_entropy: [B*S, V] and [B*S]
logits_flat = logits.view(-1, logits.size(-1)) # [B*S, 128256]
labels_flat = labels.view(-1) # [B*S]
# Cross-entropy = -log(softmax(logits)[correct_class])
# PyTorch fuses this into a numerically stable operation:
# 1. Subtract max (log-sum-exp trick)
# 2. Compute log-softmax
# 3. Gather the correct class
loss = torch.nn.functional.cross_entropy(
logits_flat, labels_flat,
ignore_index=-100, # Padding tokens
reduction='mean'
)
return loss # Scalar, FP32
4.2 Why FP32 Matters for Loss
The cross-entropy loss for a single token is where . The softmax involves:
- — finding the maximum across 128,256 classes
- — exponentiation
- — summation over 128,256 terms
In BF16, the summation over 128,256 terms loses precision due to the limited 8-bit mantissa. The relative error in the sum can reach 1-2%, which translates to a bias in the gradient of the same magnitude. Over millions of steps, this bias accumulates into a measurably worse model.
The FP32 upcast of logits adds bytes. For : that is GB. This is a significant memory spike. Some implementations compute loss in chunks (processing 1024 vocabulary elements at a time) to avoid materializing the full FP32 logit tensor.
4.3 Loss Values and Perplexity
For reference, here are typical loss values during 70B training:
Training Loss Milestones (70B Model)
| Training Stage | Tokens Seen | Loss | Perplexity | Notes |
|---|---|---|---|---|
| Random init | 0 | 11.76 | 128,256 | = ln(V), random guessing |
| After 100 steps | ~800K | 7.5 | 1,808 | Learning basic frequencies |
| After 1K steps | ~8M | 4.2 | 66.7 | Learning common patterns |
| After 10K steps | ~80M | 3.0 | 20.1 | Grammatical text |
| After 100K steps | ~800M | 2.3 | 10.0 | Coherent paragraphs |
| After 500K steps | ~4B | 1.9 | 6.7 | Factual knowledge emerging |
| End of training | ~1.4T | 1.6 | 5.0 | Near convergence |
5. Backward Pass
5.1 Overview
The backward pass computes for every parameter in the model via reverse-mode automatic differentiation (backpropagation). The computational cost is approximately 2x the forward pass (each forward matmul generates two backward matmuls: data gradient and weight gradient).
The backward pass flows in reverse through the model:
- Gradient of loss w.r.t. logits: . Shape:
- Gradient through the output head (lm_head):
- Gradient through each transformer layer (80 to 0): as detailed in Part 16 of this series
- Gradient through the embedding table
5.2 Backward Pass Memory Timeline
The backward pass is where memory peaks. For each layer, the backward computes gradients and then frees the saved activations for that layer. The peak occurs when:
- All model weights are in memory (140 GB in BF16)
- Optimizer states are in memory (840 GB in FP32 — but distributed across GPUs)
- Gradients for the current layer are being computed (transient)
- Saved activations for remaining layers (not yet freed)
Memory timeline during backward (single GPU shard, simplified):
Time -->
|<-- Backward layer 80 -->|<-- Layer 79 -->| ... |<-- Layer 0 -->|
| | | | |
Activations saved: | layers 0-80 | layers 0-79 | | layer 0 |
Gradients computed: | layer 80 grads | layer 79 grads | | layer 0 grads |
| | | | |
Memory: | PEAK | decreasing | | minimum |
The peak is at the start of backward (layer 80), where all forward activations are still in memory plus the first backward computation.
With gradient checkpointing, the picture changes: activations are recomputed on-the-fly, so the “saved activations” component is much smaller (only the checkpoint tensors).
5.3 The GradScaler (FP16 Only)
When training in FP16 (not BF16), a GradScaler is needed to prevent gradient underflow:
scaler = torch.cuda.amp.GradScaler(
init_scale=2**16, # Initial loss scale: 65536
growth_factor=2.0, # Double scale every growth_interval steps
backoff_factor=0.5, # Halve scale on overflow
growth_interval=2000, # Steps between scale increases
)
# In the training loop:
with torch.cuda.amp.autocast(dtype=torch.float16):
logits = model(input_ids)
loss = F.cross_entropy(logits.view(-1, V), labels.view(-1))
scaler.scale(loss).backward() # loss * scale_factor, then backward
scaler.unscale_(optimizer) # Divide gradients by scale_factor
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer) # Skip step if inf/nan detected
scaler.update() # Adjust scale for next step
With BF16, the GradScaler is not needed — BF16’s dynamic range matches FP32. This simplifies the training loop and eliminates the overhead of scale management.
BF16 is preferred for training because: (1) No loss scaling needed — same dynamic range as FP32. (2) No GradScaler overhead. (3) Same tensor core throughput as FP16 on A100/H100. The only downside: BF16 has less mantissa precision (7 bits vs 10 bits), which occasionally matters for very precise operations. All modern LLM training uses BF16 unless the hardware does not support it.
6. Gradient Accumulation
6.1 Simulating Larger Batches
Large batch sizes improve training efficiency (better gradient estimates, higher hardware utilization) but require more memory. Gradient accumulation simulates a large batch by splitting it into micro-batches and accumulating gradients:
Effective batch size = micro_batch_size accumulation_steps n_data_parallel_gpus
For example, to achieve an effective batch of 1024 sequences with 256 GPUs and micro_batch_size = 1:
Each GPU processes 4 micro-batches, accumulating gradients, before performing a single optimizer step.
def train_step_with_accumulation(model, optimizer, batches, accumulation_steps):
"""
Train with gradient accumulation.
batches: iterator yielding (input_ids, labels) micro-batches
"""
optimizer.zero_grad()
total_loss = 0.0
for step in range(accumulation_steps):
input_ids, labels = next(batches)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
logits = model(input_ids)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1)
) / accumulation_steps # Scale loss by 1/N to average gradients
loss.backward() # Gradients accumulate (not overwritten)
total_loss += loss.item() * accumulation_steps
# After all micro-batches: clip and step
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
return total_loss / accumulation_steps
6.2 The Division Matters
The loss must be divided by accumulation_steps before backward(). This is because:
- Without division: gradients are the sum across micro-batches, which is larger than the mean gradient. This is equivalent to using larger learning rate
- With division: gradients are the mean across micro-batches, matching what you would get from a single large batch
This is a common bug that causes divergence when increasing accumulation steps without adjusting the loss scaling.
6.3 Memory Impact
Gradient accumulation does not save memory on activations (each micro-batch still needs full activation memory for its forward-backward). But it avoids needing memory for a large batch’s activations simultaneously:
Memory: Large Batch vs Gradient Accumulation
| Approach | Activations | Gradients | Total | Notes |
|---|---|---|---|---|
| B=4, no accumulation | 640 GB | 140 GB | 780 GB | All 4 sequences simultaneously |
| B=1, 4x accumulation | 160 GB | 140 GB | 300 GB | One sequence at a time |
| B=1, 4x accum + checkpoint | 12 GB | 140 GB | 152 GB | Minimum memory |
7. The Optimizer Step
7.1 AdamW
Virtually all LLM training uses AdamW (Adam with decoupled weight decay). The update rule for parameter at step :
Standard hyperparameters for LLM training:
| Hyperparameter | Symbol | Typical Value | Notes |
|---|---|---|---|
| Learning rate | Peak LR, after warmup | ||
| Beta 1 | 0.9 | Momentum decay | |
| Beta 2 | 0.95 | Variance decay | |
| Epsilon | Numerical stability | ||
| Weight decay | 0.1 | Applied to all weights, not biases/norms |
7.2 Optimizer Memory
AdamW stores two state tensors per parameter (first moment and second moment ), both in FP32:
Model weights (BF16): 70B * 2 bytes = 140 GB
Model weights (FP32 master): 70B * 4 bytes = 280 GB
Adam first moment m (FP32): 70B * 4 bytes = 280 GB
Adam second moment v (FP32): 70B * 4 bytes = 280 GB
Gradients (BF16): 70B * 2 bytes = 140 GB
------------------------------------------------------
Total optimizer + weights: 1,120 GB
This is why ZeRO-3 / FSDP is essential: the optimizer states alone (840 GB) exceed the memory of even 8 H100s (640 GB total).
With ZeRO-3 sharding across 256 GPUs:
Per-GPU optimizer memory: 840 GB / 256 = 3.28 GB
Per-GPU master weights: 280 GB / 256 = 1.09 GB
Per-GPU gradients: 140 GB / 256 = 0.55 GB
Per-GPU BF16 weights: 140 GB / 256 = 0.55 GB
----------------------------------------------
Per-GPU total (optimizer): 5.47 GB
This leaves approximately 74 GB per H100 for activations and communication buffers.
7.3 The Optimizer Step in Detail
def adamw_step(param, grad, m, v, step, lr, beta1, beta2, eps, weight_decay):
"""
One AdamW update step for a single parameter tensor.
All inputs are FP32 (master weights and optimizer states).
"""
# Update biased first moment estimate
m.mul_(beta1).add_(grad, alpha=1 - beta1) # m = beta1*m + (1-beta1)*g
# Update biased second moment estimate
v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v = beta2*v + (1-beta2)*g^2
# Bias correction
bias_correction1 = 1 - beta1 ** step
bias_correction2 = 1 - beta2 ** step
# Compute step size
step_size = lr / bias_correction1
# Compute denominator
denom = (v.sqrt() / (bias_correction2 ** 0.5)).add_(eps)
# Adam update
param.addcdiv_(m, denom, value=-step_size)
# Decoupled weight decay (applied to original param, not through Adam)
param.add_(param, alpha=-lr * weight_decay)
7.4 Why Weight Decay is Decoupled
In the original Adam optimizer, weight decay was implemented as L2 regularization: adding to the gradient. This is wrong for adaptive optimizers because Adam divides the gradient by , which means the effective weight decay is — different for each parameter, depending on the gradient variance.
AdamW fixes this by applying weight decay directly to the parameter, outside of the Adam update. The decay is not divided by , so it is uniform across all parameters (proportional to each parameter’s magnitude).
The practical impact: with L2 regularization in Adam, parameters with large gradients (high ) experience less effective weight decay. This means frequently updated parameters (which tend to be the most important) get less regularization — the opposite of what you want. AdamW corrects this.
Memory Breakdown: 70B Model Training on 256 H100s
(GB per GPU)8. Learning Rate Schedule
8.1 Warmup + Cosine Decay
The standard LR schedule for LLM training:
- Linear warmup: LR increases linearly from 0 to peak over
warmup_steps(typically 2000) - Cosine decay: LR decreases following a cosine curve to
min_lrover the remaining steps - Minimum LR: Typically 10% of peak LR ( for peak )
import math
def cosine_schedule_with_warmup(step, warmup_steps, total_steps, peak_lr, min_lr):
"""
Warmup + cosine decay learning rate schedule.
step: current training step (0-indexed)
warmup_steps: number of warmup steps
total_steps: total training steps
peak_lr: maximum learning rate (reached at end of warmup)
min_lr: minimum learning rate (reached at end of training)
"""
if step < warmup_steps:
# Linear warmup: 0 -> peak_lr
return peak_lr * step / warmup_steps
# Cosine decay: peak_lr -> min_lr
progress = (step - warmup_steps) / (total_steps - warmup_steps)
cosine_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
return min_lr + (peak_lr - min_lr) * cosine_factor
8.2 Why Warmup is Necessary
At initialization, the model’s activations and gradients are noisy and poorly conditioned. Large learning rates in this regime cause destructive updates: the model overshoots in parameter space, activations explode, and training diverges.
Warmup allows the model to:
- Build up Adam’s second moment estimate (which takes steps to stabilize)
- Move weights to a region of parameter space with smoother loss landscape
- Calibrate the RMSNorm statistics and attention patterns to the data distribution
The typical warmup of 2000 steps means the model sees tokens before reaching peak LR. For and : that is 16.8 billion tokens of warmup.
The 2000-step warmup is not universal. Shorter warmups (500-1000 steps) sometimes work for smaller models. Longer warmups (5000-10000 steps) are sometimes used for very large models or when training instability is observed. The key diagnostic: if training loss spikes during the transition from warmup to peak LR, the warmup was too short.
8.3 LR Values Through Training
Learning Rate at Key Training Milestones (70B)
| Step | Tokens Seen | LR | Phase | Notes |
|---|---|---|---|---|
| 0 | 0 | 0 | Warmup start | Zero LR, no updates |
| 1000 | 8.4B | 1.5e-4 | Mid-warmup | Half of peak |
| 2000 | 16.8B | 3.0e-4 | Peak | Maximum learning rate |
| 50000 | 420B | 2.7e-4 | Early decay | 90% of peak |
| 100000 | 840B | 2.0e-4 | Mid decay | 67% of peak |
| 150000 | 1.26T | 1.1e-4 | Late decay | 37% of peak |
| 170000 | 1.43T | 3.0e-5 | Near end | Minimum LR |
9. Memory Timeline: Peak During Backward
9.1 Full Memory Timeline
Here is the complete memory timeline for one training step of a 70B model on a single GPU (with ZeRO-3 sharding across 256 GPUs):
Phase | Weights | Optim States | Activations | Gradients | Total
| (GB) | (GB) | (GB) | (GB) | (GB)
---------------+---------+--------------+-------------+-----------+------
Idle | 0.55 | 3.28 | 0 | 0 | 3.83
Data loading | 0.55 | 3.28 | 0 | 0 | 3.83
Forward (L=0) | 0.55 | 3.28 | 0.13 | 0 | 3.96
Forward (L=40) | 0.55 | 3.28 | 5.13 | 0 | 8.96
Forward (L=80) | 0.55 | 3.28 | 10.13 | 0 | 13.96
Loss compute | 0.55 | 3.28 | 10.13+4.2 | 0 | 18.16
Backward (L=80)| 0.55 | 3.28 | 10.13 | 0.55 | 14.51
Backward (L=40)| 0.55 | 3.28 | 5.13 | 0.55 | 9.51
Backward (L=0) | 0.55 | 3.28 | 0.13 | 0.55 | 4.51
AllReduce grads| 0.55 | 3.28 | 0 | 0.55 | 4.38
Optimizer step | 0.55 | 3.28 | 0 | 0.55 | 4.38
After step | 0.55 | 3.28 | 0 | 0 | 3.83
Key observations:
- Peak memory occurs during loss computation (14.4 GB for FP32 logits on top of activations) or at the start of backward
- Activations are freed progressively during backward — layer 80’s activations are freed after layer 80’s backward is complete
- Gradients accumulate during backward but are sharded across GPUs via reduce-scatter
- After optimizer step, gradients are zeroed and memory returns to baseline
9.2 Communication Overlap
In production training with ZeRO-3/FSDP, communication overlaps with computation:
- During forward: All-gather weights for the next layer while computing the current layer
- During backward: Reduce-scatter gradients for the current layer while computing the previous layer’s backward
- During optimizer step: Each GPU updates its shard of the optimizer states locally (no communication)
The communication cost per layer:
All-gather weights (forward): 2 * layer_params * (n-1)/n bytes
Reduce-scatter grads (backward): 2 * layer_params * (n-1)/n bytes
For one layer with 856M parameters in BF16, GPUs:
On 400 Gbps InfiniBand (50 GB/s per GPU): ms per layer for each all-gather or reduce-scatter. With overlap, this is hidden behind the 13.5 ms compute time per layer if the network bandwidth is sufficient — which it often is not, making communication the bottleneck for large-scale training.
10. The Complete Training Loop
Here is the production-grade training loop in approximately 50 lines of PyTorch:
import torch
import torch.nn.functional as F
import math
from torch.nn.utils import clip_grad_norm_
def train(
model,
optimizer,
train_loader,
total_steps,
warmup_steps=2000,
peak_lr=3e-4,
min_lr=3e-5,
max_grad_norm=1.0,
accumulation_steps=1,
log_interval=10,
dtype=torch.bfloat16,
):
"""Complete LLM training loop."""
model.train()
step = 0
accum_loss = 0.0
data_iter = iter(train_loader)
while step < total_steps:
optimizer.zero_grad(set_to_none=True) # Faster than zero_grad()
# --- Gradient accumulation loop ---
for micro_step in range(accumulation_steps):
try:
input_ids, labels = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
input_ids, labels = next(data_iter)
input_ids = input_ids.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
# Forward pass in mixed precision
with torch.cuda.amp.autocast(dtype=dtype):
logits = model(input_ids)
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100,
) / accumulation_steps
# Backward pass (gradients accumulate)
loss.backward()
accum_loss += loss.item() * accumulation_steps
# --- Gradient clipping ---
grad_norm = clip_grad_norm_(model.parameters(), max_grad_norm)
# --- Optimizer step ---
optimizer.step()
# --- Learning rate schedule ---
step += 1
lr = cosine_schedule_with_warmup(step, warmup_steps, total_steps, peak_lr, min_lr)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# --- Logging ---
if step % log_interval == 0:
avg_loss = accum_loss / log_interval
ppl = math.exp(min(avg_loss, 20)) # Cap to avoid overflow
tokens_seen = step * accumulation_steps * train_loader.batch_size * 8192
print(f"step={step:>6d} | loss={avg_loss:.4f} | ppl={ppl:.1f} | "
f"lr={lr:.2e} | grad_norm={grad_norm:.2f} | "
f"tokens={tokens_seen/1e9:.1f}B")
accum_loss = 0.0
def cosine_schedule_with_warmup(step, warmup_steps, total_steps, peak_lr, min_lr):
if step < warmup_steps:
return peak_lr * step / warmup_steps
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
return min_lr + (peak_lr - min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress))
10.1 What This Loop Does Not Include
The 50-line loop above covers the core computation. A production training script also needs:
- Distributed setup:
torch.distributed.init_process_group(), FSDP/ZeRO wrapping, tensor parallelism - Checkpointing: Save model + optimizer + scheduler state every N steps
- Evaluation: Periodic evaluation on held-out data to track validation loss
- Gradient accumulation sync: For FSDP,
no_sync()context manager during micro-batches to avoid premature all-reduce - Exception handling: NaN detection, automatic checkpoint recovery
- Throughput logging: Tokens per second, MFU (model FLOP utilization)
- Data sampling: Multi-source data mixing with configurable ratios
10.2 Training Throughput
The end-to-end throughput for 70B training on common hardware configurations:
Training Throughput: 70B Model on Various Configurations
| Hardware | GPUs | Tokens/sec | MFU | Time for 1.4T tokens |
|---|---|---|---|---|
| 8x H100 (single node) | 8 | ~4,800 | ~32% | ~3,370 days |
| 64x H100 (8 nodes) | 64 | ~35,000 | ~38% | ~463 days |
| 256x H100 (32 nodes) | 256 | ~130,000 | ~44% | ~125 days |
| 512x H100 (64 nodes) | 512 | ~240,000 | ~48% | ~67 days |
| 2048x H100 (256 nodes) | 2048 | ~850,000 | ~50% | ~19 days |
At 2048 H100 GPUs, the cluster consumes approximately 1.5 MW of power. At \2/GPU-hour$2 \times 2048 \times 24 \times 19 = $1.9M$.
11. Putting It All Together: Annotated Single Step
To close, here is one complete training step with every operation annotated with its cost:
# === STEP t ===
# Memory at start: 3.83 GB (weights + optimizer, sharded)
# 1. DATA LOADING (overlapped with previous step, effectively free)
input_ids, labels = next(data_iter) # [B, S] int64, from pinned memory
input_ids = input_ids.cuda(non_blocking=True) # Async H2D transfer: 0.06 ms
labels = labels.cuda(non_blocking=True)
# 2. FORWARD PASS (580 ms, 1,317 TFLOPs)
# Memory: 3.83 -> 18.16 GB (peak with logits)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
logits = model(input_ids) # 80 layers, each:
# RMSNorm: 0.1 ms, ~0 TFLOPs # 0.13 GFLOPs
# QKV project: 1.2 ms, 1,376 GFLOPs
# FlashAttn: 2.5 ms, 2,200 GFLOPs
# O project: 0.8 ms, 1,100 GFLOPs
# RMSNorm: 0.1 ms, ~0 TFLOPs
# SwiGLU FFN: 2.5 ms, 11,574 GFLOPs
# Total/layer: 7.2 ms, 16,250 GFLOPs
# Output head: 2.1 ms, 17,200 GFLOPs
# 3. LOSS COMPUTATION (15 ms, negligible FLOPs)
loss = F.cross_entropy(
logits.float().view(-1, 128256), # Upcast BF16->FP32: 4.2 GB
labels.view(-1)
) # Scalar FP32
# 4. BACKWARD PASS (1,080 ms, 2,744 TFLOPs)
# Memory: 18.16 -> 14.51 -> decreasing -> 4.51 GB
loss.backward()
# Gradient through output head: 4 ms
# Per layer (80x): 13.5 ms each
# Back through FFN: 7 ms
# Back through attention: 5 ms
# Back through RMSNorm (2x): 0.5 ms
# Back through residual: 0.01 ms (just addition)
# Reduce-scatter gradients: overlapped
# 5. GRADIENT CLIPPING (8 ms)
grad_norm = clip_grad_norm_(model.parameters(), max_norm=1.0)
# All-reduce to compute global norm: 5 ms
# Scale gradients if needed: 3 ms
# 6. OPTIMIZER STEP (55 ms)
optimizer.step()
# Per parameter: 5 FP32 operations (m, v, bias_correct, update, decay)
# Memory-bound: reads/writes 4 tensors per param (param, m, v, grad)
# Total memory traffic: 70B * 4 * 4 bytes = 1.12 TB
# 7. LR UPDATE (0.01 ms)
lr = cosine_schedule_with_warmup(step, 2000, 170000, 3e-4, 3e-5)
for pg in optimizer.param_groups:
pg['lr'] = lr
optimizer.zero_grad(set_to_none=True) # Free gradient memory
# Memory back to 3.83 GB
# === Total step time: ~1,738 ms ===
# === Tokens processed: B * S = 1024 * 8192 = 8,388,608 ===
# === Throughput: 8.39M / 1.738 = 4.83M tokens/sec (across 256 GPUs) ===
The backward pass dominates at 62% of step time. The forward pass is 33%. The optimizer step is 3%. Everything else (data loading, loss, clipping, LR update) is collectively less than 2%. Any optimization effort should focus on making the backward pass faster — which is why FlashAttention (reducing backward memory and enabling larger batches), gradient checkpointing (enabling longer sequences), and communication overlap (hiding all-reduce latency) are the highest-impact techniques.
Reviewer Agent Validation Challenge
The following statements about this post’s content are candidates for review. Some are true, some contain deliberate errors.
-
Claim: The forward pass of Llama 3 70B costs approximately 1,317 TFLOPs for . Verify: 80 layers at 16,250 GFLOPs each = 1,300 TFLOPs, plus 17.2 TFLOPs for the output head = 1,317.2 TFLOPs. Is this calculation consistent?
-
Claim: The backward pass costs exactly 2x the forward pass in FLOPs. Check: the table in Section 5 shows 2,744 TFLOPs backward vs 1,317 TFLOPs forward, which is a ratio of 2.08x. Is this exactly 2x or slightly more? Why?
-
Claim: AdamW stores 3 tensors per parameter: the parameter itself, the first moment , and the second moment . But the memory calculation shows 4 tensors (BF16 weights, FP32 master weights, , ). Verify: how many distinct copies of each parameter exist, and what is the total per-parameter memory?
-
Claim: Loss at random initialization is . Verify: is the loss log base (natural log) and is the value correct? Cross-entropy of a uniform distribution over classes with a one-hot target is .
-
Claim: Warmup of 2000 steps with effective batch 1024 and means tokens of warmup. Verify the arithmetic.
-
Claim: Per-GPU memory for optimizer states with ZeRO-3 across 256 GPUs is GB. Check: is GB the correct total for Adam states on a 70B model?
-
Claim: The output head matmul costs TFLOPs. Verify: the output head shape is , so the cost is . With : . Compute this value.
-
Claim: Communication cost per layer for all-gather is GB. Verify: the factor of 2 at the start is for bytes per BF16 element. Should the formula use bytes or elements? Is 3.42 GB correct?