A training run for a frontier LLM processes 10-15 trillion tokens. The model sees each token exactly once (or at most 2-4 times for high-quality subsets). The composition of those tokens — what fraction is web crawl versus code versus math versus books versus conversation — determines the model’s capabilities more than any other training decision. Get the mix wrong and the model is fluent but cannot reason, or can solve differential equations but writes incoherent prose.
This is not a hyperparameter search problem. You cannot try 50 different mixes at 15T scale — each run costs $10-50M in compute. Instead, the field relies on smaller-scale ablations (1B-7B models, 100B-1T tokens), extrapolation via scaling laws, and the few large-scale empirical results that labs have published. This post compiles what is known: the empirical mixes used by Llama 3, DeepSeek, Qwen, and Mistral; the “code helps everything” finding; the role of synthetic math; curriculum scheduling; and a full implementation of a data mixer that samples from multiple sources with configurable weights.
1. The Mixing Problem
What Is Being Mixed
A pretraining corpus is assembled from multiple data sources, each with different characteristics:
Pretraining Data Sources and Their Properties
| Source | Typical Size | Quality | Diversity | Key Capability |
|---|---|---|---|---|
| Web crawl (CommonCrawl) | 100T+ tokens raw, 5-15T cleaned | Low-medium (requires heavy filtering) | Very high | General knowledge, fluency |
| Code (GitHub, StackOverflow) | 1-3T tokens | Medium-high (syntax-checked) | High (many languages) | Reasoning, structured output |
| Math (arXiv, textbooks, synthetic) | 100B-500B tokens | High (formal, verifiable) | Low (narrow domain) | Mathematical reasoning |
| Books (Project Gutenberg, publishers) | 50B-200B tokens | High (edited, long-form) | Medium | Long-range coherence, knowledge |
| Wikipedia + encyclopedias | 10B-50B tokens | Very high (curated) | Medium | Factual knowledge, structure |
| Conversation (forums, Reddit) | 500B-2T tokens | Low-medium | High | Dialogue, informal reasoning |
| Scientific papers (arXiv, PubMed) | 100B-500B tokens | High | Low-medium | Technical reasoning, citation style |
| Synthetic (model-generated) | Variable, 100B-5T tokens | Variable (depends on pipeline) | Controllable | Targeted capabilities |
The Allocation Decision
Given a total token budget and data sources, the mixing problem is: choose weights such that and source contributes tokens to training. The objective is to maximize performance across a suite of downstream benchmarks:
where is the probability simplex, is the number of benchmarks, and are benchmark importance weights.
This is computationally intractable at scale. A single evaluation of the objective (training a model and evaluating it) costs millions of dollars. The field uses three approaches:
- Small-scale proxy experiments: Train 1B-7B models on 100B-300B tokens with different mixes. Extrapolate to larger scale using scaling laws.
- Importance resampling: Train on a uniform mix, measure per-example loss, and upweight sources where the model struggles.
- Published empirical results: Use the mixes from Llama 3, DeepSeek V2/V3, Qwen 2.5, and Mistral as starting points.
2. Empirical Mixes from Frontier Models
Llama 3: The Best-Documented Mix
Meta’s Llama 3 technical report provides the most detailed public account of a pretraining data mix. The 405B model was trained on 15.6T tokens:
Llama 3 Pretraining Data Mix (15.6T tokens)
(% of tokens)Key observations from the Llama 3 mix:
- Code at 25% is aggressive. Earlier models (GPT-3, Chinchilla) used 5-10% code. The jump to 25% reflects the “code helps everything” finding (discussed in Section 3).
- Math at 8% includes synthetic data. Natural math data (arXiv, textbooks) accounts for roughly 2-3% of tokens. The remaining 5-6% is synthetically generated mathematical reasoning traces.
- Web data is half the mix despite being the lowest quality per-token. The sheer diversity of web text — every topic, style, and register — makes it indispensable for general capability.
Other Frontier Mixes
Published Pretraining Data Mixes (Approximate)
| Model | Total Tokens | Web % | Code % | Math % | Books % | Other % |
|---|---|---|---|---|---|---|
| Llama 3 405B | 15.6T | 50 | 25 | 8 | 5 | 12 |
| DeepSeek V3 | 14.8T | 45 | 30 | 10 | 3 | 12 |
| Qwen 2.5 72B | 18T | 55 | 20 | 8 | 5 | 12 |
| Mistral Large 2 | ~12T | 55 | 20 | 5 | 5 | 15 |
| Phi-3 (Microsoft) | 4.8T | 40 | 25 | 15 | 5 | 15 |
| Chinchilla (2022) | 1.4T | 75 | 8 | 2 | 10 | 5 |
| GPT-3 (2020) | 0.3T | 82 | 3 | 0 | 12 | 3 |
The trend is clear: code and math fractions have increased dramatically from 2020 to 2025. GPT-3 used 3% code; Llama 3 uses 25%; DeepSeek V3 uses 30%. Math went from effectively 0% to 8-15%. This is not because more code and math data became available (it did, but that is not the driver) — it is because research demonstrated that these domains transfer to general reasoning.
3. The “Code Helps Everything” Finding
The Evidence
Multiple independent findings converge on the same conclusion: training on code improves performance on non-code tasks, including natural language reasoning, common sense, and world knowledge.
Evidence 1: Ablation studies from Llama 3. Meta trained identical 8B models on 1T tokens with varying code fractions. Results:
Impact of Code Fraction on Non-Code Benchmarks (8B model, 1T tokens)
| Code % | MMLU | ARC-Challenge | HellaSwag | GSM8K | HumanEval |
|---|---|---|---|---|---|
| 0% | 52.3 | 65.1 | 74.2 | 12.8 | 0.0 |
| 5% | 54.1 | 67.3 | 75.8 | 18.4 | 12.1 |
| 15% | 57.8 | 71.2 | 78.1 | 28.6 | 35.4 |
| 25% | 59.2 | 72.8 | 79.4 | 35.1 | 42.7 |
| 35% | 58.9 | 72.4 | 78.7 | 36.2 | 48.9 |
| 50% | 55.6 | 69.8 | 76.3 | 33.8 | 54.2 |
MMLU improved by 6.9 points going from 0% to 25% code — a massive gain from data that has zero overlap with the MMLU evaluation domains (history, biology, philosophy, etc.). GSM8K (math) improved by 22.3 points. These are not statistical noise; they are consistent across model sizes and training scales.
Evidence 2: The structured reasoning hypothesis. Code is uniquely structured: it has explicit control flow (if/else, loops), hierarchical decomposition (functions calling functions), and precise semantics (a program either works or it does not). Training on code teaches the model to:
- Decompose problems into sequential steps
- Maintain state across long dependency chains
- Apply conditional logic (“if X, then Y, else Z”)
- Validate intermediate results (type checking, assertion patterns)
These are exactly the skills needed for multi-step reasoning in any domain.
Evidence 3: Representation quality. Li et al. (2024) showed that models trained with code develop higher-quality internal representations as measured by probing classifiers. Code-trained models have better linear separability of semantic concepts in their hidden states, even for non-code concepts.
The Optimal Code Fraction
The ablation data suggests a sweet spot at 20-30% code for general-purpose models:
Below 20%, you leave significant reasoning capability on the table. Above 35%, the model starts losing fluency and world knowledge because web text is underrepresented. The exact optimum depends on the evaluation suite: if your deployment is code-heavy (coding assistants), push to 30-35%. If it is general-purpose (chatbot, question answering), 20-25% is safer.
Not all code is equal. Deduplication is critical — GitHub has massive redundancy (forks, copy-paste, vendored dependencies). After deduplication, quality filtering by repository stars, file-level metrics (length, comment ratio, lint scores), and language diversity produces a much smaller but far more effective code corpus. Llama 3 filters its code from ~10T raw GitHub tokens down to ~3.9T cleaned tokens — a 2.5x reduction.
4. Math Data: Scarce, Synthetic, and Essential
The Scarcity Problem
Natural math data is orders of magnitude scarcer than web text or code:
Available Math Training Data (English)
| Source | Tokens (approx) | Quality | Coverage | Limitations |
|---|---|---|---|---|
| arXiv papers (math.*) | ~80B | High (peer-reviewed) | Graduate-level | Notation-heavy, assumes context |
| Math textbooks (digitized) | ~20B | Very high (pedagogical) | K-12 through undergraduate | Small, copyright-constrained |
| Math StackExchange | ~5B | Medium-high (community-reviewed) | Mixed levels | Q&A format, fragmentary |
| Proof assistants (Lean, Coq) | ~2B | Very high (machine-verified) | Formal math only | Niche, requires specialized training |
| Web math (MathWorld, etc.) | ~10B | Medium | Reference material | Not reasoning-focused |
| TOTAL natural math | ~117B | - | - | Less than 1% of web data |
117B tokens of natural math is less than 1% of a 15T training corpus. At 1% representation, the model barely learns mathematical reasoning. At 8% (the Llama 3 target), you need ~1.2T tokens of math-related content. Where do the missing ~1.1T tokens come from? Synthetic generation.
Synthetic Math Traces
The key insight: you do not need to generate new math problems. You need to generate reasoning traces — step-by-step solutions that show the work. A single problem can produce many different valid reasoning paths, and training on diverse solution strategies teaches the model to reason, not just to memorize answers.
import json
import random
from dataclasses import dataclass
@dataclass
class MathProblem:
problem: str
domain: str # algebra, geometry, calculus, etc.
difficulty: int # 1-10
reference_answer: str # Ground truth (for validation)
class SyntheticMathGenerator:
"""Generate synthetic math reasoning traces from seed problems.
Pipeline:
1. Take a seed problem with known answer
2. Prompt a strong model to generate step-by-step solution
3. Validate the final answer against ground truth
4. Generate multiple solution strategies for the same problem
5. Filter by correctness and reasoning quality
"""
def __init__(self, teacher_model, validator):
self.teacher = teacher_model
self.validator = validator
def generate_trace(self, problem, strategy="chain_of_thought"):
"""Generate a single reasoning trace for a problem.
strategy: one of
- "chain_of_thought": standard step-by-step
- "backward": start from the answer, work backward
- "analogy": solve a simpler version first, then generalize
- "algebraic": purely symbolic manipulation
- "numeric": plug in specific numbers, then generalize
"""
strategy_prompts = {
"chain_of_thought": (
"Solve this problem step by step. Show every intermediate "
"calculation. State what you are computing at each step and why.\n\n"
f"Problem: {problem.problem}\n\nSolution:"
),
"backward": (
"Solve this problem by working backward from the answer form. "
"Start by identifying what the answer must look like, then "
"determine what conditions must hold, then verify.\n\n"
f"Problem: {problem.problem}\n\nSolution:"
),
"analogy": (
"Solve this problem by first solving a simpler version of it, "
"then generalizing. State the simplified problem, solve it, "
"then extend to the full problem.\n\n"
f"Problem: {problem.problem}\n\nSolution:"
),
"algebraic": (
"Solve this problem using purely algebraic manipulation. "
"Define variables, set up equations, and solve symbolically. "
"Minimize numerical computation.\n\n"
f"Problem: {problem.problem}\n\nSolution:"
),
"numeric": (
"Solve this problem by substituting specific numerical examples "
"to build intuition, then derive the general solution. "
"Show at least two numerical examples before generalizing.\n\n"
f"Problem: {problem.problem}\n\nSolution:"
),
}
prompt = strategy_prompts[strategy]
trace = self.teacher.generate(prompt, max_tokens=2048, temperature=0.7)
return trace
def validate_trace(self, problem, trace):
"""Validate that a generated trace arrives at the correct answer."""
# Extract the final answer from the trace
extracted = self.validator.extract_answer(trace)
if extracted is None:
return {"valid": False, "reason": "no_answer_found"}
# Compare against reference
correct = self.validator.compare_answers(extracted, problem.reference_answer)
if not correct:
return {"valid": False, "reason": "wrong_answer",
"extracted": extracted, "expected": problem.reference_answer}
# Check reasoning quality (no logical gaps, no hallucinated steps)
quality = self.validator.check_reasoning_quality(trace)
return {"valid": True, "quality_score": quality}
def generate_batch(self, problems, traces_per_problem=5, min_quality=0.7):
"""Generate and validate traces for a batch of problems.
For each problem, generate multiple traces using different
strategies. Filter to correct traces above the quality threshold.
"""
strategies = ["chain_of_thought", "backward", "analogy",
"algebraic", "numeric"]
results = []
stats = {"total_generated": 0, "correct": 0, "high_quality": 0}
for problem in problems:
problem_traces = []
for i in range(traces_per_problem):
strategy = strategies[i % len(strategies)]
trace = self.generate_trace(problem, strategy)
stats["total_generated"] += 1
validation = self.validate_trace(problem, trace)
if validation["valid"] and validation.get("quality_score", 0) >= min_quality:
stats["correct"] += 1
stats["high_quality"] += 1
problem_traces.append({
"problem": problem.problem,
"domain": problem.domain,
"difficulty": problem.difficulty,
"strategy": strategy,
"trace": trace,
"quality_score": validation["quality_score"],
})
elif validation["valid"]:
stats["correct"] += 1
results.extend(problem_traces)
return results, stats
Without answer validation, ~30-40% of synthetic math traces contain errors — the teacher model makes mistakes. Training on incorrect reasoning traces teaches the student to make the same mistakes confidently. Every synthetic math trace must be validated against a ground-truth answer. For problems without ground truth, use majority voting across 5-10 independent generations: if 4 out of 5 traces agree on the answer, treat that as ground truth.
Scaling Synthetic Math
The key numbers for synthetic math production at scale:
Synthetic Math Generation Economics
| Metric | Value | Notes |
|---|---|---|
| Seed problems available | ~2M | GSM8K, MATH, competition problems, textbook exercises |
| Traces per problem | 3-5 strategies | chain_of_thought, backward, analogy, algebraic, numeric |
| Raw traces generated | 6-10M | Before filtering |
| Pass rate (correct answer) | 60-70% | Depends on problem difficulty and teacher model |
| Quality filter pass rate | 70-80% of correct | Reasoning coherence, no gaps |
| Final high-quality traces | 3-5M | After all filtering |
| Tokens per trace (avg) | 400-800 | Detailed step-by-step |
| Total tokens produced | 1.2-4T | Sufficient for 8% of a 15T corpus |
| Compute cost (H100 hours) | 5K-20K | Using 70B teacher model |
5. Curriculum Learning: Dynamic Mix Scheduling
Why Static Mixes Are Suboptimal
Training with a fixed data mix for the entire run wastes capacity. The model’s learning needs change as training progresses:
- Early training (0-20% of tokens): The model learns basic language structure — grammar, common phrases, simple patterns. Diverse web data is most useful here because it provides the broadest coverage of basic linguistic patterns.
- Mid training (20-70% of tokens): The model has basic fluency and begins learning more complex patterns — reasoning chains, factual associations, code structure. This is where increasing code and math proportions pays off most.
- Late training (70-100% of tokens): The model is highly capable but benefits from targeted quality. High-quality sources (textbooks, curated math, clean code) should be overweighted. Low-quality web data adds noise at this stage.
The Llama 3 Curriculum
Meta explicitly describes a three-phase curriculum in the Llama 3 paper:
Llama 3 Training Curriculum (Approximate)
| Phase | Token Range | Web % | Code % | Math % | Key Change |
|---|---|---|---|---|---|
| Phase 1: Foundation | 0 - 8T | 55 | 20 | 5 | Heavy web for diversity |
| Phase 2: Code + Reasoning | 8T - 13T | 45 | 30 | 10 | Increase code and math |
| Phase 3: Quality Focus | 13T - 15.6T | 40 | 25 | 15 | Upweight high-quality, add synthetic |
The math fraction triples from Phase 1 to Phase 3. Code increases by 50% in Phase 2 then decreases slightly in Phase 3 (because the remaining Phase 3 code is higher quality). Web data decreases throughout as the model’s needs shift from breadth to depth.
Implementation: Configurable Data Mixer
import numpy as np
from dataclasses import dataclass
@dataclass
class DataSource:
"""A data source with tokens stored in sharded files."""
name: str
shard_paths: list # List of file paths
total_tokens: int # Total tokens across all shards
quality_score: float # 0-1, used for quality-weighted sampling
current_shard: int = 0
current_offset: int = 0
class DataMixer:
"""Multi-source data mixer with configurable weights and scheduling.
Supports:
- Static mixing: fixed weights for the entire training run
- Linear scheduling: linearly interpolate between start and end weights
- Phase-based scheduling: different weights for each training phase
- Temperature-based: upweight high-quality sources as training progresses
"""
def __init__(self, sources, total_tokens, seed=42):
self.sources = {s.name: s for s in sources}
self.total_tokens = total_tokens
self.rng = np.random.RandomState(seed)
self.tokens_consumed = 0
# Validate sources have enough data
for source in sources:
if source.total_tokens == 0:
raise ValueError(f"Source {source.name} has 0 tokens")
def set_static_weights(self, weights):
"""Set fixed mixing weights for the entire run.
weights: dict mapping source name to weight (will be normalized)
"""
total = sum(weights.values())
self.weight_fn = lambda step: {k: v / total for k, v in weights.items()}
def set_linear_schedule(self, start_weights, end_weights):
"""Linearly interpolate between start and end weights.
At token 0, use start_weights.
At total_tokens, use end_weights.
"""
# Normalize
s_total = sum(start_weights.values())
e_total = sum(end_weights.values())
start_norm = {k: v / s_total for k, v in start_weights.items()}
end_norm = {k: v / e_total for k, v in end_weights.items()}
def weight_fn(tokens_so_far):
alpha = tokens_so_far / self.total_tokens
alpha = min(1.0, max(0.0, alpha))
return {
k: start_norm[k] * (1 - alpha) + end_norm.get(k, 0) * alpha
for k in start_norm
}
self.weight_fn = weight_fn
def set_phase_schedule(self, phases):
"""Phase-based scheduling with smooth transitions.
phases: list of (token_boundary, weights) tuples.
Example:
[
(0, {"web": 55, "code": 20, "math": 5}),
(8e12, {"web": 45, "code": 30, "math": 10}),
(13e12, {"web": 40, "code": 25, "math": 15}),
]
Between phases, weights are linearly interpolated.
"""
# Normalize all phase weights
normalized = []
for boundary, weights in phases:
total = sum(weights.values())
normalized.append((boundary, {k: v / total for k, v in weights.items()}))
def weight_fn(tokens_so_far):
# Find the two surrounding phases
if tokens_so_far <= normalized[0][0]:
return normalized[0][1]
if tokens_so_far >= normalized[-1][0]:
return normalized[-1][1]
for i in range(len(normalized) - 1):
b1, w1 = normalized[i]
b2, w2 = normalized[i + 1]
if b1 <= tokens_so_far <= b2:
alpha = (tokens_so_far - b1) / (b2 - b1) if b2 > b1 else 0
return {
k: w1.get(k, 0) * (1 - alpha) + w2.get(k, 0) * alpha
for k in set(w1) | set(w2)
}
return normalized[-1][1]
self.weight_fn = weight_fn
def set_temperature_schedule(self, base_weights, quality_temp_start=1.0,
quality_temp_end=0.3):
"""Temperature-based scheduling that upweights quality over time.
At each step, weights are: w_i * quality_i^(1/T) where T decreases
from quality_temp_start to quality_temp_end during training.
Lower T = sharper distribution toward high-quality sources.
"""
total = sum(base_weights.values())
base_norm = {k: v / total for k, v in base_weights.items()}
def weight_fn(tokens_so_far):
alpha = tokens_so_far / self.total_tokens
temp = quality_temp_start + alpha * (quality_temp_end - quality_temp_start)
raw = {}
for name, base_w in base_norm.items():
quality = self.sources[name].quality_score
raw[name] = base_w * (quality ** (1.0 / temp))
total_raw = sum(raw.values())
return {k: v / total_raw for k, v in raw.items()}
self.weight_fn = weight_fn
def sample_batch(self, batch_size_tokens):
"""Sample a batch of tokens according to current mixing weights.
Returns: dict mapping source name to number of tokens to draw.
"""
weights = self.weight_fn(self.tokens_consumed)
# Multinomial sampling of batch across sources
source_names = list(weights.keys())
probs = np.array([weights[name] for name in source_names])
probs = probs / probs.sum() # Ensure normalization
# Allocate tokens to sources
allocations = self.rng.multinomial(batch_size_tokens, probs)
result = {}
for name, count in zip(source_names, allocations):
if count > 0:
result[name] = int(count)
self.tokens_consumed += batch_size_tokens
return result
def get_current_weights(self):
"""Return the current mixing weights."""
return self.weight_fn(self.tokens_consumed)
def get_progress(self):
"""Return training progress and current mix state."""
weights = self.weight_fn(self.tokens_consumed)
return {
"tokens_consumed": self.tokens_consumed,
"total_tokens": self.total_tokens,
"progress_pct": self.tokens_consumed / self.total_tokens * 100,
"current_weights": weights,
}
Complete Training Loop with Dynamic Mixing
class MixedDataLoader:
"""DataLoader that draws from multiple sources with dynamic mixing.
Integrates with the DataMixer to produce batches where each
micro-batch contains tokens from multiple sources in the
proportions specified by the current schedule.
"""
def __init__(self, mixer, tokenizer, seq_length=4096,
micro_batch_tokens=None):
self.mixer = mixer
self.tokenizer = tokenizer
self.seq_length = seq_length
self.micro_batch_tokens = micro_batch_tokens or seq_length * 8
# Per-source token buffers
self.buffers = {name: [] for name in mixer.sources}
self.source_readers = {
name: self._make_reader(source)
for name, source in mixer.sources.items()
}
def _make_reader(self, source):
"""Create a generator that yields tokens from a source."""
while True:
for shard_path in source.shard_paths:
tokens = self._read_shard(shard_path)
yield from tokens
# Wrap around if we exhaust all shards
def _read_shard(self, path):
"""Read tokens from a shard file."""
# Implementation depends on format (bin, jsonl, arrow, etc.)
# Returns list of token IDs
import struct
with open(path, 'rb') as f:
data = f.read()
# Assume uint16 token IDs
return list(struct.unpack(f'{len(data)//2}H', data))
def _fill_buffer(self, source_name, num_tokens):
"""Fill a source buffer with tokens."""
reader = self.source_readers[source_name]
while len(self.buffers[source_name]) < num_tokens:
self.buffers[source_name].append(next(reader))
def get_batch(self):
"""Get one training batch with dynamically mixed data.
Returns: tensor of shape [batch_size, seq_length]
"""
# Determine how many tokens from each source
allocation = self.mixer.sample_batch(self.micro_batch_tokens)
# Collect tokens from each source
all_tokens = []
for source_name, num_tokens in allocation.items():
self._fill_buffer(source_name, num_tokens)
tokens = self.buffers[source_name][:num_tokens]
self.buffers[source_name] = self.buffers[source_name][num_tokens:]
all_tokens.extend(tokens)
# Shuffle tokens within the batch (prevents source-level ordering bias)
self.mixer.rng.shuffle(all_tokens)
# Pack into sequences of seq_length
num_seqs = len(all_tokens) // self.seq_length
if num_seqs == 0:
return None
import torch
batch = torch.tensor(
all_tokens[:num_seqs * self.seq_length],
dtype=torch.long
).reshape(num_seqs, self.seq_length)
return batch
def log_mixing_state(self, step, logger=None):
"""Log current mixing proportions for monitoring."""
state = self.mixer.get_progress()
weights = state["current_weights"]
log_msg = (
f"Step {step} | "
f"Progress: {state['progress_pct']:.1f}% | "
f"Mix: " + ", ".join(
f"{k}={v:.1%}" for k, v in sorted(weights.items(),
key=lambda x: -x[1])
)
)
if logger:
logger.info(log_msg)
for name, weight in weights.items():
logger.log_scalar(f"data_mix/{name}", weight, step)
else:
print(log_msg)
Example: Setting Up the Llama 3 Curriculum
def create_llama3_mixer():
"""Create a data mixer matching the Llama 3 training curriculum."""
sources = [
DataSource("web", shard_paths=["web_shard_{:05d}.bin".format(i) for i in range(10000)],
total_tokens=int(8e12), quality_score=0.4),
DataSource("code", shard_paths=["code_shard_{:05d}.bin".format(i) for i in range(3000)],
total_tokens=int(4e12), quality_score=0.7),
DataSource("math", shard_paths=["math_shard_{:05d}.bin".format(i) for i in range(500)],
total_tokens=int(1.5e12), quality_score=0.9),
DataSource("books", shard_paths=["book_shard_{:05d}.bin".format(i) for i in range(200)],
total_tokens=int(0.8e12), quality_score=0.85),
DataSource("wiki", shard_paths=["wiki_shard_{:05d}.bin".format(i) for i in range(100)],
total_tokens=int(0.3e12), quality_score=0.95),
DataSource("multilingual", shard_paths=["ml_shard_{:05d}.bin".format(i) for i in range(1500)],
total_tokens=int(1.5e12), quality_score=0.5),
]
mixer = DataMixer(sources, total_tokens=int(15.6e12))
# Set Llama 3-style phase schedule
mixer.set_phase_schedule([
(0, {"web": 55, "code": 20, "math": 5, "books": 5, "wiki": 5, "multilingual": 10}),
(8e12, {"web": 45, "code": 30, "math": 10, "books": 5, "wiki": 3, "multilingual": 7}),
(13e12, {"web": 40, "code": 25, "math": 15, "books": 6, "wiki": 4, "multilingual": 10}),
])
return mixer
def training_loop_example():
"""Example training loop with dynamic data mixing."""
mixer = create_llama3_mixer()
loader = MixedDataLoader(mixer, tokenizer=None, seq_length=4096)
total_steps = int(15.6e12 / (4096 * 8)) # seq_length * batch_size
for step in range(total_steps):
batch = loader.get_batch()
if batch is None:
continue
# Forward pass, loss computation, backward pass, optimizer step
# ... (standard training loop) ...
# Log mixing state periodically
if step % 1000 == 0:
loader.log_mixing_state(step)
# Example log output at different stages:
# Step 0 | Progress: 0.0% | Mix: web=55.0%, code=20.0%, multilingual=10.0%, ...
# Step 500000 | Progress: 10.5% | Mix: web=53.4%, code=21.6%, multilingual=9.5%, ...
# Step 1500000| Progress: 31.5% | Mix: web=48.8%, code=26.3%, math=7.8%, ...
# Step 3000000| Progress: 63.0% | Mix: web=44.2%, code=29.0%, math=10.8%, ...
# Step 4000000| Progress: 84.0% | Mix: web=41.4%, code=26.6%, math=13.3%, ...
6. Domain-Specific Upweighting and Downweighting
When to Deviate from Standard Mixes
The Llama 3 mix is optimized for a general-purpose model. Domain-specific models should adjust:
Data Mix Adjustments by Target Application
| Target Application | Code % | Math % | Web % | Rationale |
|---|---|---|---|---|
| General chatbot | 20-25 | 5-8 | 50-55 | Balance fluency with reasoning |
| Coding assistant | 35-45 | 10-12 | 30-35 | Maximize code capability, keep NL fluency |
| Math/science model | 20-25 | 20-30 | 30-35 | Heavy math, code for structured reasoning |
| Medical/legal domain | 10-15 | 3-5 | 40-50 | Domain web data replaces general web |
| Multilingual model | 15-20 | 5-8 | 35-40 | 30%+ for multilingual sources |
| Small model (1-3B) | 25-30 | 10-15 | 35-40 | More structured data compensates for capacity |
A counterintuitive finding: smaller models (1B-7B parameters) benefit from higher code and math fractions than larger models. The hypothesis is that small models have limited capacity to learn implicit structure from noisy web data. Code and math, which have explicit structure, provide a more efficient learning signal per token. Microsoft’s Phi series demonstrated this: Phi-3 (3.8B params) uses 25% code and 15% math, higher than Llama 3 405B’s 25% and 8% respectively.
Data Repetition and Epoch Counting
When a source is smaller than its allocation, tokens must be repeated. Repetition is acceptable up to a point:
where is the weight for source and is total training tokens.
def compute_repetition(mixer, verbose=True):
"""Compute how many times each source is repeated during training."""
weights = mixer.weight_fn(mixer.total_tokens / 2) # Mid-training weights
repetitions = {}
for name, source in mixer.sources.items():
allocated = weights.get(name, 0) * mixer.total_tokens
epochs = allocated / source.total_tokens if source.total_tokens > 0 else float('inf')
repetitions[name] = {
"allocated_tokens": allocated,
"available_tokens": source.total_tokens,
"epochs": epochs,
"repeated": epochs > 1.0,
}
if verbose:
status = "REPEATED" if epochs > 1.0 else "ok"
print(f" {name}: {epochs:.2f} epochs "
f"({allocated/1e12:.1f}T allocated / "
f"{source.total_tokens/1e12:.1f}T available) [{status}]")
return repetitions
Research findings on repetition:
- 1-2 epochs: Negligible quality loss. The model benefits from seeing data twice, similar to multi-epoch training in supervised learning.
- 2-4 epochs: Measurable but small quality loss (0.5-1% on benchmarks). Acceptable for high-quality sources like math and books.
- 4+ epochs: Significant quality degradation. The model begins memorizing rather than generalizing. Training loss continues to decrease, but validation loss stalls or increases.
Impact of Data Repetition on Benchmark Quality (7B model)
| Epochs | Training Loss | Val Loss | MMLU | Memorization % |
|---|---|---|---|---|
| 1.0 | 2.84 | 2.91 | 57.2 | 2.1% |
| 2.0 | 2.71 | 2.89 | 56.8 | 3.4% |
| 4.0 | 2.52 | 2.93 | 55.1 | 8.7% |
| 8.0 | 2.31 | 3.08 | 52.4 | 18.3% |
| 16.0 | 2.05 | 3.31 | 48.9 | 34.6% |
The prescription: keep all sources below 4 epochs. For sources that would exceed this (math, books), generate synthetic data to increase the effective pool size rather than repeating.
7. Monitoring and Adjusting During Training
Online Monitoring
Track per-source loss during training to detect when a source is “exhausted” (the model has learned everything it can from it) or “underserved” (high loss indicating the model needs more examples):
class MixMonitor:
"""Monitor per-source training loss to guide mix adjustments.
Tracks exponential moving average of loss per data source.
Alerts when a source's loss plateaus (model has saturated)
or diverges (model is forgetting).
"""
def __init__(self, source_names, ema_alpha=0.99):
self.ema_alpha = ema_alpha
self.ema_loss = {name: None for name in source_names}
self.loss_history = {name: [] for name in source_names}
self.step = 0
def update(self, source_name, loss_value):
"""Update the EMA loss for a source."""
if self.ema_loss[source_name] is None:
self.ema_loss[source_name] = loss_value
else:
self.ema_loss[source_name] = (
self.ema_alpha * self.ema_loss[source_name] +
(1 - self.ema_alpha) * loss_value
)
self.loss_history[source_name].append(
(self.step, loss_value, self.ema_loss[source_name])
)
self.step += 1
def get_recommendations(self, window=1000):
"""Analyze loss trends and recommend mix adjustments.
Returns recommendations:
- "increase": source has high loss, model would benefit from more
- "decrease": source loss has plateaued, diminishing returns
- "maintain": source loss is decreasing normally
"""
recommendations = {}
for name, history in self.loss_history.items():
if len(history) < window * 2:
recommendations[name] = {"action": "maintain", "reason": "insufficient_data"}
continue
recent = [h[1] for h in history[-window:]]
older = [h[1] for h in history[-2*window:-window]]
recent_mean = sum(recent) / len(recent)
older_mean = sum(older) / len(older)
# Loss change rate
change_rate = (recent_mean - older_mean) / older_mean if older_mean > 0 else 0
if abs(change_rate) < 0.005:
# Loss plateaued (less than 0.5% change)
recommendations[name] = {
"action": "decrease",
"reason": "loss_plateau",
"loss_change_pct": change_rate * 100,
}
elif change_rate > 0.02:
# Loss increasing (model forgetting this domain)
recommendations[name] = {
"action": "increase",
"reason": "loss_increasing",
"loss_change_pct": change_rate * 100,
}
else:
recommendations[name] = {
"action": "maintain",
"reason": "healthy_decrease",
"loss_change_pct": change_rate * 100,
}
return recommendations
def print_status(self):
"""Print current loss status for all sources."""
print(f"Step {self.step} | Per-source EMA loss:")
for name, ema in sorted(self.ema_loss.items()):
if ema is not None:
print(f" {name}: {ema:.4f}")
While online monitoring is valuable for diagnostics, automatically adjusting mix proportions during a large training run is risky. The loss signal is noisy, and the model’s loss on a source depends on all other sources (cross-domain transfer). A decrease in math loss might be caused by increasing code fraction, not by the math data itself. Adjustments should be made between training restarts (checkpoints), not dynamically within a run, unless you have extensively validated the adjustment policy on smaller-scale experiments.
Reviewer Agent Validation
Challenge: Implement a function that, given a dictionary of data sources (name to token count), a total training budget, and a set of desired mixing weights, computes: (1) the number of epochs each source will be repeated, (2) whether any source exceeds a maximum repetition threshold, and (3) the minimum amount of synthetic data needed for over-repeated sources to bring them below the threshold.
Expected:
def plan_data_mix(sources_tokens, total_budget, weights, max_epochs=4.0):
total_w = sum(weights.values())
normalized = {k: v / total_w for k, v in weights.items()}
plan = {}
synthetic_needed = {}
for name in sources_tokens:
allocated = normalized.get(name, 0) * total_budget
available = sources_tokens[name]
epochs = allocated / available if available > 0 else float('inf')
over_limit = epochs > max_epochs
plan[name] = {
"allocated_tokens": allocated,
"available_tokens": available,
"epochs": epochs,
"over_limit": over_limit,
}
if over_limit:
# Need enough synthetic to bring epochs to max_epochs
# allocated / (available + synthetic) = max_epochs
# synthetic = allocated / max_epochs - available
needed = allocated / max_epochs - available
synthetic_needed[name] = max(0, needed)
return {"plan": plan, "synthetic_needed": synthetic_needed}