Part of Series Frontier Model Architectures 4 of 27
1 Kimi K2: How Moonshot Built a 1T MoE That Rivals Claude and GPT-4o 2 MiniMax-01: Lightning Attention, 4M Token Context, and the Linear Attention Revival 3 Frontier Models in 2025: The Architectural Convergence and Where Innovation Happens 4 Llama 3 Architecture Decisions: Why Meta Chose Dense, GQA-8, 128K Vocab, and RoPE 5 Qwen 2.5: Alibaba's Architecture, Training Recipe, and What Makes It Competitive 6 Gemini: Google's Natively Multimodal Architecture and the 1M Token Context 7 Claude Architecture: Constitutional AI, RLHF at Scale, and the 200K Context Window 8 Grok: xAI's Architecture, Massive Scale, and Real-Time Information Integration 9 Open vs Closed Models in 2026: Llama vs GPT-4o vs Claude vs DeepSeek — The Capability Gap Analysis 10 DeepSeek-R1: The Architecture of Reasoning — GRPO Training, Multi-Stage Pipeline, and Open Weights 11 Phi and Small Language Models: How Microsoft Achieves GPT-3.5 Quality at 3B Parameters 12 Mistral and the Sliding Window: Efficient Long-Context with Linear Memory 13 Llama 4: Meta's Shift to Multimodal MoE and What It Signals 14 Training Infrastructure: How Frontier Labs Build Their GPU Clusters 15 Benchmark Deep Dive: What MMLU, HumanEval, MATH, and SWE-bench Actually Measure 16 Jamba: AI21's Hybrid Mamba-Attention Architecture 17 Yi Series: 01.AI's Bilingual Architecture 18 DBRX: Databricks' Enterprise MoE Architecture 19 OpenAI o1: Reasoning Compute Budgets and Internal CoT 20 Distilled Models: Phi, Gemma, Llama 3.2 at Small Scale 21 Llama 4: Meta's Shift to Multimodal MoE 22 Scaling Laws and Model Design: How Chinchilla Changed Architecture Decisions 23 Open Weight Release Strategy: Llama vs Mistral vs DeepSeek — Licensing and Ecosystem Impact 24 Safety Architecture: How Frontier Models Build Guardrails Into the Model Itself 25 Multimodal Model Comparison 2026: GPT-4o vs Gemini vs Claude vs Llama Vision 26 MoE vs Dense in Production: Serving Cost, Latency, and When Each Wins 27 Chinese Frontier Models: DeepSeek, Qwen, Yi, and Kimi — Architecture Comparison

Meta trained Llama 3 405B on 15.6 trillion tokens — 5x more than Llama 2. The reason: Chinchilla scaling laws dictate that a 400B model needs 8T+ tokens for optimal quality, and Meta overshot the target by 95% to ensure frontier performance. Every other architectural decision — dense not MoE, GQA-8 not MHA, 128K vocabulary not 32K — flows from a single constraint: Meta serves billions of inference requests per day across Instagram, WhatsApp, and Facebook, and operational simplicity matters more than training efficiency.

Why Dense (Not MoE)

Meta’s Position

Meta chose a dense architecture for Llama 3 despite the clear training efficiency advantages of MoE. The reasoning:

  1. Serving simplicity: Meta serves Llama models at massive scale across its products. MoE requires expert parallelism, all-to-all communication, and careful load balancing. Dense models are straightforward to shard with tensor parallelism.

  2. Training budget is not the constraint: Meta has one of the largest GPU fleets in the world (600K+ H100s). The binding constraint is model quality per token generated at inference, not training cost.

  3. Open-source ecosystem: MoE models are harder for the community to run. A dense 70B model fits on 2 consumer GPUs at INT4. A 46.7B MoE model (Mixtral) requires the same memory but gives less compute per byte loaded.

