A training run for a frontier LLM processes 10-15 trillion tokens. The model sees each token exactly once (or at most 2-4 times for high-quality subsets). The composition of those tokens — what fraction is web crawl versus code versus math versus books versus conversation — determines the model’s capabilities more than any other training decision. Get the mix wrong and the model is fluent but cannot reason, or can solve differential equations but writes incoherent prose.

This is not a hyperparameter search problem. You cannot try 50 different mixes at 15T scale — each run costs $10-50M in compute. Instead, the field relies on smaller-scale ablations (1B-7B models, 100B-1T tokens), extrapolation via scaling laws, and the few large-scale empirical results that labs have published. This post compiles what is known: the empirical mixes used by Llama 3, DeepSeek, Qwen, and Mistral; the “code helps everything” finding; the role of synthetic math; curriculum scheduling; and a full implementation of a data mixer that samples from multiple sources with configurable weights.


1. The Mixing Problem

What Is Being Mixed

A pretraining corpus is assembled from multiple data sources, each with different characteristics:

📊

Pretraining Data Sources and Their Properties

SourceTypical SizeQualityDiversityKey Capability
Web crawl (CommonCrawl) 100T+ tokens raw, 5-15T cleaned Low-medium (requires heavy filtering) Very high General knowledge, fluency
Code (GitHub, StackOverflow) 1-3T tokens Medium-high (syntax-checked) High (many languages) Reasoning, structured output
Math (arXiv, textbooks, synthetic) 100B-500B tokens High (formal, verifiable) Low (narrow domain) Mathematical reasoning
Books (Project Gutenberg, publishers) 50B-200B tokens High (edited, long-form) Medium Long-range coherence, knowledge
Wikipedia + encyclopedias 10B-50B tokens Very high (curated) Medium Factual knowledge, structure
Conversation (forums, Reddit) 500B-2T tokens Low-medium High Dialogue, informal reasoning
Scientific papers (arXiv, PubMed) 100B-500B tokens High Low-medium Technical reasoning, citation style
Synthetic (model-generated) Variable, 100B-5T tokens Variable (depends on pipeline) Controllable Targeted capabilities
Note: Token counts are approximate and vary by tokenizer. 'Quality' reflects information density and accuracy. Sizes are for English-dominant corpora; multilingual adds 2-5x.

The Allocation Decision

Given a total token budget TT and KK data sources, the mixing problem is: choose weights w1,w2,,wKw_1, w_2, \ldots, w_K such that iwi=1\sum_i w_i = 1 and source ii contributes wi×Tw_i \times T tokens to training. The objective is to maximize performance across a suite of downstream benchmarks:

maxwΔKj=1MλjBenchmarkj(Train(w,T))\max_{w \in \Delta^K} \sum_{j=1}^{M} \lambda_j \cdot \text{Benchmark}_j\left(\text{Train}(w, T)\right)

where ΔK\Delta^K is the probability simplex, MM is the number of benchmarks, and λj\lambda_j are benchmark importance weights.

This is computationally intractable at scale. A single evaluation of the objective (training a model and evaluating it) costs millions of dollars. The field uses three approaches:

  1. Small-scale proxy experiments: Train 1B-7B models on 100B-300B tokens with different mixes. Extrapolate to larger scale using scaling laws.
  2. Importance resampling: Train on a uniform mix, measure per-example loss, and upweight sources where the model struggles.
  3. Published empirical results: Use the mixes from Llama 3, DeepSeek V2/V3, Qwen 2.5, and Mistral as starting points.

2. Empirical Mixes from Frontier Models

Llama 3: The Best-Documented Mix

Meta’s Llama 3 technical report provides the most detailed public account of a pretraining data mix. The 405B model was trained on 15.6T tokens:

Llama 3 Pretraining Data Mix (15.6T tokens)

(% of tokens)
Web (English) ~7.8T tokens, heavily filtered
50 % of tokens
Code ~3.9T tokens, multi-language
25 % of tokens
Multilingual web ~1.2T tokens, 30+ languages
8 % of tokens
Math ~1.2T, includes synthetic
8 % of tokens
Books + long-form ~0.8T tokens
5 % of tokens
Other (wiki, papers) ~0.6T tokens
4 % of tokens

Key observations from the Llama 3 mix:

  • Code at 25% is aggressive. Earlier models (GPT-3, Chinchilla) used 5-10% code. The jump to 25% reflects the “code helps everything” finding (discussed in Section 3).
  • Math at 8% includes synthetic data. Natural math data (arXiv, textbooks) accounts for roughly 2-3% of tokens. The remaining 5-6% is synthetically generated mathematical reasoning traces.
  • Web data is half the mix despite being the lowest quality per-token. The sheer diversity of web text — every topic, style, and register — makes it indispensable for general capability.

