Meta trained Llama 3 405B on 15.6 trillion tokens — 5x more than Llama 2. The reason: Chinchilla scaling laws dictate that a 400B model needs 8T+ tokens for optimal quality, and Meta overshot the target by 95% to ensure frontier performance. Every other architectural decision — dense not MoE, GQA-8 not MHA, 128K vocabulary not 32K — flows from a single constraint: Meta serves billions of inference requests per day across Instagram, WhatsApp, and Facebook, and operational simplicity matters more than training efficiency.
Why Dense (Not MoE)
Meta’s Position
Meta chose a dense architecture for Llama 3 despite the clear training efficiency advantages of MoE. The reasoning:
-
Serving simplicity: Meta serves Llama models at massive scale across its products. MoE requires expert parallelism, all-to-all communication, and careful load balancing. Dense models are straightforward to shard with tensor parallelism.
-
Training budget is not the constraint: Meta has one of the largest GPU fleets in the world (600K+ H100s). The binding constraint is model quality per token generated at inference, not training cost.
-
Open-source ecosystem: MoE models are harder for the community to run. A dense 70B model fits on 2 consumer GPUs at INT4. A 46.7B MoE model (Mixtral) requires the same memory but gives less compute per byte loaded.
def dense_vs_moe_serving_analysis():
"""
Why dense is better for Meta's specific serving requirements.
"""
# Meta's inference scenario: billions of queries/day
# Priority: latency + cost per query + operational simplicity
dense_405b = {
"total_params": 405e9,
"active_params": 405e9,
"flops_per_token": 2 * 405e9,
"memory_gb_fp16": 810,
"gpus_needed": 12, # A100 80GB
"serving_complexity": "simple", # TP only
"latency_overhead": "none",
}
# Hypothetical MoE of equivalent quality
moe_equivalent = {
"total_params": 1500e9, # ~4x total for equivalent quality
"active_params": 100e9, # ~25% active
"flops_per_token": 2 * 100e9,
"memory_gb_fp16": 3000,
"gpus_needed": 45, # More GPUs for memory
"serving_complexity": "complex", # EP + TP + load balancing
"latency_overhead": "all-to-all communication",
}
# At Meta's scale (billions of queries):
# Dense: fewer GPUs per instance, simpler scaling
# MoE: more GPUs per instance, but faster per query (fewer FLOPs)
# Crossover: MoE wins when GPU utilization is high
# Meta's varied workload means many queries are small batch
# Dense wins for small-batch latency
return dense_405b, moe_equivalent
Meta’s bet is that training compute is cheaper than serving complexity. They can afford to spend $100M+ training a dense model because the result is operationally simpler to deploy at scale. DeepSeek, with less compute budget, made the opposite bet: spend engineering effort on MoE to reduce training cost. Both are rational given their constraints.
Grouped Query Attention (GQA-8)
The GQA Decision
Llama 3 uses GQA with 8 KV heads for both the 70B and 405B models (the 8B model uses standard GQA with 8 KV heads for 32 query heads). The choice of 8 KV heads balances quality against KV cache size.
def gqa_analysis(
d_model,
num_q_heads,
head_dim,
seq_len,
batch_size,
kv_head_options,
):
"""
Analyze the tradeoff between KV head count, cache size, and quality.
"""
results = []
for num_kv_heads in kv_head_options:
kv_dim = num_kv_heads * head_dim
# KV cache size per token per layer
kv_per_token = 2 * kv_dim * 2 # K + V, FP16
# Ratio of Q heads to KV heads (the "group" size)
group_size = num_q_heads // num_kv_heads
# Quality impact: more KV heads = better, diminishing returns
# Empirical from Llama 3 ablations:
# MHA (num_kv = num_q) = 100% quality
# GQA-8 = 99.5% quality
# GQA-1 (MQA) = 97% quality
quality_lookup = {
1: 97.0,
4: 99.0,
8: 99.5,
16: 99.8,
32: 99.9,
64: 100.0,
128: 100.0,
}
quality = quality_lookup.get(num_kv_heads, 99.0)
# Memory savings vs MHA
mha_kv_per_token = 2 * num_q_heads * head_dim * 2
savings = 1 - (kv_per_token / mha_kv_per_token)
results.append({
"num_kv_heads": num_kv_heads,
"group_size": group_size,
"kv_bytes_per_token": kv_per_token,
"kv_gb_128k_ctx": kv_per_token * seq_len * batch_size / 1e9,
"quality_pct": quality,
"memory_savings_pct": savings * 100,
})
return results
GQA Head Count Analysis (Llama 3.1 405B, 128K Context)
| KV Heads | Group Size | KV Cache (BS=1) | Quality | Savings vs MHA |
|---|---|---|---|---|
| 128 (MHA) | 1 | 63.0 GB | 100% | 0% |
| 32 | 4 | 15.8 GB | 99.9% | 75% |
| 16 | 8 | 7.9 GB | 99.8% | 87.5% |
| 8 (Llama 3) | 16 | 3.9 GB | 99.5% | 93.75% |
| 4 | 32 | 2.0 GB | 99.0% | 96.9% |
| 1 (MQA) | 128 | 0.5 GB | 97.0% | 99.2% |
GQA-8 is the sweet spot: 93.75% memory savings with only 0.5% quality loss versus full MHA. Going to GQA-4 saves another 1.9 GB but costs an additional 0.5% quality — not worth it.
Implementation
import torch
import torch.nn as nn
import math
class Llama3Attention(nn.Module):
"""
Llama 3 GQA-8 attention implementation.
"""
def __init__(
self,
d_model=8192, # 405B
num_q_heads=128,
num_kv_heads=8,
head_dim=128, # d_model / num_q_heads = 64, but Llama uses 128
max_seq_len=131072,
):
super().__init__()
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.group_size = num_q_heads // num_kv_heads # 16
# Projections
self.q_proj = nn.Linear(d_model, num_q_heads * head_dim, bias=False)
self.k_proj = nn.Linear(d_model, num_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(d_model, num_kv_heads * head_dim, bias=False)
self.o_proj = nn.Linear(num_q_heads * head_dim, d_model, bias=False)
def forward(self, x, kv_cache=None, position_ids=None):
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.num_q_heads, self.head_dim)
k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim)
v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim)
# Apply RoPE to q, k (omitted for clarity)
# Expand KV heads to match Q heads
# Each KV head is shared by group_size Q heads
k = k.unsqueeze(3).expand(-1, -1, -1, self.group_size, -1)
k = k.reshape(B, T, self.num_q_heads, self.head_dim)
v = v.unsqueeze(3).expand(-1, -1, -1, self.group_size, -1)
v = v.reshape(B, T, self.num_q_heads, self.head_dim)
# Standard attention
q = q.transpose(1, 2) # [B, H, T, D]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scale = 1.0 / math.sqrt(self.head_dim)
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
# Causal mask
causal_mask = torch.triu(
torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
)
attn = attn.masked_fill(causal_mask, float('-inf'))
attn = torch.softmax(attn, dim=-1)
output = torch.matmul(attn, v)
output = output.transpose(1, 2).reshape(B, T, -1)
return self.o_proj(output)
128K Vocabulary
Why 128K?
Llama 3 increased the vocabulary from 32K (Llama 2) to 128K (128,256 tokens). This was one of the most impactful changes.
def vocabulary_size_analysis():
"""
Analyze the impact of vocabulary size on compression and quality.
"""
vocab_configs = {
"Llama 2 (32K)": {
"vocab_size": 32000,
"bytes_per_token_english": 3.7,
"bytes_per_token_chinese": 1.8,
"bytes_per_token_code": 3.2,
"embedding_params": 32000 * 4096, # Llama 2 d_model
},
"Llama 3 (128K)": {
"vocab_size": 128256,
"bytes_per_token_english": 4.4,
"bytes_per_token_chinese": 3.5,
"bytes_per_token_code": 4.0,
"embedding_params": 128256 * 8192, # Llama 3 405B d_model
},
}
# Higher bytes per token = better compression
# Fewer tokens needed for the same text = faster inference
# But: larger embedding table = more parameters
for name, cfg in vocab_configs.items():
# Inference speedup: proportional to compression improvement
# If Llama 3 compresses English to 4.4 bytes/token vs 3.7 for Llama 2
# That means ~19% fewer tokens for the same text
pass
return vocab_configs
Vocabulary Size Impact
| Metric | Llama 2 (32K) | Llama 3 (128K) | Improvement |
|---|---|---|---|
| English bytes/token | 3.7 | 4.4 | 19% better compression |
| Chinese bytes/token | 1.8 | 3.5 | 94% better compression |
| Code bytes/token | 3.2 | 4.0 | 25% better compression |
| Tokens for 1K words (English) | ~350 | ~294 | 16% fewer tokens |
| Embedding parameters (405B) | 131M | 1.05B | 8x more |
| Inference speedup (fewer tokens) | - | ~16% | Significant at scale |
The Multilingual Argument
The primary motivation for 128K vocabulary was multilingual coverage. With 32K BPE tokens, Chinese text requires 2+ tokens per character (average 1.8 bytes/token vs 3+ for efficient encoding). With 128K tokens, the tokenizer can allocate more tokens to CJK characters, Cyrillic, Arabic, and other scripts.
def tokenization_efficiency(text, tokenizer):
"""
Measure tokenization efficiency: bytes per token.
Higher is better (more information per token).
"""
encoded = tokenizer.encode(text)
num_tokens = len(encoded)
num_bytes = len(text.encode('utf-8'))
return {
"num_tokens": num_tokens,
"num_bytes": num_bytes,
"bytes_per_token": num_bytes / num_tokens,
}
The Cost
The larger vocabulary increases the embedding and output projection sizes. For Llama 3 405B with :
- Embedding: parameters
- Output head: parameters (shared with embedding)
This is 1.05B parameters that contribute to the output prediction but do not increase the model’s representational capacity in the transformer layers. The tradeoff: better tokenization efficiency at inference justifies the parameter cost.
A 128K vocabulary means 16% fewer tokens for English text and 50%+ fewer for CJK languages. Since inference cost is per-token (KV cache grows per token, generation is sequential per token), this directly translates to 16-50% faster inference for the same text output. At Meta’s scale, this is worth far more than the 1B parameter overhead.
RoPE (Rotary Position Embeddings)
Why RoPE Won
Llama 3 uses RoPE with base frequency 500,000 for position encoding. RoPE won over alternatives for three reasons:
- Context extension: RoPE can be extended beyond training length via frequency scaling (YaRN, NTK-aware interpolation). Learned position embeddings cannot.
- No learned parameters: RoPE is computed analytically. No additional parameters to train.
- Relative position: RoPE encodes relative distances, which generalize better than absolute positions.
def compute_rope_embeddings(
seq_len,
head_dim,
base=500000.0,
device="cuda",
):
"""
Compute RoPE sin/cos embeddings.
Llama 3 uses base frequency 500,000 (vs 10,000 for Llama 2).
"""
# Compute frequencies
dim = head_dim
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
# freqs shape: [dim/2]
# Compute positions
positions = torch.arange(seq_len, device=device).float()
# positions shape: [seq_len]
# Outer product: position * frequency
angles = torch.outer(positions, freqs) # [seq_len, dim/2]
# Sin/cos embeddings
cos_embed = angles.cos()
sin_embed = angles.sin()
return cos_embed, sin_embed
def apply_rope(q, k, cos, sin):
"""
Apply RoPE to query and key tensors.
q, k: [B, H, T, D]
cos, sin: [T, D/2]
"""
# Split into pairs
q1, q2 = q[..., ::2], q[..., 1::2]
k1, k2 = k[..., ::2], k[..., 1::2]
# Rotate
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, T, D/2]
sin = sin.unsqueeze(0).unsqueeze(0)
q_rotated = torch.cat([
q1 * cos - q2 * sin,
q1 * sin + q2 * cos,
], dim=-1)
k_rotated = torch.cat([
k1 * cos - k2 * sin,
k1 * sin + k2 * cos,
], dim=-1)
return q_rotated, k_rotated
Base Frequency 500,000
Llama 2 used a RoPE base frequency of 10,000. Llama 3 uses 500,000. The higher base frequency:
- Stretches the frequency range, allowing the model to distinguish positions over longer ranges
- Was essential for extending context from 4K to 128K tokens
- Was determined empirically through context extension experiments
def rope_base_frequency_impact(bases, seq_len, dim):
"""
Show how base frequency affects the wavelengths in RoPE.
"""
for base in bases:
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
wavelengths = 2 * torch.pi / freqs
min_wavelength = wavelengths.min().item()
max_wavelength = wavelengths.max().item()
# Can the model distinguish positions at seq_len?
# Need at least one frequency with wavelength > seq_len
can_distinguish = max_wavelength > seq_len
print(f" Base {base:>10.0f}: wavelengths [{min_wavelength:.0f}, {max_wavelength:.0f}]")
print(f" Can distinguish at {seq_len}: {can_distinguish}")
# Base 10,000: wavelengths [6, 62,832] — works up to ~62K
# Base 500,000: wavelengths [6, 3,141,593] — works up to ~3M
SwiGLU FFN
The Standard Choice
Llama 3 uses SwiGLU (SiLU-gated linear unit) for the FFN activation, following the near-universal consensus:
class Llama3SwiGLUFFN(nn.Module):
"""
Llama 3 SwiGLU FFN.
Intermediate dimension is d_model * 8/3 (rounded to multiple of 256).
"""
def __init__(self, d_model=8192, multiplier=8/3):
super().__init__()
# Compute intermediate size
d_ff = int(d_model * multiplier)
d_ff = ((d_ff + 255) // 256) * 256 # Round to multiple of 256
self.gate_proj = nn.Linear(d_model, d_ff, bias=False) # W1
self.up_proj = nn.Linear(d_model, d_ff, bias=False) # W3
self.down_proj = nn.Linear(d_ff, d_model, bias=False) # W2
def forward(self, x):
return self.down_proj(
torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x)
)
The 8/3 Multiplier
With SwiGLU, the FFN has 3 weight matrices instead of 2 (gate, up, down vs just up, down for standard FFN). To keep the total FFN parameter count the same, the intermediate dimension is reduced by a factor of . With a standard intermediate size, SwiGLU uses :
Same parameter count, but SwiGLU consistently outperforms ReLU and GELU activations by 1-3% on perplexity benchmarks.
Training Data: 15 Trillion Tokens
The Data Decision
Llama 3 was trained on over 15 trillion tokens, a massive increase from Llama 2’s 2 trillion. This is among the most tokens ever used for training a single model.
def training_data_analysis():
"""
Training data scale analysis for Llama 3.
"""
configs = {
"Llama 2 70B": {
"params_B": 70,
"tokens_T": 2.0,
"tokens_per_param": 2e12 / 70e9, # ~28.6
"chinchilla_optimal": 70e9 * 20 / 1e12, # 1.4T (Chinchilla: 20 tokens/param)
"over_chinchilla": 2.0 / 1.4, # 1.4x
},
"Llama 3 70B": {
"params_B": 70,
"tokens_T": 15.0,
"tokens_per_param": 15e12 / 70e9, # ~214
"chinchilla_optimal": 1.4,
"over_chinchilla": 15.0 / 1.4, # 10.7x
},
"Llama 3.1 405B": {
"params_B": 405,
"tokens_T": 15.0,
"tokens_per_param": 15e12 / 405e9, # ~37
"chinchilla_optimal": 405e9 * 20 / 1e12, # 8.1T
"over_chinchilla": 15.0 / 8.1, # 1.85x
},
}
return configs
Training Data Scale Comparison
| Model | Parameters | Tokens | Tokens/Param | vs Chinchilla Optimal |
|---|---|---|---|---|
| Llama 2 70B | 70B | 2.0T | 28.6 | 1.4x over-trained |
| Llama 3 70B | 70B | 15.0T | 214 | 10.7x over-trained |
| Llama 3.1 405B | 405B | 15.0T | 37 | 1.85x over-trained |
| DeepSeek V3 | 671B (37B active) | 14.8T | 22 (total), 400 (active) | N/A (MoE) |
| Chinchilla 70B | 70B | 1.4T | 20 | Optimal (by Chinchilla law) |
Why Over-Train?
Llama 3 70B is trained 10.7x beyond the Chinchilla-optimal token count. This is deliberate:
-
Inference-optimal scaling: Chinchilla optimizes for training FLOPs. But if the model will be queried billions of times, it is cheaper to spend more FLOPs during training (once) to get a smaller, higher-quality model than to serve a larger model.
-
The smaller model gets better: Training the 70B model on 15T tokens makes it match the quality of a much larger model trained on fewer tokens. A 70B model trained on 15T tokens approaches the quality of a 200B+ model trained on 2T tokens.
-
Data is available: Meta has access to massive text corpora. The marginal cost of additional training data processing is small relative to the quality gains.
def inference_optimal_scaling(
training_budget_flops,
inference_queries,
gpu_cost_per_flop,
):
"""
Optimal model size depends on how many times
the model will be queried at inference.
"""
# Chinchilla: minimize training loss for fixed training FLOPs
# Inference-optimal: minimize total cost (training + inference)
# Total cost = training_FLOPs + inference_queries * inference_FLOPs_per_query
# With more queries, smaller models are preferred
# because inference FLOPs = 2 * N (params) per token
# A smaller model trained longer has lower inference cost
# Breakeven: at what query count does over-training pay off?
# Over-training cost: extra_training_flops * gpu_cost
# Inference savings per query: (N_large - N_small) * 2 * tokens_per_query * gpu_cost
pass
Meta’s approach can be summarized as: “train a smaller model for much longer to get a model that is cheap to serve.” The 70B model, over-trained by 10x, approaches the quality of models 3-5x its size. Since inference cost scales linearly with model size, this is a net win after the model is queried enough times. For Meta’s products (billions of queries), the math works out strongly in favor of over-training.
Architecture Specifications
Complete Llama 3 Family
LLAMA3_CONFIGS = {
"8B": {
"d_model": 4096,
"num_layers": 32,
"num_q_heads": 32,
"num_kv_heads": 8,
"head_dim": 128,
"d_ff": 14336,
"vocab_size": 128256,
"total_params": "8.03B",
"rope_base": 500000,
"context": 8192, # Extended to 128K in 3.1
},
"70B": {
"d_model": 8192,
"num_layers": 80,
"num_q_heads": 64,
"num_kv_heads": 8,
"head_dim": 128,
"d_ff": 28672,
"vocab_size": 128256,
"total_params": "70.6B",
"rope_base": 500000,
"context": 8192, # Extended to 128K in 3.1
},
"405B": {
"d_model": 16384,
"num_layers": 126,
"num_q_heads": 128,
"num_kv_heads": 8,
"head_dim": 128,
"d_ff": 53248,
"vocab_size": 128256,
"total_params": "405.5B",
"rope_base": 500000,
"context": 8192, # Extended to 128K in 3.1
},
}
def compute_param_breakdown(config):
"""Compute detailed parameter counts for a Llama 3 config."""
d = config["d_model"]
L = config["num_layers"]
Hq = config["num_q_heads"]
Hkv = config["num_kv_heads"]
hd = config["head_dim"]
d_ff = config["d_ff"]
V = config["vocab_size"]
# Attention per layer
q_params = d * Hq * hd
k_params = d * Hkv * hd
v_params = d * Hkv * hd
o_params = Hq * hd * d
attn_per_layer = q_params + k_params + v_params + o_params
# FFN per layer (SwiGLU: 3 matrices)
ffn_per_layer = 3 * d * d_ff
# Norms per layer (2 RMSNorm)
norm_per_layer = 2 * d
# Total per layer
per_layer = attn_per_layer + ffn_per_layer + norm_per_layer
# Embeddings
embed = V * d # Input embedding
# Output head shares input embedding weights
total = per_layer * L + embed
return {
"attention_per_layer_M": attn_per_layer / 1e6,
"ffn_per_layer_M": ffn_per_layer / 1e6,
"per_layer_M": per_layer / 1e6,
"embedding_M": embed / 1e6,
"total_B": total / 1e9,
}
Llama 3 Family Parameter Breakdown
| Component | 8B | 70B | 405B |
|---|---|---|---|
| d_model | 4096 | 8192 | 16384 |
| Layers | 32 | 80 | 126 |
| Q heads | 32 | 64 | 128 |
| KV heads | 8 | 8 | 8 |
| Head dim | 128 | 128 | 128 |
| FFN dim | 14336 | 28672 | 53248 |
| Attention params/layer | 33.6M | 134.2M | 536.9M |
| FFN params/layer | 176.2M | 704.6M | 2616.6M |
| Total params | 8.03B | 70.6B | 405.5B |
RMSNorm (Pre-Norm)
Why RMSNorm Over LayerNorm
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
Llama uses this instead of standard LayerNorm.
"""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
# RMSNorm: x / sqrt(mean(x^2) + eps) * weight
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x * rms * self.weight
RMSNorm is 10-15% faster than LayerNorm because it skips the mean subtraction step. The quality difference is negligible. Every frontier model uses RMSNorm.
Pre-Norm Architecture
Llama 3 applies normalization before each sublayer (attention and FFN), not after. The forward pass for one transformer block:
class Llama3TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attention_norm = RMSNorm(config["d_model"])
self.attention = Llama3Attention(
d_model=config["d_model"],
num_q_heads=config["num_q_heads"],
num_kv_heads=config["num_kv_heads"],
head_dim=config["head_dim"],
)
self.ffn_norm = RMSNorm(config["d_model"])
self.ffn = Llama3SwiGLUFFN(
d_model=config["d_model"],
)
def forward(self, x, kv_cache=None, position_ids=None):
# Pre-norm attention
residual = x
x = self.attention_norm(x)
x = self.attention(x, kv_cache, position_ids)
x = residual + x
# Pre-norm FFN
residual = x
x = self.ffn_norm(x)
x = self.ffn(x)
x = residual + x
return x
Context Extension: 8K to 128K
The Extension Strategy
Llama 3 was trained with an 8K context window. Llama 3.1 extended it to 128K through two techniques:
-
RoPE base frequency increase: From 10,000 (Llama 2) to 500,000 (Llama 3). The higher base frequency provides longer wavelengths, enabling position discrimination at longer ranges.
-
Continued pretraining on long sequences: After the initial 8K training, additional training on sequences up to 128K with gradually increasing length.
def context_extension_schedule():
"""
Llama 3.1 context extension training schedule.
"""
stages = [
{
"stage": 1,
"context_length": 8192,
"tokens": "15T (full pretraining)",
"rope_base": 500000,
"lr": 2.2e-4,
},
{
"stage": 2,
"context_length": 16384,
"tokens": "~100B",
"rope_base": 500000,
"lr": 1e-5,
},
{
"stage": 3,
"context_length": 65536,
"tokens": "~100B",
"rope_base": 500000,
"lr": 5e-6,
},
{
"stage": 4,
"context_length": 131072,
"tokens": "~100B",
"rope_base": 500000,
"lr": 2e-6,
},
]
return stages
NIAH (Needle In A Haystack) Accuracy by Context Length
(NIAH retrieval accuracy (%))Training Infrastructure
The 16,000 H100 Cluster
Llama 3 405B was trained on 16,384 H100 GPUs. The parallelism configuration:
def llama3_training_parallelism():
"""
Training parallelism for Llama 3.1 405B.
"""
config = {
"total_gpus": 16384,
"tensor_parallelism": 8, # Within a node
"pipeline_parallelism": 16, # Across nodes
"data_parallelism": 128, # 16384 / (8 * 16) = 128
"sequence_parallelism": True, # Overlaps with TP
"fsdp": True, # Fully Sharded Data Parallelism
"precision": "BF16",
"gradient_accumulation_steps": 1,
"micro_batch_size": 1,
"global_batch_size": 128, # 128 DP ranks * 1 micro-batch
}
return config
Llama 3.1 405B Training Configuration
| Parameter | Value | Notes |
|---|---|---|
| Hardware | 16,384 H100 80GB | Meta's GPU cluster |
| TP | 8 | Within 8-GPU node |
| PP | 16 | 16 pipeline stages across nodes |
| DP | 128 | 128 data-parallel ranks |
| Precision | BF16 | No FP8 (unlike DeepSeek V3) |
| Tokens | 15T+ | Across multiple phases |
| Training time | ~54 days | Estimated from MFU |
| MFU | ~38-40% | Typical for large-scale training |
Post-Training: SFT and RLHF
The Alignment Pipeline
Llama 3 includes detailed descriptions of the post-training pipeline:
- Supervised Fine-Tuning (SFT): Train on high-quality instruction-response pairs curated by human annotators.
- Reward Model Training: Train a reward model on human preference data (response A vs response B).
- Direct Preference Optimization (DPO): Optimize the policy directly against preferences without a separate reward model (Llama 3 uses DPO as the primary RLHF method).
- Safety Training: Additional rounds with safety-specific data.
def llama3_post_training_pipeline():
"""
Post-training pipeline for Llama 3.
"""
stages = {
"SFT": {
"data_size": "~10M examples",
"epochs": 2,
"lr": 1e-5,
"method": "Standard supervised fine-tuning",
},
"DPO": {
"data_size": "~1M preference pairs",
"beta": 0.1,
"epochs": 1,
"lr": 5e-7,
"method": "Direct Preference Optimization",
},
"Safety": {
"data_size": "~100K safety examples",
"method": "Additional SFT + DPO on safety data",
},
}
return stages
Summary
Llama 3’s architecture is a lesson in practical engineering at scale. Every choice optimizes for Meta’s specific constraints:
- Dense: Simpler serving at billions-of-queries scale. Meta has the training compute budget.
- GQA-8: 94% KV cache savings with 0.5% quality cost. The optimal balance point.
- 128K vocabulary: 16-50% fewer tokens for multilingual text. Directly reduces inference cost.
- RoPE (base 500K): Enables context extension from 8K to 128K without retraining from scratch.
- SwiGLU: Universal consensus; no reason to deviate.
- 15T tokens: Over-train the smaller model for inference-optimal scaling. Makes the 70B model approach larger models in quality.
- BF16 training: Simpler than FP8, acceptable given Meta’s compute budget.
The Llama 3 recipe is not universally optimal — DeepSeek V3 proves that MoE + FP8 can achieve equivalent quality at 18x lower training cost. But for a lab with unlimited training compute and a priority on serving simplicity, Llama 3 makes the right set of tradeoffs.