def dense_vs_moe_serving_analysis():
    """
    Why dense is better for Meta's specific serving requirements.
    """
    # Meta's inference scenario: billions of queries/day
    # Priority: latency + cost per query + operational simplicity

    dense_405b = {
        "total_params": 405e9,
        "active_params": 405e9,
        "flops_per_token": 2 * 405e9,
        "memory_gb_fp16": 810,
        "gpus_needed": 12,  # A100 80GB
        "serving_complexity": "simple",  # TP only
        "latency_overhead": "none",
    }

    # Hypothetical MoE of equivalent quality
    moe_equivalent = {
        "total_params": 1500e9,  # ~4x total for equivalent quality
        "active_params": 100e9,  # ~25% active
        "flops_per_token": 2 * 100e9,
        "memory_gb_fp16": 3000,
        "gpus_needed": 45,  # More GPUs for memory
        "serving_complexity": "complex",  # EP + TP + load balancing
        "latency_overhead": "all-to-all communication",
    }

    # At Meta's scale (billions of queries):
    # Dense: fewer GPUs per instance, simpler scaling
    # MoE: more GPUs per instance, but faster per query (fewer FLOPs)

    # Crossover: MoE wins when GPU utilization is high
    # Meta's varied workload means many queries are small batch
    # Dense wins for small-batch latency

    return dense_405b, moe_equivalent
ℹ️ The Dense Bet

Meta’s bet is that training compute is cheaper than serving complexity. They can afford to spend $100M+ training a dense model because the result is operationally simpler to deploy at scale. DeepSeek, with less compute budget, made the opposite bet: spend engineering effort on MoE to reduce training cost. Both are rational given their constraints.

Grouped Query Attention (GQA-8)

The GQA Decision

Llama 3 uses GQA with 8 KV heads for both the 70B and 405B models (the 8B model uses standard GQA with 8 KV heads for 32 query heads). The choice of 8 KV heads balances quality against KV cache size.

def gqa_analysis(
    d_model,
    num_q_heads,
    head_dim,
    seq_len,
    batch_size,
    kv_head_options,
):
    """
    Analyze the tradeoff between KV head count, cache size, and quality.
    """
    results = []
    for num_kv_heads in kv_head_options:
        kv_dim = num_kv_heads * head_dim

        # KV cache size per token per layer
        kv_per_token = 2 * kv_dim * 2  # K + V, FP16

        # Ratio of Q heads to KV heads (the "group" size)
        group_size = num_q_heads // num_kv_heads

        # Quality impact: more KV heads = better, diminishing returns
        # Empirical from Llama 3 ablations:
        # MHA (num_kv = num_q) = 100% quality
        # GQA-8 = 99.5% quality
        # GQA-1 (MQA) = 97% quality
        quality_lookup = {
            1: 97.0,
            4: 99.0,
            8: 99.5,
            16: 99.8,
            32: 99.9,
            64: 100.0,
            128: 100.0,
        }
        quality = quality_lookup.get(num_kv_heads, 99.0)

        # Memory savings vs MHA
        mha_kv_per_token = 2 * num_q_heads * head_dim * 2
        savings = 1 - (kv_per_token / mha_kv_per_token)

        results.append({
            "num_kv_heads": num_kv_heads,
            "group_size": group_size,
            "kv_bytes_per_token": kv_per_token,
            "kv_gb_128k_ctx": kv_per_token * seq_len * batch_size / 1e9,
            "quality_pct": quality,
            "memory_savings_pct": savings * 100,
        })

    return results
📊

GQA Head Count Analysis (Llama 3.1 405B, 128K Context)

KV HeadsGroup SizeKV Cache (BS=1)QualitySavings vs MHA
128 (MHA) 1 63.0 GB 100% 0%
32 4 15.8 GB 99.9% 75%
16 8 7.9 GB 99.8% 87.5%
8 (Llama 3) 16 3.9 GB 99.5% 93.75%
4 32 2.0 GB 99.0% 96.9%
1 (MQA) 128 0.5 GB 97.0% 99.2%