Other Frontier Mixes

📊

Published Pretraining Data Mixes (Approximate)

ModelTotal TokensWeb %Code %Math %Books %Other %
Llama 3 405B 15.6T 50 25 8 5 12
DeepSeek V3 14.8T 45 30 10 3 12
Qwen 2.5 72B 18T 55 20 8 5 12
Mistral Large 2 ~12T 55 20 5 5 15
Phi-3 (Microsoft) 4.8T 40 25 15 5 15
Chinchilla (2022) 1.4T 75 8 2 10 5
GPT-3 (2020) 0.3T 82 3 0 12 3
Note: Percentages are approximate, reconstructed from published reports and estimates. 'Other' includes Wikipedia, scientific papers, conversations, and multilingual data. Mixes evolve during training (see curriculum section).

The trend is clear: code and math fractions have increased dramatically from 2020 to 2025. GPT-3 used 3% code; Llama 3 uses 25%; DeepSeek V3 uses 30%. Math went from effectively 0% to 8-15%. This is not because more code and math data became available (it did, but that is not the driver) — it is because research demonstrated that these domains transfer to general reasoning.


3. The “Code Helps Everything” Finding

The Evidence

Multiple independent findings converge on the same conclusion: training on code improves performance on non-code tasks, including natural language reasoning, common sense, and world knowledge.

Evidence 1: Ablation studies from Llama 3. Meta trained identical 8B models on 1T tokens with varying code fractions. Results:

📊

Impact of Code Fraction on Non-Code Benchmarks (8B model, 1T tokens)

Code %MMLUARC-ChallengeHellaSwagGSM8KHumanEval
0% 52.3 65.1 74.2 12.8 0.0
5% 54.1 67.3 75.8 18.4 12.1
15% 57.8 71.2 78.1 28.6 35.4
25% 59.2 72.8 79.4 35.1 42.7
35% 58.9 72.4 78.7 36.2 48.9
50% 55.6 69.8 76.3 33.8 54.2
Note: MMLU, ARC, HellaSwag are non-code reasoning benchmarks. GSM8K is math (grade school). HumanEval is code generation. Optimal code fraction for non-code benchmarks: 25-35%. Beyond 35%, non-code performance degrades.

MMLU improved by 6.9 points going from 0% to 25% code — a massive gain from data that has zero overlap with the MMLU evaluation domains (history, biology, philosophy, etc.). GSM8K (math) improved by 22.3 points. These are not statistical noise; they are consistent across model sizes and training scales.

Evidence 2: The structured reasoning hypothesis. Code is uniquely structured: it has explicit control flow (if/else, loops), hierarchical decomposition (functions calling functions), and precise semantics (a program either works or it does not). Training on code teaches the model to:

  • Decompose problems into sequential steps
  • Maintain state across long dependency chains
  • Apply conditional logic (“if X, then Y, else Z”)
  • Validate intermediate results (type checking, assertion patterns)

These are exactly the skills needed for multi-step reasoning in any domain.

Evidence 3: Representation quality. Li et al. (2024) showed that models trained with code develop higher-quality internal representations as measured by probing classifiers. Code-trained models have better linear separability of semantic concepts in their hidden states, even for non-code concepts.

The Optimal Code Fraction

The ablation data suggests a sweet spot at 20-30% code for general-purpose models:

codeoptimal0.20+0.05×1[target tasks include coding]\text{code}_{\text{optimal}} \approx 0.20 + 0.05 \times \mathbb{1}[\text{target tasks include coding}]

Below 20%, you leave significant reasoning capability on the table. Above 35%, the model starts losing fluency and world knowledge because web text is underrepresented. The exact optimum depends on the evaluation suite: if your deployment is code-heavy (coding assistants), push to 30-35%. If it is general-purpose (chatbot, question answering), 20-25% is safer.

Code Quality Matters More Than Quantity

Not all code is equal. Deduplication is critical — GitHub has massive redundancy (forks, copy-paste, vendored dependencies). After deduplication, quality filtering by repository stars, file-level metrics (length, comment ratio, lint scores), and language diversity produces a much smaller but far more effective code corpus. Llama 3 filters its code from ~10T raw GitHub tokens down to ~3.9T cleaned tokens — a 2.5x reduction.


4. Math Data: Scarce, Synthetic, and Essential

The Scarcity Problem

Natural math data is orders of magnitude scarcer than web text or code:

📊

Available Math Training Data (English)

