Chinchilla proved GPT-3 was overtrained by 4x. The insight: for a fixed compute budget, you get better performance by training a 70B model on 1.4T tokens than a 175B model on 300B tokens. This single paper rewrote the economics of frontier training — every lab immediately pivoted from “biggest model possible” to “optimal data-to-parameters ratio.” Llama, Grok, DeepSeek, and Qwen all follow Chinchilla scaling. The implications cascade: if 1 PetaFLOP-day buys you either a 10B model or 100B training tokens, the math dictates your architecture before you write a line of code.
The Kaplan Scaling Laws (2020)
OpenAI’s original scaling laws paper established three power-law relationships:
import math
def kaplan_loss_params(N: float, alpha_N: float = 0.076, N_c: float = 8.8e13) -> float:
"""
Kaplan scaling law: loss as a function of parameter count.
L(N) = (N_c / N) ^ alpha_N
N: parameter count
alpha_N: exponent (Kaplan found ~0.076)
N_c: critical parameter count
"""
return (N_c / N) ** alpha_N
def kaplan_loss_data(D: float, alpha_D: float = 0.095, D_c: float = 5.4e13) -> float:
"""
Loss as a function of dataset size (tokens).
L(D) = (D_c / D) ^ alpha_D
"""
return (D_c / D) ** alpha_D
def kaplan_loss_compute(C: float, alpha_C: float = 0.050, C_c: float = 3.1e8) -> float:
"""
Loss as a function of compute (FLOPs).
L(C) = (C_c / C) ^ alpha_C
"""
return (C_c / C) ** alpha_C
# Key insight: Kaplan found alpha_N > alpha_D
# This means loss is MORE sensitive to model size than data size
# Implication: for fixed compute, make the model bigger, train less
def kaplan_optimal_allocation(C_total: float) -> dict:
"""Kaplan's recommendation: scale N faster than D."""
# N proportional to C^0.73
# D proportional to C^0.27
N_optimal = C_total ** 0.73
D_optimal = C_total ** 0.27
return {"params": N_optimal, "tokens": D_optimal}
The Kaplan scaling laws had three critical implications:
- Loss follows a smooth power law with no sign of diminishing returns up to FLOPs
- Model size matters more than data size (exponent 0.076 vs 0.095 on the loss curve)
- For a 10x increase in compute, you should increase model size by ~5.5x and data by only ~1.8x
Kaplan Optimal Scaling (10x Compute Increases)
| Compute Budget | Optimal Params | Optimal Tokens | N/D Ratio |
|---|---|---|---|
| 1e20 FLOPs | 400M | 8B | 0.05 |
| 1e21 FLOPs | 2.2B | 14.4B | 0.15 |
| 1e22 FLOPs | 12B | 26B | 0.46 |
| 1e23 FLOPs | 66B | 47B | 1.4 |
| 1e24 FLOPs | 360B | 85B | 4.2 |
This led to the GPT-3 era: train a 175B parameter model on only 300B tokens. The model was massively undertrained by modern standards.
The Chinchilla Revolution
Hoffmann et al. (2022) re-examined scaling laws with a crucial methodological improvement: they trained models to convergence rather than using early stopping.
def chinchilla_optimal_params_tokens(C: float) -> dict:
"""
Chinchilla scaling law: for compute budget C (FLOPs),
the optimal parameter count N and token count D satisfy:
N_opt proportional to C^0.50
D_opt proportional to C^0.50
With the constraint: C approx 6 * N * D (forward + backward pass)
The key ratio: D/N approx 20 (train on 20 tokens per parameter)
"""
# C = 6 * N * D and D = 20 * N
# C = 6 * N * 20 * N = 120 * N^2
N_opt = math.sqrt(C / 120)
D_opt = 20 * N_opt
return {
"optimal_params": N_opt,
"optimal_tokens": D_opt,
"tokens_per_param": D_opt / N_opt,
"compute_check": 6 * N_opt * D_opt
}
def chinchilla_loss(N: float, D: float, E: float = 1.69,
A: float = 406.4, B: float = 410.7,
alpha: float = 0.34, beta: float = 0.28) -> float:
"""
Chinchilla parametric loss function:
L(N, D) = E + A / N^alpha + B / D^beta
E: irreducible loss (entropy of natural language)
A, B: scaling constants
alpha, beta: exponents for parameters and data
"""
return E + A / (N ** alpha) + B / (D ** beta)
# Compare GPT-3 vs Chinchilla
gpt3_loss = chinchilla_loss(N=175e9, D=300e9)
chinchilla_loss_val = chinchilla_loss(N=70e9, D=1.4e12)
print(f"GPT-3 (175B, 300B tokens): L = {gpt3_loss:.3f}")
print(f"Chinchilla (70B, 1.4T tokens): L = {chinchilla_loss_val:.3f}")
# Chinchilla achieves lower loss with 2.5x fewer parameters
Chinchilla’s core finding: for any compute budget , the compute-optimal model trains on approximately 20 tokens per parameter. GPT-3 trained on only 1.7 tokens per parameter — 12x undertrained by Chinchilla’s criterion.
The mathematical derivation proceeds from minimizing subject to :
Using Lagrange multipliers with the constraint :
Dividing these two equations:
This gives the optimal ratio when using Chinchilla’s fitted parameters.
Post-Chinchilla: Over-Training Is the New Norm
Chinchilla optimized for training compute. But inference compute often dominates total cost. A smaller model trained on more data costs less to serve.
def total_cost_model(
N: float,
D_train: float,
cost_per_training_flop: float,
num_inference_tokens: float,
cost_per_inference_flop: float
) -> dict:
"""
Total cost = training cost + inference cost.
Training cost = 6 * N * D_train * cost_per_training_flop
Inference cost = 2 * N * num_inference_tokens * cost_per_inference_flop
For high-volume deployment, inference dominates.
"""
train_flops = 6 * N * D_train
train_cost = train_flops * cost_per_training_flop
inference_flops = 2 * N * num_inference_tokens
inference_cost = inference_flops * cost_per_inference_flop
return {
"training_cost": train_cost,
"inference_cost": inference_cost,
"total_cost": train_cost + inference_cost,
"inference_fraction": inference_cost / (train_cost + inference_cost)
}
# Example: Llama 3 8B serves 1 trillion inference tokens
llama3_cost = total_cost_model(
N=8e9,
D_train=15e12, # 15T tokens (187x Chinchilla optimal)
cost_per_training_flop=1e-18,
num_inference_tokens=1e12,
cost_per_inference_flop=3e-18 # inference more expensive per FLOP
)
# inference_fraction >> 90% for popular models
Chinchilla-Optimal vs Over-Trained Models
| Model | Params | Tokens | Tokens/Param | Chinchilla Ratio |
|---|---|---|---|---|
| Chinchilla (70B) | 70B | 1.4T | 20 | 1.0x |
| Llama 2 (70B) | 70B | 2.0T | 29 | 1.4x |
| Llama 3 (8B) | 8B | 15T | 1,875 | 94x |
| Llama 3.1 (405B) | 405B | 15T | 37 | 1.9x |
| Mistral 7B | 7B | 8T+ | 1,143+ | 57x+ |
| Gemma 2 (9B) | 9B | 8T | 889 | 44x |
Over-Training Ratio (Tokens/Param vs Chinchilla Optimal)
The smaller models (7B-9B) are overtrained by 50-100x relative to Chinchilla’s recommendation. This is intentional: a 7B model trained on 15T tokens costs more to train than Chinchilla-optimal, but vastly less to serve per token than a 70B model with equivalent quality.
How Scaling Laws Determine Architecture
Given a target compute budget and deployment scenario, scaling laws constrain the key architecture dimensions.
def derive_architecture_from_budget(
compute_budget_flops: float,
over_training_factor: float,
target_params: float = None
) -> dict:
"""
Given compute and over-training preference,
derive architecture dimensions.
Standard transformer: N approx 12 * L * d^2
where L = layers, d = hidden dimension
Typical d/L ratio: d approx 128 * L for efficient models
"""
if target_params is None:
# Chinchilla-optimal parameters for this compute
N_chinchilla = math.sqrt(compute_budget_flops / 120)
# Adjust for over-training: smaller model, more data
target_params = N_chinchilla / math.sqrt(over_training_factor)
# Derive tokens from compute and params
tokens = compute_budget_flops / (6 * target_params)
# Derive architecture dimensions
# N = 12 * L * d^2 (attention + FFM params per layer)
# Typical: d = 128 * L (heuristic from successful models)
# N = 12 * L * (128*L)^2 = 12 * 128^2 * L^3
# L = (N / (12 * 128^2))^(1/3)
L = (target_params / (12 * 128 * 128)) ** (1/3)
L = max(1, round(L))
d = round(math.sqrt(target_params / (12 * L)))
# Round to multiple of 128 for hardware efficiency
d = 128 * round(d / 128)
actual_params = 12 * L * d * d
return {
"layers": L,
"hidden_dim": d,
"actual_params": actual_params,
"training_tokens": tokens,
"tokens_per_param": tokens / actual_params
}
# Example: 1e24 FLOP budget, various over-training factors
for otf in [1, 10, 50, 100]:
arch = derive_architecture_from_budget(1e24, otf)
print(f"OTF={otf}x: {arch['layers']}L, d={arch['hidden_dim']}, "
f"N={arch['actual_params']/1e9:.1f}B, "
f"D={arch['training_tokens']/1e12:.1f}T")
The relationship comes from counting parameters in a standard transformer layer: self-attention has parameters (Q, K, V, O projections), and the FFN has parameters (up-projection and down-projection ). So each layer has parameters, and layers gives total.
Width vs Depth Trade-offs
Scaling laws tell you the total parameter count . They do not tell you how to allocate those parameters between width and depth . Empirical findings:
def width_depth_tradeoffs(total_params: float) -> list:
"""
For a given N, enumerate width/depth configurations and
their expected properties.
Deeper models: better feature composition, harder to parallelize
Wider models: easier to parallelize (TP), diminishing returns
"""
configs = []
# Sweep depth from shallow to deep
for L in [16, 32, 48, 64, 80, 96, 128]:
d = round(math.sqrt(total_params / (12 * L)))
d = 128 * max(1, round(d / 128))
actual_N = 12 * L * d * d
# Metrics
param_efficiency = actual_N / total_params # how close to target
tp_comm_per_layer = 2 * d * 4 # bytes of allreduce per layer (fp32)
pipeline_bubble = 1.0 / L # pipeline parallelism bubble fraction
kv_cache_per_token = 2 * L * 2 * d * 2 # 2 heads approx, fp16
configs.append({
"layers": L,
"hidden_dim": d,
"actual_params_b": actual_N / 1e9,
"tp_comm_bytes": tp_comm_per_layer,
"pipeline_bubble_frac": pipeline_bubble,
"kv_bytes_per_token": kv_cache_per_token
})
return configs
# For a 70B parameter model
for cfg in width_depth_tradeoffs(70e9):
print(f"L={cfg['layers']:3d}, d={cfg['hidden_dim']:5d}, "
f"N={cfg['actual_params_b']:.1f}B, "
f"bubble={cfg['pipeline_bubble_frac']:.3f}")
Width vs Depth Configurations (70B Parameter Budget)
| Config | Layers | Hidden Dim | KV Cache/Token (KB) | Pipeline Bubble |
|---|---|---|---|---|
| Wide-Shallow | 32 | 13,568 | 3,392 | 3.1% |
| Balanced | 64 | 9,600 | 4,800 | 1.6% |
| Llama 2 70B | 80 | 8,192 | 5,120 | 1.3% |
| Deep-Narrow | 128 | 6,784 | 6,784 | 0.8% |
Deeper models have smaller pipeline bubbles (better PP efficiency) but use more KV cache memory per token. Wider models are easier to tensor-parallelize but have larger allreduce payloads per layer. The industry has converged on depth-to-width ratios where to .
Vocabulary Size and Embedding Scaling
Scaling laws primarily model transformer layer parameters. But vocabulary size and embedding dimension add parameters that scale differently.
def embedding_parameter_fraction(
vocab_size: int, hidden_dim: int, num_layers: int
) -> dict:
"""
Embedding params vs transformer params.
Embedding: V * d (input) + V * d (output, often tied)
Transformer: 12 * L * d^2
"""
embed_params = vocab_size * hidden_dim # tied embeddings
transformer_params = 12 * num_layers * hidden_dim * hidden_dim
total = embed_params + transformer_params
return {
"embedding_params": embed_params,
"transformer_params": transformer_params,
"embedding_fraction": embed_params / total,
"total_params": total
}
# Impact of vocabulary size
for V in [32000, 50257, 100000, 128256, 200000]:
result = embedding_parameter_fraction(V, 4096, 32)
print(f"V={V:>7d}: embed={result['embedding_params']/1e9:.2f}B "
f"({result['embedding_fraction']*100:.1f}% of total)")
Vocabulary Size Impact on Parameter Budget (7B-class Model)
| Vocab Size | Embed Params | % of Total | Model Examples |
|---|---|---|---|
| 32,000 | 0.13B | 1.9% | Llama 2, Mistral |
| 50,257 | 0.21B | 3.0% | GPT-2/3 |
| 100,000 | 0.41B | 5.7% | Gemma |
| 128,256 | 0.53B | 7.2% | Llama 3 |
| 200,000 | 0.82B | 10.6% | DeepSeek-V3 |
Larger vocabularies improve tokenization efficiency (fewer tokens per document), which effectively increases the data scaling coefficient. DeepSeek-V3’s 200K vocabulary means it sees more semantic content per training token, partially compensating for the parameter overhead.
Compute-Optimal Batch Size
Scaling laws also predict the optimal batch size for a given model size:
def optimal_batch_size(
N: float, loss: float,
B_crit_constant: float = 2e8
) -> dict:
"""
The critical batch size B_crit scales with the loss value:
B_crit approx B_0 / L^(1/alpha_B)
Below B_crit: compute-efficient (each step makes maximum progress)
Above B_crit: time-efficient (steps are faster but less efficient)
"""
# Empirical: B_crit roughly proportional to loss^(-4)
# At loss ~3.0 (early training): B_crit ~ 60K tokens
# At loss ~2.0 (mid training): B_crit ~ 500K tokens
# At loss ~1.8 (late training): B_crit ~ 2M tokens
B_crit = B_crit_constant / (loss ** 4)
return {
"critical_batch_tokens": B_crit,
"critical_batch_seqs_2k": B_crit / 2048,
"recommendation": "Use B_crit for compute efficiency, 2-3x B_crit for time efficiency"
}
# Batch size scaling during training
for loss_val in [3.5, 3.0, 2.5, 2.0, 1.8]:
result = optimal_batch_size(70e9, loss_val)
print(f"Loss={loss_val:.1f}: B_crit={result['critical_batch_tokens']/1e6:.1f}M tokens "
f"({result['critical_batch_seqs_2k']:.0f} seqs of 2048)")
Frontier models increase batch size during training as loss decreases. Llama 3 started with a batch size of 4M tokens and increased to 16M tokens over the course of training. This follows the compute-optimal batch size curve.
Learning Rate and Model Size Scaling
The optimal learning rate decreases with model size, following another power law:
def optimal_learning_rate(N: float, batch_size_tokens: float) -> dict:
"""
Empirical scaling of peak learning rate with model size.
lr_opt proportional to N^(-0.5) (approximate)
With mu-Parametrization (muP):
lr_opt is exactly predictable from small model sweeps.
"""
# Standard heuristic
base_lr = 3e-4 # for ~125M params
base_N = 125e6
lr_scaled = base_lr * (base_N / N) ** 0.5
# muP prediction (if available)
# lr_muP = lr_proxy * (d_proxy / d_target)
# where proxy is a small model you ran LR sweep on
return {
"standard_lr": lr_scaled,
"cosine_min_lr": lr_scaled * 0.1,
"warmup_tokens": min(batch_size_tokens * 2000, 5e9)
}
for size_b in [0.125, 1, 7, 13, 34, 70, 405]:
result = optimal_learning_rate(size_b * 1e9, 4e6)
print(f"{size_b:>5.1f}B: lr={result['standard_lr']:.2e}, "
f"min_lr={result['cosine_min_lr']:.2e}")
Learning Rate Scaling Across Model Sizes
| Model Size | Peak LR | Min LR (Cosine) | Warmup Steps |
|---|---|---|---|
| 125M | 3.0e-4 | 3.0e-5 | 2,000 |
| 1B | 1.0e-4 | 1.0e-5 | 2,000 |
| 7B | 3.0e-5 | 3.0e-6 | 2,000 |
| 70B | 1.5e-5 | 1.5e-6 | 2,000 |
| 405B | 8.0e-6 | 8.0e-7 | 2,000 |
Scaling Law Failures: When the Laws Break Down
Scaling laws predict average loss on a test set. They do not predict:
def scaling_law_limitations():
"""
Cases where smooth power-law scaling breaks down.
"""
failures = {
"emergence": {
"description": "Capabilities that appear suddenly at specific scales",
"examples": [
"Chain-of-thought reasoning (appears ~60B+)",
"In-context learning quality jumps",
"Multi-step arithmetic (appears ~100B+)"
],
"note": "Schaeffer et al. argue emergence is a measurement artifact"
},
"data_quality": {
"description": "Scaling laws assume IID data quality",
"impact": "Phi-2 (2.7B) matches 7B models via curated data",
"implication": "Data quality shifts the scaling curve vertically"
},
"architecture_changes": {
"description": "Laws fit to one architecture don't transfer",
"examples": [
"MoE models break compute-optimal N predictions",
"Mixture models have different N_eff",
"Recurrent models (RWKV, Mamba) have different scaling"
]
},
"post_training": {
"description": "RLHF/DPO shifts the capability frontier",
"impact": "A 7B model after RLHF can match 70B base on some benchmarks"
}
}
return failures
Scaling laws predict pretraining loss, not downstream task performance. A model with 0.1 lower loss might score 20 percentage points higher on a reasoning benchmark, or it might score the same. The relationship between loss and capabilities is not smooth for individual tasks.
Modern Scaling Law Extensions
Recent work extends scaling laws beyond the original Chinchilla framework:
def inference_aware_scaling(
compute_budget: float,
expected_inference_tokens: float,
training_cost_per_flop: float = 1e-18,
inference_cost_per_flop: float = 3e-18
) -> dict:
"""
Optimize for total cost (training + inference).
When inference dominates, the optimal model is SMALLER
than Chinchilla-optimal, trained on MORE data.
Total cost = 6*N*D * c_train + 2*N*I * c_infer
where I = total inference tokens over deployment lifetime
"""
# Inference-to-training ratio
infer_train_ratio = (2 * expected_inference_tokens * inference_cost_per_flop) / \
(compute_budget * training_cost_per_flop)
# Chinchilla-optimal
N_chinchilla = math.sqrt(compute_budget / 120)
D_chinchilla = 20 * N_chinchilla
# Inference-aware: reduce N, increase D
shrink_factor = (1 + infer_train_ratio) ** 0.5
N_optimal = N_chinchilla / shrink_factor
D_optimal = compute_budget / (6 * N_optimal)
return {
"chinchilla_N": N_chinchilla,
"chinchilla_D": D_chinchilla,
"inference_optimal_N": N_optimal,
"inference_optimal_D": D_optimal,
"over_training_factor": D_optimal / (20 * N_optimal),
"inference_train_cost_ratio": infer_train_ratio
}
# Example: 1e24 FLOP training budget, expect 1e15 inference tokens
result = inference_aware_scaling(1e24, 1e15)
print(f"Chinchilla: N={result['chinchilla_N']/1e9:.1f}B")
print(f"Inference-optimal: N={result['inference_optimal_N']/1e9:.1f}B")
print(f"Over-training factor: {result['over_training_factor']:.0f}x")
Training Compute Allocation: Chinchilla vs Inference-Aware
Practical Scaling Law Recipe
How frontier labs actually use scaling laws today:
def scaling_law_recipe(target_compute: float, deployment_scenario: str) -> dict:
"""
Step-by-step recipe for using scaling laws in model design.
"""
# Step 1: Run small-scale experiments
# Train 5-10 models from 100M to 2B on 10B-200B tokens each
# Fit the parametric loss function L(N, D)
small_experiments = {
"sizes": [100e6, 250e6, 500e6, 1e9, 2e9],
"tokens_each": "10x Chinchilla-optimal for each size",
"total_compute": "less than 0.1% of target budget",
"fit_parameters": "E, A, B, alpha, beta"
}
# Step 2: Determine over-training factor from deployment
if deployment_scenario == "api_serving":
over_training = 50 # high volume, inference-dominated
elif deployment_scenario == "research":
over_training = 1 # Chinchilla-optimal
elif deployment_scenario == "on_device":
over_training = 100 # must be small
else:
over_training = 5
# Step 3: Compute target model size
N_chinchilla = math.sqrt(target_compute / 120)
N_target = N_chinchilla / math.sqrt(over_training)
D_target = target_compute / (6 * N_target)
# Step 4: Choose architecture dimensions
# Use muP to transfer hyperparameters from small experiments
# Width multiple of 128, depth divisible by PP degree
# Step 5: Validate with medium-scale run
# Train at 1-5% of target compute to verify loss trajectory
return {
"target_params": N_target,
"target_tokens": D_target,
"over_training_factor": over_training,
"validation_compute": 0.02 * target_compute,
"small_experiment_compute": 0.001 * target_compute
}
Scaling Law Predictions vs Actual Frontier Models
| Model | Predicted Loss | Actual Loss | Error |
|---|---|---|---|
| Llama 2 7B | 2.05 | 2.08 | +1.5% |
| Llama 2 70B | 1.75 | 1.73 | -1.1% |
| Llama 3 8B (15T) | 1.82 | 1.79 | -1.7% |
| Llama 3.1 405B | 1.48 | 1.49 | +0.7% |
| Mistral 7B (8T+) | 1.85 | 1.83 | -1.1% |
Scaling law predictions are remarkably accurate for pretraining loss — typically within 2% for models within the fitted compute range. They become less reliable when extrapolating more than 100x beyond the fitting range.
Key Takeaways for Model Design
The practical impact of scaling laws on current frontier model design:
SCALING_LAW_DESIGN_PRINCIPLES = {
"1_budget_first": (
"Start with your compute budget and deployment scenario. "
"Scaling laws then determine model size and data requirements."
),
"2_over_train_for_inference": (
"If you expect high inference volume, train a smaller model "
"on more data. The extra training cost is amortized over inference."
),
"3_data_quality_shifts_curve": (
"Higher quality data shifts the scaling curve down. "
"Invest in data curation before scaling model size."
),
"4_validate_cheaply": (
"Run small experiments at 0.1% of target compute to fit "
"scaling law parameters specific to your data and architecture."
),
"5_architecture_follows_scale": (
"At a given N, the depth/width ratio is determined by "
"hardware efficiency: TP communication, PP bubble fraction, "
"and KV cache memory."
),
"6_moe_breaks_laws": (
"MoE models have different effective N for compute vs memory. "
"Scaling laws must be re-fit for sparse architectures."
)
}
The single most impactful finding from scaling laws research: pretraining loss is predictable from small experiments. You do not need to train a 70B model to know what loss it will achieve. This transforms model design from guesswork into engineering.
Scaling laws have moved from an academic curiosity to the foundation of billion-dollar training decisions. Every frontier lab runs scaling law experiments before committing to a large training run. The key shift from Kaplan to Chinchilla to modern practice: the optimal model size depends not just on training compute, but on the total lifecycle cost including inference. This is why the industry has moved decisively toward smaller, over-trained models for deployment.