GQA-8 is the sweet spot: 93.75% memory savings with only 0.5% quality loss versus full MHA. Going to GQA-4 saves another 1.9 GB but costs an additional 0.5% quality — not worth it.

Implementation

import torch
import torch.nn as nn
import math

class Llama3Attention(nn.Module):
    """
    Llama 3 GQA-8 attention implementation.
    """
    def __init__(
        self,
        d_model=8192,        # 405B
        num_q_heads=128,
        num_kv_heads=8,
        head_dim=128,         # d_model / num_q_heads = 64, but Llama uses 128
        max_seq_len=131072,
    ):
        super().__init__()
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.group_size = num_q_heads // num_kv_heads  # 16

        # Projections
        self.q_proj = nn.Linear(d_model, num_q_heads * head_dim, bias=False)
        self.k_proj = nn.Linear(d_model, num_kv_heads * head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, num_kv_heads * head_dim, bias=False)
        self.o_proj = nn.Linear(num_q_heads * head_dim, d_model, bias=False)

    def forward(self, x, kv_cache=None, position_ids=None):
        B, T, _ = x.shape

        q = self.q_proj(x).view(B, T, self.num_q_heads, self.head_dim)
        k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim)
        v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim)

        # Apply RoPE to q, k (omitted for clarity)

        # Expand KV heads to match Q heads
        # Each KV head is shared by group_size Q heads
        k = k.unsqueeze(3).expand(-1, -1, -1, self.group_size, -1)
        k = k.reshape(B, T, self.num_q_heads, self.head_dim)
        v = v.unsqueeze(3).expand(-1, -1, -1, self.group_size, -1)
        v = v.reshape(B, T, self.num_q_heads, self.head_dim)

        # Standard attention
        q = q.transpose(1, 2)  # [B, H, T, D]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        scale = 1.0 / math.sqrt(self.head_dim)
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale

        # Causal mask
        causal_mask = torch.triu(
            torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1
        )
        attn = attn.masked_fill(causal_mask, float('-inf'))
        attn = torch.softmax(attn, dim=-1)

        output = torch.matmul(attn, v)
        output = output.transpose(1, 2).reshape(B, T, -1)
        return self.o_proj(output)

128K Vocabulary

Why 128K?

Llama 3 increased the vocabulary from 32K (Llama 2) to 128K (128,256 tokens). This was one of the most impactful changes.

def vocabulary_size_analysis():
    """
    Analyze the impact of vocabulary size on compression and quality.
    """
    vocab_configs = {
        "Llama 2 (32K)": {
            "vocab_size": 32000,
            "bytes_per_token_english": 3.7,
            "bytes_per_token_chinese": 1.8,
            "bytes_per_token_code": 3.2,
            "embedding_params": 32000 * 4096,  # Llama 2 d_model
        },
        "Llama 3 (128K)": {
            "vocab_size": 128256,
            "bytes_per_token_english": 4.4,
            "bytes_per_token_chinese": 3.5,
            "bytes_per_token_code": 4.0,
            "embedding_params": 128256 * 8192,  # Llama 3 405B d_model
        },
    }

    # Higher bytes per token = better compression
    # Fewer tokens needed for the same text = faster inference
    # But: larger embedding table = more parameters

    for name, cfg in vocab_configs.items():
        # Inference speedup: proportional to compression improvement
        # If Llama 3 compresses English to 4.4 bytes/token vs 3.7 for Llama 2
        # That means ~19% fewer tokens for the same text
        pass

    return vocab_configs
📊

Vocabulary Size Impact

MetricLlama 2 (32K)Llama 3 (128K)Improvement
English bytes/token 3.7 4.4 19% better compression
Chinese bytes/token 1.8 3.5 94% better compression
Code bytes/token 3.2 4.0 25% better compression
Tokens for 1K words (English) ~350 ~294 16% fewer tokens
Embedding parameters (405B) 131M 1.05B 8x more
Inference speedup (fewer tokens) - ~16% Significant at scale