SourceTokens (approx)QualityCoverageLimitations
arXiv papers (math.*) ~80B High (peer-reviewed) Graduate-level Notation-heavy, assumes context
Math textbooks (digitized) ~20B Very high (pedagogical) K-12 through undergraduate Small, copyright-constrained
Math StackExchange ~5B Medium-high (community-reviewed) Mixed levels Q&A format, fragmentary
Proof assistants (Lean, Coq) ~2B Very high (machine-verified) Formal math only Niche, requires specialized training
Web math (MathWorld, etc.) ~10B Medium Reference material Not reasoning-focused
TOTAL natural math ~117B - - Less than 1% of web data
Note: Token counts estimated from published corpus sizes. Actual filtered, deduplicated totals are lower.

117B tokens of natural math is less than 1% of a 15T training corpus. At 1% representation, the model barely learns mathematical reasoning. At 8% (the Llama 3 target), you need ~1.2T tokens of math-related content. Where do the missing ~1.1T tokens come from? Synthetic generation.

Synthetic Math Traces

The key insight: you do not need to generate new math problems. You need to generate reasoning traces — step-by-step solutions that show the work. A single problem can produce many different valid reasoning paths, and training on diverse solution strategies teaches the model to reason, not just to memorize answers.

import json
import random
from dataclasses import dataclass

@dataclass
class MathProblem:
    problem: str
    domain: str          # algebra, geometry, calculus, etc.
    difficulty: int       # 1-10
    reference_answer: str # Ground truth (for validation)

class SyntheticMathGenerator:
    """Generate synthetic math reasoning traces from seed problems.

    Pipeline:
    1. Take a seed problem with known answer
    2. Prompt a strong model to generate step-by-step solution
    3. Validate the final answer against ground truth
    4. Generate multiple solution strategies for the same problem
    5. Filter by correctness and reasoning quality
    """

    def __init__(self, teacher_model, validator):
        self.teacher = teacher_model
        self.validator = validator

    def generate_trace(self, problem, strategy="chain_of_thought"):
        """Generate a single reasoning trace for a problem.

        strategy: one of
        - "chain_of_thought": standard step-by-step
        - "backward": start from the answer, work backward
        - "analogy": solve a simpler version first, then generalize
        - "algebraic": purely symbolic manipulation
        - "numeric": plug in specific numbers, then generalize
        """
        strategy_prompts = {
            "chain_of_thought": (
                "Solve this problem step by step. Show every intermediate "
                "calculation. State what you are computing at each step and why.\n\n"
                f"Problem: {problem.problem}\n\nSolution:"
            ),
            "backward": (
                "Solve this problem by working backward from the answer form. "
                "Start by identifying what the answer must look like, then "
                "determine what conditions must hold, then verify.\n\n"
                f"Problem: {problem.problem}\n\nSolution:"
            ),
            "analogy": (
                "Solve this problem by first solving a simpler version of it, "
                "then generalizing. State the simplified problem, solve it, "
                "then extend to the full problem.\n\n"
                f"Problem: {problem.problem}\n\nSolution:"
            ),
            "algebraic": (
                "Solve this problem using purely algebraic manipulation. "
                "Define variables, set up equations, and solve symbolically. "
                "Minimize numerical computation.\n\n"
                f"Problem: {problem.problem}\n\nSolution:"
            ),
            "numeric": (
                "Solve this problem by substituting specific numerical examples "
                "to build intuition, then derive the general solution. "
                "Show at least two numerical examples before generalizing.\n\n"
                f"Problem: {problem.problem}\n\nSolution:"
            ),
        }

        prompt = strategy_prompts[strategy]
        trace = self.teacher.generate(prompt, max_tokens=2048, temperature=0.7)

        return trace

    def validate_trace(self, problem, trace):
        """Validate that a generated trace arrives at the correct answer."""
        # Extract the final answer from the trace
        extracted = self.validator.extract_answer(trace)

        if extracted is None:
            return {"valid": False, "reason": "no_answer_found"}

        # Compare against reference
        correct = self.validator.compare_answers(extracted, problem.reference_answer)

        if not correct:
            return {"valid": False, "reason": "wrong_answer",
                    "extracted": extracted, "expected": problem.reference_answer}

        # Check reasoning quality (no logical gaps, no hallucinated steps)
        quality = self.validator.check_reasoning_quality(trace)

        return {"valid": True, "quality_score": quality}

    def generate_batch(self, problems, traces_per_problem=5, min_quality=0.7):
        """Generate and validate traces for a batch of problems.

        For each problem, generate multiple traces using different
        strategies. Filter to correct traces above the quality threshold.
        """
        strategies = ["chain_of_thought", "backward", "analogy",
                      "algebraic", "numeric"]

        results = []
        stats = {"total_generated": 0, "correct": 0, "high_quality": 0}

        for problem in problems:
            problem_traces = []

            for i in range(traces_per_problem):
                strategy = strategies[i % len(strategies)]

                trace = self.generate_trace(problem, strategy)
                stats["total_generated"] += 1

                validation = self.validate_trace(problem, trace)

                if validation["valid"] and validation.get("quality_score", 0) >= min_quality:
                    stats["correct"] += 1
                    stats["high_quality"] += 1
                    problem_traces.append({
                        "problem": problem.problem,
                        "domain": problem.domain,
                        "difficulty": problem.difficulty,
                        "strategy": strategy,
                        "trace": trace,
                        "quality_score": validation["quality_score"],
                    })
                elif validation["valid"]:
                    stats["correct"] += 1

            results.extend(problem_traces)

        return results, stats