The Multilingual Argument

The primary motivation for 128K vocabulary was multilingual coverage. With 32K BPE tokens, Chinese text requires 2+ tokens per character (average 1.8 bytes/token vs 3+ for efficient encoding). With 128K tokens, the tokenizer can allocate more tokens to CJK characters, Cyrillic, Arabic, and other scripts.

def tokenization_efficiency(text, tokenizer):
    """
    Measure tokenization efficiency: bytes per token.
    Higher is better (more information per token).
    """
    encoded = tokenizer.encode(text)
    num_tokens = len(encoded)
    num_bytes = len(text.encode('utf-8'))
    return {
        "num_tokens": num_tokens,
        "num_bytes": num_bytes,
        "bytes_per_token": num_bytes / num_tokens,
    }

The Cost

The larger vocabulary increases the embedding and output projection sizes. For Llama 3 405B with dmodel=8192d_{\text{model}} = 8192:

  • Embedding: 128256×8192=1.05B128256 \times 8192 = 1.05\text{B} parameters
  • Output head: 8192×128256=1.05B8192 \times 128256 = 1.05\text{B} parameters (shared with embedding)

This is 1.05B parameters that contribute to the output prediction but do not increase the model’s representational capacity in the transformer layers. The tradeoff: better tokenization efficiency at inference justifies the parameter cost.

The Inference Speed Argument

A 128K vocabulary means 16% fewer tokens for English text and 50%+ fewer for CJK languages. Since inference cost is per-token (KV cache grows per token, generation is sequential per token), this directly translates to 16-50% faster inference for the same text output. At Meta’s scale, this is worth far more than the 1B parameter overhead.

RoPE (Rotary Position Embeddings)

Why RoPE Won

Llama 3 uses RoPE with base frequency 500,000 for position encoding. RoPE won over alternatives for three reasons:

  1. Context extension: RoPE can be extended beyond training length via frequency scaling (YaRN, NTK-aware interpolation). Learned position embeddings cannot.
  2. No learned parameters: RoPE is computed analytically. No additional parameters to train.
  3. Relative position: RoPE encodes relative distances, which generalize better than absolute positions.
def compute_rope_embeddings(
    seq_len,
    head_dim,
    base=500000.0,
    device="cuda",
):
    """
    Compute RoPE sin/cos embeddings.
    Llama 3 uses base frequency 500,000 (vs 10,000 for Llama 2).
    """
    # Compute frequencies
    dim = head_dim
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
    # freqs shape: [dim/2]

    # Compute positions
    positions = torch.arange(seq_len, device=device).float()
    # positions shape: [seq_len]

    # Outer product: position * frequency
    angles = torch.outer(positions, freqs)  # [seq_len, dim/2]

    # Sin/cos embeddings
    cos_embed = angles.cos()
    sin_embed = angles.sin()

    return cos_embed, sin_embed

def apply_rope(q, k, cos, sin):
    """
    Apply RoPE to query and key tensors.
    q, k: [B, H, T, D]
    cos, sin: [T, D/2]
    """
    # Split into pairs
    q1, q2 = q[..., ::2], q[..., 1::2]
    k1, k2 = k[..., ::2], k[..., 1::2]

    # Rotate
    cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, T, D/2]
    sin = sin.unsqueeze(0).unsqueeze(0)

    q_rotated = torch.cat([
        q1 * cos - q2 * sin,
        q1 * sin + q2 * cos,
    ], dim=-1)

    k_rotated = torch.cat([
        k1 * cos - k2 * sin,
        k1 * sin + k2 * cos,
    ], dim=-1)

    return q_rotated, k_rotated

Base Frequency 500,000

Llama 2 used a RoPE base frequency of 10,000. Llama 3 uses 500,000. The higher base frequency:

  • Stretches the frequency range, allowing the model to distinguish positions over longer ranges
  • Was essential for extending context from 4K to 128K tokens
  • Was determined empirically through context extension experiments
def rope_base_frequency_impact(bases, seq_len, dim):
    """
    Show how base frequency affects the wavelengths in RoPE.
    """
    for base in bases:
        freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        wavelengths = 2 * torch.pi / freqs

        min_wavelength = wavelengths.min().item()
        max_wavelength = wavelengths.max().item()

        # Can the model distinguish positions at seq_len?
        # Need at least one frequency with wavelength > seq_len
        can_distinguish = max_wavelength > seq_len

        print(f"  Base {base:>10.0f}: wavelengths [{min_wavelength:.0f}, {max_wavelength:.0f}]")
        print(f"  Can distinguish at {seq_len}: {can_distinguish}")

# Base 10,000: wavelengths [6, 62,832] — works up to ~62K
# Base 500,000: wavelengths [6, 3,141,593] — works up to ~3M

SwiGLU FFN

The Standard Choice

Llama 3 uses SwiGLU (SiLU-gated linear unit) for the FFN activation, following the near-universal consensus:

class Llama3SwiGLUFFN(nn.Module):
    """
    Llama 3 SwiGLU FFN.
    Intermediate dimension is d_model * 8/3 (rounded to multiple of 256).
    """
    def __init__(self, d_model=8192, multiplier=8/3):
        super().__init__()
        # Compute intermediate size
        d_ff = int(d_model * multiplier)
        d_ff = ((d_ff + 255) // 256) * 256  # Round to multiple of 256

        self.gate_proj = nn.Linear(d_model, d_ff, bias=False)   # W1
        self.up_proj = nn.Linear(d_model, d_ff, bias=False)     # W3
        self.down_proj = nn.Linear(d_ff, d_model, bias=False)   # W2

    def forward(self, x):
        return self.down_proj(
            torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x)
        )

The 8/3 Multiplier

With SwiGLU, the FFN has 3 weight matrices instead of 2 (gate, up, down vs just up, down for standard FFN). To keep the total FFN parameter count the same, the intermediate dimension is reduced by a factor of 2/32/3. With a standard 4×dmodel4 \times d_{\text{model}} intermediate size, SwiGLU uses 83×dmodel\frac{8}{3} \times d_{\text{model}}:

Standard FFN params:2×d×4d=8d2\text{Standard FFN params}: 2 \times d \times 4d = 8d^2 SwiGLU FFN params:3×d×8d3=8d2\text{SwiGLU FFN params}: 3 \times d \times \frac{8d}{3} = 8d^2

Same parameter count, but SwiGLU consistently outperforms ReLU and GELU activations by 1-3% on perplexity benchmarks.

Training Data: 15 Trillion Tokens

The Data Decision

Llama 3 was trained on over 15 trillion tokens, a massive increase from Llama 2’s 2 trillion. This is among the most tokens ever used for training a single model.

def training_data_analysis():
    """
    Training data scale analysis for Llama 3.
    """
    configs = {
        "Llama 2 70B": {
            "params_B": 70,
            "tokens_T": 2.0,
            "tokens_per_param": 2e12 / 70e9,  # ~28.6
            "chinchilla_optimal": 70e9 * 20 / 1e12,  # 1.4T (Chinchilla: 20 tokens/param)
            "over_chinchilla": 2.0 / 1.4,  # 1.4x
        },
        "Llama 3 70B": {
            "params_B": 70,
            "tokens_T": 15.0,
            "tokens_per_param": 15e12 / 70e9,  # ~214
            "chinchilla_optimal": 1.4,
            "over_chinchilla": 15.0 / 1.4,  # 10.7x
        },
        "Llama 3.1 405B": {
            "params_B": 405,
            "tokens_T": 15.0,
            "tokens_per_param": 15e12 / 405e9,  # ~37
            "chinchilla_optimal": 405e9 * 20 / 1e12,  # 8.1T
            "over_chinchilla": 15.0 / 8.1,  # 1.85x
        },
    }

    return configs