ℹ️ Validation Is Non-Negotiable

Without answer validation, ~30-40% of synthetic math traces contain errors — the teacher model makes mistakes. Training on incorrect reasoning traces teaches the student to make the same mistakes confidently. Every synthetic math trace must be validated against a ground-truth answer. For problems without ground truth, use majority voting across 5-10 independent generations: if 4 out of 5 traces agree on the answer, treat that as ground truth.

Scaling Synthetic Math

The key numbers for synthetic math production at scale:

📊

Synthetic Math Generation Economics

MetricValueNotes
Seed problems available ~2M GSM8K, MATH, competition problems, textbook exercises
Traces per problem 3-5 strategies chain_of_thought, backward, analogy, algebraic, numeric
Raw traces generated 6-10M Before filtering
Pass rate (correct answer) 60-70% Depends on problem difficulty and teacher model
Quality filter pass rate 70-80% of correct Reasoning coherence, no gaps
Final high-quality traces 3-5M After all filtering
Tokens per trace (avg) 400-800 Detailed step-by-step
Total tokens produced 1.2-4T Sufficient for 8% of a 15T corpus
Compute cost (H100 hours) 5K-20K Using 70B teacher model
Note: Costs assume using an open 70B model locally. API-based generation with GPT-4 class models costs 3-5x more but may produce higher quality traces.

5. Curriculum Learning: Dynamic Mix Scheduling

Why Static Mixes Are Suboptimal

Training with a fixed data mix for the entire run wastes capacity. The model’s learning needs change as training progresses:

  • Early training (0-20% of tokens): The model learns basic language structure — grammar, common phrases, simple patterns. Diverse web data is most useful here because it provides the broadest coverage of basic linguistic patterns.
  • Mid training (20-70% of tokens): The model has basic fluency and begins learning more complex patterns — reasoning chains, factual associations, code structure. This is where increasing code and math proportions pays off most.
  • Late training (70-100% of tokens): The model is highly capable but benefits from targeted quality. High-quality sources (textbooks, curated math, clean code) should be overweighted. Low-quality web data adds noise at this stage.

The Llama 3 Curriculum

Meta explicitly describes a three-phase curriculum in the Llama 3 paper:

📊

Llama 3 Training Curriculum (Approximate)

PhaseToken RangeWeb %Code %Math %Key Change
Phase 1: Foundation 0 - 8T 55 20 5 Heavy web for diversity
Phase 2: Code + Reasoning 8T - 13T 45 30 10 Increase code and math
Phase 3: Quality Focus 13T - 15.6T 40 25 15 Upweight high-quality, add synthetic
Note: These are reconstructed from Llama 3 paper descriptions and are approximate. The actual transitions are smoother (linear interpolation over billions of tokens).

The math fraction triples from Phase 1 to Phase 3. Code increases by 50% in Phase 2 then decreases slightly in Phase 3 (because the remaining Phase 3 code is higher quality). Web data decreases throughout as the model’s needs shift from breadth to depth.

Implementation: Configurable Data Mixer

import numpy as np
from dataclasses import dataclass

@dataclass
class DataSource:
    """A data source with tokens stored in sharded files."""
    name: str
    shard_paths: list      # List of file paths
    total_tokens: int       # Total tokens across all shards
    quality_score: float    # 0-1, used for quality-weighted sampling
    current_shard: int = 0
    current_offset: int = 0