📊

Training Data Scale Comparison

ModelParametersTokensTokens/Paramvs Chinchilla Optimal
Llama 2 70B 70B 2.0T 28.6 1.4x over-trained
Llama 3 70B 70B 15.0T 214 10.7x over-trained
Llama 3.1 405B 405B 15.0T 37 1.85x over-trained
DeepSeek V3 671B (37B active) 14.8T 22 (total), 400 (active) N/A (MoE)
Chinchilla 70B 70B 1.4T 20 Optimal (by Chinchilla law)

Why Over-Train?

Llama 3 70B is trained 10.7x beyond the Chinchilla-optimal token count. This is deliberate:

  1. Inference-optimal scaling: Chinchilla optimizes for training FLOPs. But if the model will be queried billions of times, it is cheaper to spend more FLOPs during training (once) to get a smaller, higher-quality model than to serve a larger model.

  2. The smaller model gets better: Training the 70B model on 15T tokens makes it match the quality of a much larger model trained on fewer tokens. A 70B model trained on 15T tokens approaches the quality of a 200B+ model trained on 2T tokens.

  3. Data is available: Meta has access to massive text corpora. The marginal cost of additional training data processing is small relative to the quality gains.

def inference_optimal_scaling(
    training_budget_flops,
    inference_queries,
    gpu_cost_per_flop,
):
    """
    Optimal model size depends on how many times
    the model will be queried at inference.
    """
    # Chinchilla: minimize training loss for fixed training FLOPs
    # Inference-optimal: minimize total cost (training + inference)
    # Total cost = training_FLOPs + inference_queries * inference_FLOPs_per_query

    # With more queries, smaller models are preferred
    # because inference FLOPs = 2 * N (params) per token
    # A smaller model trained longer has lower inference cost

    # Breakeven: at what query count does over-training pay off?
    # Over-training cost: extra_training_flops * gpu_cost
    # Inference savings per query: (N_large - N_small) * 2 * tokens_per_query * gpu_cost
    pass
💡 The Llama 3 Training Insight

Meta’s approach can be summarized as: “train a smaller model for much longer to get a model that is cheap to serve.” The 70B model, over-trained by 10x, approaches the quality of models 3-5x its size. Since inference cost scales linearly with model size, this is a net win after the model is queried enough times. For Meta’s products (billions of queries), the math works out strongly in favor of over-training.

Architecture Specifications

Complete Llama 3 Family

LLAMA3_CONFIGS = {
    "8B": {
        "d_model": 4096,
        "num_layers": 32,
        "num_q_heads": 32,
        "num_kv_heads": 8,
        "head_dim": 128,
        "d_ff": 14336,
        "vocab_size": 128256,
        "total_params": "8.03B",
        "rope_base": 500000,
        "context": 8192,  # Extended to 128K in 3.1
    },
    "70B": {
        "d_model": 8192,
        "num_layers": 80,
        "num_q_heads": 64,
        "num_kv_heads": 8,
        "head_dim": 128,
        "d_ff": 28672,
        "vocab_size": 128256,
        "total_params": "70.6B",
        "rope_base": 500000,
        "context": 8192,  # Extended to 128K in 3.1
    },
    "405B": {
        "d_model": 16384,
        "num_layers": 126,
        "num_q_heads": 128,
        "num_kv_heads": 8,
        "head_dim": 128,
        "d_ff": 53248,
        "vocab_size": 128256,
        "total_params": "405.5B",
        "rope_base": 500000,
        "context": 8192,  # Extended to 128K in 3.1
    },
}