class DataMixer:
    """Multi-source data mixer with configurable weights and scheduling.

    Supports:
    - Static mixing: fixed weights for the entire training run
    - Linear scheduling: linearly interpolate between start and end weights
    - Phase-based scheduling: different weights for each training phase
    - Temperature-based: upweight high-quality sources as training progresses
    """

    def __init__(self, sources, total_tokens, seed=42):
        self.sources = {s.name: s for s in sources}
        self.total_tokens = total_tokens
        self.rng = np.random.RandomState(seed)
        self.tokens_consumed = 0

        # Validate sources have enough data
        for source in sources:
            if source.total_tokens == 0:
                raise ValueError(f"Source {source.name} has 0 tokens")

    def set_static_weights(self, weights):
        """Set fixed mixing weights for the entire run.

        weights: dict mapping source name to weight (will be normalized)
        """
        total = sum(weights.values())
        self.weight_fn = lambda step: {k: v / total for k, v in weights.items()}

    def set_linear_schedule(self, start_weights, end_weights):
        """Linearly interpolate between start and end weights.

        At token 0, use start_weights.
        At total_tokens, use end_weights.
        """
        # Normalize
        s_total = sum(start_weights.values())
        e_total = sum(end_weights.values())
        start_norm = {k: v / s_total for k, v in start_weights.items()}
        end_norm = {k: v / e_total for k, v in end_weights.items()}

        def weight_fn(tokens_so_far):
            alpha = tokens_so_far / self.total_tokens
            alpha = min(1.0, max(0.0, alpha))
            return {
                k: start_norm[k] * (1 - alpha) + end_norm.get(k, 0) * alpha
                for k in start_norm
            }

        self.weight_fn = weight_fn

    def set_phase_schedule(self, phases):
        """Phase-based scheduling with smooth transitions.

        phases: list of (token_boundary, weights) tuples.
        Example:
            [
                (0, {"web": 55, "code": 20, "math": 5}),
                (8e12, {"web": 45, "code": 30, "math": 10}),
                (13e12, {"web": 40, "code": 25, "math": 15}),
            ]

        Between phases, weights are linearly interpolated.
        """
        # Normalize all phase weights
        normalized = []
        for boundary, weights in phases:
            total = sum(weights.values())
            normalized.append((boundary, {k: v / total for k, v in weights.items()}))

        def weight_fn(tokens_so_far):
            # Find the two surrounding phases
            if tokens_so_far <= normalized[0][0]:
                return normalized[0][1]
            if tokens_so_far >= normalized[-1][0]:
                return normalized[-1][1]

            for i in range(len(normalized) - 1):
                b1, w1 = normalized[i]
                b2, w2 = normalized[i + 1]
                if b1 <= tokens_so_far <= b2:
                    alpha = (tokens_so_far - b1) / (b2 - b1) if b2 > b1 else 0
                    return {
                        k: w1.get(k, 0) * (1 - alpha) + w2.get(k, 0) * alpha
                        for k in set(w1) | set(w2)
                    }

            return normalized[-1][1]

        self.weight_fn = weight_fn

    def set_temperature_schedule(self, base_weights, quality_temp_start=1.0,
                                  quality_temp_end=0.3):
        """Temperature-based scheduling that upweights quality over time.

        At each step, weights are: w_i * quality_i^(1/T) where T decreases
        from quality_temp_start to quality_temp_end during training.
        Lower T = sharper distribution toward high-quality sources.
        """
        total = sum(base_weights.values())
        base_norm = {k: v / total for k, v in base_weights.items()}

        def weight_fn(tokens_so_far):
            alpha = tokens_so_far / self.total_tokens
            temp = quality_temp_start + alpha * (quality_temp_end - quality_temp_start)

            raw = {}
            for name, base_w in base_norm.items():
                quality = self.sources[name].quality_score
                raw[name] = base_w * (quality ** (1.0 / temp))

            total_raw = sum(raw.values())
            return {k: v / total_raw for k, v in raw.items()}

        self.weight_fn = weight_fn

    def sample_batch(self, batch_size_tokens):
        """Sample a batch of tokens according to current mixing weights.

        Returns: dict mapping source name to number of tokens to draw.
        """
        weights = self.weight_fn(self.tokens_consumed)

        # Multinomial sampling of batch across sources
        source_names = list(weights.keys())
        probs = np.array([weights[name] for name in source_names])
        probs = probs / probs.sum()  # Ensure normalization

        # Allocate tokens to sources
        allocations = self.rng.multinomial(batch_size_tokens, probs)

        result = {}
        for name, count in zip(source_names, allocations):
            if count > 0:
                result[name] = int(count)

        self.tokens_consumed += batch_size_tokens

        return result

    def get_current_weights(self):
        """Return the current mixing weights."""
        return self.weight_fn(self.tokens_consumed)

    def get_progress(self):
        """Return training progress and current mix state."""
        weights = self.weight_fn(self.tokens_consumed)
        return {
            "tokens_consumed": self.tokens_consumed,
            "total_tokens": self.total_tokens,
            "progress_pct": self.tokens_consumed / self.total_tokens * 100,
            "current_weights": weights,
        }

Complete Training Loop with Dynamic Mixing