def compute_param_breakdown(config):
    """Compute detailed parameter counts for a Llama 3 config."""
    d = config["d_model"]
    L = config["num_layers"]
    Hq = config["num_q_heads"]
    Hkv = config["num_kv_heads"]
    hd = config["head_dim"]
    d_ff = config["d_ff"]
    V = config["vocab_size"]

    # Attention per layer
    q_params = d * Hq * hd
    k_params = d * Hkv * hd
    v_params = d * Hkv * hd
    o_params = Hq * hd * d
    attn_per_layer = q_params + k_params + v_params + o_params

    # FFN per layer (SwiGLU: 3 matrices)
    ffn_per_layer = 3 * d * d_ff

    # Norms per layer (2 RMSNorm)
    norm_per_layer = 2 * d

    # Total per layer
    per_layer = attn_per_layer + ffn_per_layer + norm_per_layer

    # Embeddings
    embed = V * d  # Input embedding
    # Output head shares input embedding weights

    total = per_layer * L + embed

    return {
        "attention_per_layer_M": attn_per_layer / 1e6,
        "ffn_per_layer_M": ffn_per_layer / 1e6,
        "per_layer_M": per_layer / 1e6,
        "embedding_M": embed / 1e6,
        "total_B": total / 1e9,
    }
📊

Llama 3 Family Parameter Breakdown

Component8B70B405B
d_model 4096 8192 16384
Layers 32 80 126
Q heads 32 64 128
KV heads 8 8 8
Head dim 128 128 128
FFN dim 14336 28672 53248
Attention params/layer 33.6M 134.2M 536.9M
FFN params/layer 176.2M 704.6M 2616.6M
Total params 8.03B 70.6B 405.5B

RMSNorm (Pre-Norm)

Why RMSNorm Over LayerNorm