class MixedDataLoader:
    """DataLoader that draws from multiple sources with dynamic mixing.

    Integrates with the DataMixer to produce batches where each
    micro-batch contains tokens from multiple sources in the
    proportions specified by the current schedule.
    """

    def __init__(self, mixer, tokenizer, seq_length=4096,
                 micro_batch_tokens=None):
        self.mixer = mixer
        self.tokenizer = tokenizer
        self.seq_length = seq_length
        self.micro_batch_tokens = micro_batch_tokens or seq_length * 8

        # Per-source token buffers
        self.buffers = {name: [] for name in mixer.sources}
        self.source_readers = {
            name: self._make_reader(source)
            for name, source in mixer.sources.items()
        }

    def _make_reader(self, source):
        """Create a generator that yields tokens from a source."""
        while True:
            for shard_path in source.shard_paths:
                tokens = self._read_shard(shard_path)
                yield from tokens
            # Wrap around if we exhaust all shards

    def _read_shard(self, path):
        """Read tokens from a shard file."""
        # Implementation depends on format (bin, jsonl, arrow, etc.)
        # Returns list of token IDs
        import struct
        with open(path, 'rb') as f:
            data = f.read()
        # Assume uint16 token IDs
        return list(struct.unpack(f'{len(data)//2}H', data))

    def _fill_buffer(self, source_name, num_tokens):
        """Fill a source buffer with tokens."""
        reader = self.source_readers[source_name]
        while len(self.buffers[source_name]) < num_tokens:
            self.buffers[source_name].append(next(reader))

    def get_batch(self):
        """Get one training batch with dynamically mixed data.

        Returns: tensor of shape [batch_size, seq_length]
        """
        # Determine how many tokens from each source
        allocation = self.mixer.sample_batch(self.micro_batch_tokens)

        # Collect tokens from each source
        all_tokens = []
        for source_name, num_tokens in allocation.items():
            self._fill_buffer(source_name, num_tokens)
            tokens = self.buffers[source_name][:num_tokens]
            self.buffers[source_name] = self.buffers[source_name][num_tokens:]
            all_tokens.extend(tokens)

        # Shuffle tokens within the batch (prevents source-level ordering bias)
        self.mixer.rng.shuffle(all_tokens)

        # Pack into sequences of seq_length
        num_seqs = len(all_tokens) // self.seq_length
        if num_seqs == 0:
            return None

        import torch
        batch = torch.tensor(
            all_tokens[:num_seqs * self.seq_length],
            dtype=torch.long
        ).reshape(num_seqs, self.seq_length)

        return batch

    def log_mixing_state(self, step, logger=None):
        """Log current mixing proportions for monitoring."""
        state = self.mixer.get_progress()
        weights = state["current_weights"]

        log_msg = (
            f"Step {step} | "
            f"Progress: {state['progress_pct']:.1f}% | "
            f"Mix: " + ", ".join(
                f"{k}={v:.1%}" for k, v in sorted(weights.items(),
                                                    key=lambda x: -x[1])
            )
        )

        if logger:
            logger.info(log_msg)
            for name, weight in weights.items():
                logger.log_scalar(f"data_mix/{name}", weight, step)
        else:
            print(log_msg)

Example: Setting Up the Llama 3 Curriculum

def create_llama3_mixer():
    """Create a data mixer matching the Llama 3 training curriculum."""

    sources = [
        DataSource("web", shard_paths=["web_shard_{:05d}.bin".format(i) for i in range(10000)],
                   total_tokens=int(8e12), quality_score=0.4),
        DataSource("code", shard_paths=["code_shard_{:05d}.bin".format(i) for i in range(3000)],
                   total_tokens=int(4e12), quality_score=0.7),
        DataSource("math", shard_paths=["math_shard_{:05d}.bin".format(i) for i in range(500)],
                   total_tokens=int(1.5e12), quality_score=0.9),
        DataSource("books", shard_paths=["book_shard_{:05d}.bin".format(i) for i in range(200)],
                   total_tokens=int(0.8e12), quality_score=0.85),
        DataSource("wiki", shard_paths=["wiki_shard_{:05d}.bin".format(i) for i in range(100)],
                   total_tokens=int(0.3e12), quality_score=0.95),
        DataSource("multilingual", shard_paths=["ml_shard_{:05d}.bin".format(i) for i in range(1500)],
                   total_tokens=int(1.5e12), quality_score=0.5),
    ]

    mixer = DataMixer(sources, total_tokens=int(15.6e12))

    # Set Llama 3-style phase schedule
    mixer.set_phase_schedule([
        (0,       {"web": 55, "code": 20, "math": 5, "books": 5, "wiki": 5, "multilingual": 10}),
        (8e12,    {"web": 45, "code": 30, "math": 10, "books": 5, "wiki": 3, "multilingual": 7}),
        (13e12,   {"web": 40, "code": 25, "math": 15, "books": 6, "wiki": 4, "multilingual": 10}),
    ])

    return mixer


def training_loop_example():
    """Example training loop with dynamic data mixing."""
    mixer = create_llama3_mixer()
    loader = MixedDataLoader(mixer, tokenizer=None, seq_length=4096)

    total_steps = int(15.6e12 / (4096 * 8))  # seq_length * batch_size

    for step in range(total_steps):
        batch = loader.get_batch()
        if batch is None:
            continue

        # Forward pass, loss computation, backward pass, optimizer step
        # ... (standard training loop) ...

        # Log mixing state periodically
        if step % 1000 == 0:
            loader.log_mixing_state(step)

        # Example log output at different stages:
        # Step 0      | Progress: 0.0%  | Mix: web=55.0%, code=20.0%, multilingual=10.0%, ...
        # Step 500000 | Progress: 10.5% | Mix: web=53.4%, code=21.6%, multilingual=9.5%, ...
        # Step 1500000| Progress: 31.5% | Mix: web=48.8%, code=26.3%, math=7.8%, ...
        # Step 3000000| Progress: 63.0% | Mix: web=44.2%, code=29.0%, math=10.8%, ...
        # Step 4000000| Progress: 84.0% | Mix: web=41.4%, code=26.6%, math=13.3%, ...

6. Domain-Specific Upweighting and Downweighting

When to Deviate from Standard Mixes

The Llama 3 mix is optimized for a general-purpose model. Domain-specific models should adjust:

📊

Data Mix Adjustments by Target Application

Target ApplicationCode %Math %Web %Rationale
General chatbot 20-25 5-8 50-55 Balance fluency with reasoning
Coding assistant 35-45 10-12 30-35 Maximize code capability, keep NL fluency
Math/science model 20-25 20-30 30-35 Heavy math, code for structured reasoning
Medical/legal domain 10-15 3-5 40-50 Domain web data replaces general web
Multilingual model 15-20 5-8 35-40 30%+ for multilingual sources
Small model (1-3B) 25-30 10-15 35-40 More structured data compensates for capacity
Note: Percentages are guidelines. Remaining tokens filled with books, wiki, conversation, and domain-specific data. Small models benefit disproportionately from structured data (code, math) because they have less capacity to extract structure from noisy web data.
💡 Small Models Need More Structure

A counterintuitive finding: smaller models (1B-7B parameters) benefit from higher code and math fractions than larger models. The hypothesis is that small models have limited capacity to learn implicit structure from noisy web data. Code and math, which have explicit structure, provide a more efficient learning signal per token. Microsoft’s Phi series demonstrated this: Phi-3 (3.8B params) uses 25% code and 15% math, higher than Llama 3 405B’s 25% and 8% respectively.

Data Repetition and Epoch Counting

When a source is smaller than its allocation, tokens must be repeated. Repetition is acceptable up to a point:

effective_epochs(s)=ws×Tsize(s)\text{effective\_epochs}(s) = \frac{w_s \times T}{\text{size}(s)}

where wsw_s is the weight for source ss and TT is total training tokens.

def compute_repetition(mixer, verbose=True):
    """Compute how many times each source is repeated during training."""
    weights = mixer.weight_fn(mixer.total_tokens / 2)  # Mid-training weights

    repetitions = {}
    for name, source in mixer.sources.items():
        allocated = weights.get(name, 0) * mixer.total_tokens
        epochs = allocated / source.total_tokens if source.total_tokens > 0 else float('inf')
        repetitions[name] = {
            "allocated_tokens": allocated,
            "available_tokens": source.total_tokens,
            "epochs": epochs,
            "repeated": epochs > 1.0,
        }

        if verbose:
            status = "REPEATED" if epochs > 1.0 else "ok"
            print(f"  {name}: {epochs:.2f} epochs "
                  f"({allocated/1e12:.1f}T allocated / "
                  f"{source.total_tokens/1e12:.1f}T available) [{status}]")

    return repetitions

Research findings on repetition:

  • 1-2 epochs: Negligible quality loss. The model benefits from seeing data twice, similar to multi-epoch training in supervised learning.
  • 2-4 epochs: Measurable but small quality loss (0.5-1% on benchmarks). Acceptable for high-quality sources like math and books.
  • 4+ epochs: Significant quality degradation. The model begins memorizing rather than generalizing. Training loss continues to decrease, but validation loss stalls or increases.
📊

Impact of Data Repetition on Benchmark Quality (7B model)

EpochsTraining LossVal LossMMLUMemorization %
1.0 2.84 2.91 57.2 2.1%
2.0 2.71 2.89 56.8 3.4%
4.0 2.52 2.93 55.1 8.7%
8.0 2.31 3.08 52.4 18.3%
16.0 2.05 3.31 48.9 34.6%
Note: Training loss continues improving with repetition (memorization). Validation loss and benchmark performance degrade beyond 2 epochs. Memorization % measured by exact-match of training sequences in model output.