class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization.
    Llama uses this instead of standard LayerNorm.
    """
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # RMSNorm: x / sqrt(mean(x^2) + eps) * weight
        rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * rms * self.weight

RMSNorm is 10-15% faster than LayerNorm because it skips the mean subtraction step. The quality difference is negligible. Every frontier model uses RMSNorm.

Pre-Norm Architecture

Llama 3 applies normalization before each sublayer (attention and FFN), not after. The forward pass for one transformer block:

class Llama3TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention_norm = RMSNorm(config["d_model"])
        self.attention = Llama3Attention(
            d_model=config["d_model"],
            num_q_heads=config["num_q_heads"],
            num_kv_heads=config["num_kv_heads"],
            head_dim=config["head_dim"],
        )
        self.ffn_norm = RMSNorm(config["d_model"])
        self.ffn = Llama3SwiGLUFFN(
            d_model=config["d_model"],
        )

    def forward(self, x, kv_cache=None, position_ids=None):
        # Pre-norm attention
        residual = x
        x = self.attention_norm(x)
        x = self.attention(x, kv_cache, position_ids)
        x = residual + x

        # Pre-norm FFN
        residual = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = residual + x

        return x

Context Extension: 8K to 128K

The Extension Strategy

Llama 3 was trained with an 8K context window. Llama 3.1 extended it to 128K through two techniques:

  1. RoPE base frequency increase: From 10,000 (Llama 2) to 500,000 (Llama 3). The higher base frequency provides longer wavelengths, enabling position discrimination at longer ranges.

  2. Continued pretraining on long sequences: After the initial 8K training, additional training on sequences up to 128K with gradually increasing length.

def context_extension_schedule():
    """
    Llama 3.1 context extension training schedule.
    """
    stages = [
        {
            "stage": 1,
            "context_length": 8192,
            "tokens": "15T (full pretraining)",
            "rope_base": 500000,
            "lr": 2.2e-4,
        },
        {
            "stage": 2,
            "context_length": 16384,
            "tokens": "~100B",
            "rope_base": 500000,
            "lr": 1e-5,
        },
        {
            "stage": 3,
            "context_length": 65536,
            "tokens": "~100B",
            "rope_base": 500000,
            "lr": 5e-6,
        },
        {
            "stage": 4,
            "context_length": 131072,
            "tokens": "~100B",
            "rope_base": 500000,
            "lr": 2e-6,
        },
    ]
    return stages

NIAH (Needle In A Haystack) Accuracy by Context Length

(NIAH retrieval accuracy (%))
8K (base training) 99.5% accuracy
99.5 NIAH retrieval accuracy (%)
32K (extended) 98.2% accuracy
98.2 NIAH retrieval accuracy (%)
64K (extended) 96.8% accuracy
96.8 NIAH retrieval accuracy (%)
128K (extended) 95.1% accuracy
95.1 NIAH retrieval accuracy (%)

Training Infrastructure

The 16,000 H100 Cluster

Llama 3 405B was trained on 16,384 H100 GPUs. The parallelism configuration:

def llama3_training_parallelism():
    """
    Training parallelism for Llama 3.1 405B.
    """
    config = {
        "total_gpus": 16384,
        "tensor_parallelism": 8,     # Within a node
        "pipeline_parallelism": 16,  # Across nodes
        "data_parallelism": 128,     # 16384 / (8 * 16) = 128
        "sequence_parallelism": True,  # Overlaps with TP
        "fsdp": True,                # Fully Sharded Data Parallelism
        "precision": "BF16",
        "gradient_accumulation_steps": 1,
        "micro_batch_size": 1,
        "global_batch_size": 128,    # 128 DP ranks * 1 micro-batch
    }
    return config
📊

Llama 3.1 405B Training Configuration

ParameterValueNotes
Hardware 16,384 H100 80GB Meta's GPU cluster
TP 8 Within 8-GPU node
PP 16 16 pipeline stages across nodes
DP 128 128 data-parallel ranks
Precision BF16 No FP8 (unlike DeepSeek V3)
Tokens 15T+ Across multiple phases
Training time ~54 days Estimated from MFU
MFU ~38-40% Typical for large-scale training

Post-Training: SFT and RLHF

The Alignment Pipeline

Llama 3 includes detailed descriptions of the post-training pipeline:

  1. Supervised Fine-Tuning (SFT): Train on high-quality instruction-response pairs curated by human annotators.
  2. Reward Model Training: Train a reward model on human preference data (response A vs response B).
  3. Direct Preference Optimization (DPO): Optimize the policy directly against preferences without a separate reward model (Llama 3 uses DPO as the primary RLHF method).
  4. Safety Training: Additional rounds with safety-specific data.
def llama3_post_training_pipeline():
    """
    Post-training pipeline for Llama 3.
    """
    stages = {
        "SFT": {
            "data_size": "~10M examples",
            "epochs": 2,
            "lr": 1e-5,
            "method": "Standard supervised fine-tuning",
        },
        "DPO": {
            "data_size": "~1M preference pairs",
            "beta": 0.1,
            "epochs": 1,
            "lr": 5e-7,
            "method": "Direct Preference Optimization",
        },
        "Safety": {
            "data_size": "~100K safety examples",
            "method": "Additional SFT + DPO on safety data",
        },
    }
    return stages

Summary

Llama 3’s architecture is a lesson in practical engineering at scale. Every choice optimizes for Meta’s specific constraints:

  • Dense: Simpler serving at billions-of-queries scale. Meta has the training compute budget.
  • GQA-8: 94% KV cache savings with 0.5% quality cost. The optimal balance point.
  • 128K vocabulary: 16-50% fewer tokens for multilingual text. Directly reduces inference cost.
  • RoPE (base 500K): Enables context extension from 8K to 128K without retraining from scratch.
  • SwiGLU: Universal consensus; no reason to deviate.
  • 15T tokens: Over-train the smaller model for inference-optimal scaling. Makes the 70B model approach larger models in quality.
  • BF16 training: Simpler than FP8, acceptable given Meta’s compute budget.

The Llama 3 recipe is not universally optimal — DeepSeek V3 proves that MoE + FP8 can achieve equivalent quality at 18x lower training cost. But for a lab with unlimited training compute and a priority on serving simplicity, Llama 3 makes the right set of tradeoffs.