The prescription: keep all sources below 4 epochs. For sources that would exceed this (math, books), generate synthetic data to increase the effective pool size rather than repeating.


7. Monitoring and Adjusting During Training

Online Monitoring

Track per-source loss during training to detect when a source is “exhausted” (the model has learned everything it can from it) or “underserved” (high loss indicating the model needs more examples):

class MixMonitor:
    """Monitor per-source training loss to guide mix adjustments.

    Tracks exponential moving average of loss per data source.
    Alerts when a source's loss plateaus (model has saturated)
    or diverges (model is forgetting).
    """

    def __init__(self, source_names, ema_alpha=0.99):
        self.ema_alpha = ema_alpha
        self.ema_loss = {name: None for name in source_names}
        self.loss_history = {name: [] for name in source_names}
        self.step = 0

    def update(self, source_name, loss_value):
        """Update the EMA loss for a source."""
        if self.ema_loss[source_name] is None:
            self.ema_loss[source_name] = loss_value
        else:
            self.ema_loss[source_name] = (
                self.ema_alpha * self.ema_loss[source_name] +
                (1 - self.ema_alpha) * loss_value
            )
        self.loss_history[source_name].append(
            (self.step, loss_value, self.ema_loss[source_name])
        )
        self.step += 1

    def get_recommendations(self, window=1000):
        """Analyze loss trends and recommend mix adjustments.

        Returns recommendations:
        - "increase": source has high loss, model would benefit from more
        - "decrease": source loss has plateaued, diminishing returns
        - "maintain": source loss is decreasing normally
        """
        recommendations = {}

        for name, history in self.loss_history.items():
            if len(history) < window * 2:
                recommendations[name] = {"action": "maintain", "reason": "insufficient_data"}
                continue

            recent = [h[1] for h in history[-window:]]
            older = [h[1] for h in history[-2*window:-window]]

            recent_mean = sum(recent) / len(recent)
            older_mean = sum(older) / len(older)

            # Loss change rate
            change_rate = (recent_mean - older_mean) / older_mean if older_mean > 0 else 0

            if abs(change_rate) < 0.005:
                # Loss plateaued (less than 0.5% change)
                recommendations[name] = {
                    "action": "decrease",
                    "reason": "loss_plateau",
                    "loss_change_pct": change_rate * 100,
                }
            elif change_rate > 0.02:
                # Loss increasing (model forgetting this domain)
                recommendations[name] = {
                    "action": "increase",
                    "reason": "loss_increasing",
                    "loss_change_pct": change_rate * 100,
                }
            else:
                recommendations[name] = {
                    "action": "maintain",
                    "reason": "healthy_decrease",
                    "loss_change_pct": change_rate * 100,
                }

        return recommendations

    def print_status(self):
        """Print current loss status for all sources."""
        print(f"Step {self.step} | Per-source EMA loss:")
        for name, ema in sorted(self.ema_loss.items()):
            if ema is not None:
                print(f"  {name}: {ema:.4f}")
⚠️ Do Not Auto-Adjust in Large Runs

While online monitoring is valuable for diagnostics, automatically adjusting mix proportions during a large training run is risky. The loss signal is noisy, and the model’s loss on a source depends on all other sources (cross-domain transfer). A decrease in math loss might be caused by increasing code fraction, not by the math data itself. Adjustments should be made between training restarts (checkpoints), not dynamically within a run, unless you have extensively validated the adjustment policy on smaller-scale experiments.


Reviewer Agent Validation

Challenge: Implement a function that, given a dictionary of data sources (name to token count), a total training budget, and a set of desired mixing weights, computes: (1) the number of epochs each source will be repeated, (2) whether any source exceeds a maximum repetition threshold, and (3) the minimum amount of synthetic data needed for over-repeated sources to bring them below the threshold.

Expected:

def plan_data_mix(sources_tokens, total_budget, weights, max_epochs=4.0):
    total_w = sum(weights.values())
    normalized = {k: v / total_w for k, v in weights.items()}

    plan = {}
    synthetic_needed = {}
    for name in sources_tokens:
        allocated = normalized.get(name, 0) * total_budget
        available = sources_tokens[name]
        epochs = allocated / available if available > 0 else float('inf')
        over_limit = epochs > max_epochs

        plan[name] = {
            "allocated_tokens": allocated,
            "available_tokens": available,
            "epochs": epochs,
            "over_limit": over_limit,
        }

        if over_limit:
            # Need enough synthetic to bring epochs to max_epochs
            # allocated / (available + synthetic) = max_epochs
            # synthetic = allocated / max_epochs - available
            needed = allocated / max_epochs - available
            synthetic_needed[name] = max(0, needed)

    return {"plan": plan, "synthetic_needed": synthetic_needed}