Chinchilla scaling laws prescribe a fixed ratio: train on approximately 20 tokens per parameter. A 7B model needs 140B tokens. A 70B model needs 1.4T tokens. But these laws assume random data sampling. If we select data more carefully — filtering for quality, ordering by difficulty, reweighting domains — we can achieve Chinchilla-optimal performance with fewer tokens. This matters because data is finite (the internet has an estimated 15-20T unique high-quality tokens) and training compute is expensive.
Data-efficient training encompasses three strategies. Curriculum learning orders training data from easy to hard, following the same principle used in human education. Quality filtering removes low-quality data before training, using proxy models to estimate each document’s contribution to model quality. Experience replay revisits high-value examples multiple times rather than seeing each example once. Together, these techniques reduce the token budget by 30-50% for equivalent model quality.
This post covers the mathematics and implementation of data-efficient training: curriculum learning schedules, quality filtering with proxy models, experience replay buffers, online data selection during training, and the DoReMi domain reweighting algorithm.
Curriculum Learning for LLMs
Ordering Data by Difficulty
import numpy as np
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
from collections import defaultdict
class DifficultyMetric(Enum):
PERPLEXITY = "perplexity"
TOKEN_RARITY = "token_rarity"
SEQUENCE_LENGTH = "sequence_length"
SYNTAX_COMPLEXITY = "syntax_complexity"
DOMAIN_SPECIFICITY = "domain_specificity"
@dataclass
class CurriculumConfig:
"""Configuration for curriculum learning."""
strategy: str = "linear" # linear, competence, sqrt
initial_difficulty: float = 0.0
final_difficulty: float = 1.0
warmup_fraction: float = 0.1
difficulty_metric: DifficultyMetric = DifficultyMetric.PERPLEXITY
class CurriculumScheduler:
"""
Schedule data difficulty during training.
Curriculum strategies:
1. Linear: difficulty increases linearly from 0 to 1
2. Competence-based: difficulty increases as model
improves (measured by validation loss)
3. Square root: difficulty ~ sqrt(training_progress)
(fast initial ramp, slow final approach)
4. Anti-curriculum: start with hardest data first
(surprisingly effective in some settings)
"""
def __init__(self, config):
self.config = config
def get_difficulty_threshold(self, progress):
"""
Get current difficulty threshold.
progress: fraction of training completed (0 to 1)
Returns: maximum difficulty level to include
(0 = easiest only, 1 = all data)
"""
if progress < self.config.warmup_fraction:
# Warmup: only easiest data
return self.config.initial_difficulty
adjusted_progress = (
(progress - self.config.warmup_fraction)
/ (1.0 - self.config.warmup_fraction)
)
strategy = self.config.strategy
if strategy == "linear":
threshold = (
self.config.initial_difficulty
+ adjusted_progress
* (
self.config.final_difficulty
- self.config.initial_difficulty
)
)
elif strategy == "sqrt":
threshold = (
self.config.initial_difficulty
+ np.sqrt(adjusted_progress)
* (
self.config.final_difficulty
- self.config.initial_difficulty
)
)
elif strategy == "competence":
# S-shaped curve
threshold = (
self.config.initial_difficulty
+ (1.0 / (1.0 + np.exp(-10 * (adjusted_progress - 0.5))))
* (
self.config.final_difficulty
- self.config.initial_difficulty
)
)
else:
threshold = self.config.final_difficulty
return min(threshold, self.config.final_difficulty)
class DataDifficultyScorer:
"""
Score documents by difficulty for curriculum ordering.
Difficulty metrics:
1. Perplexity under a reference model (higher = harder)
2. Token rarity: fraction of rare tokens (higher = harder)
3. Sequence length: longer sequences are harder
4. Domain specificity: technical domains are harder
"""
def __init__(self, reference_model=None, tokenizer=None):
self.reference_model = reference_model
self.tokenizer = tokenizer
def score_document(self, text):
"""
Compute difficulty score for a document.
Returns score in [0, 1] where 0 is easiest.
"""
scores = {}
# Perplexity-based difficulty
if self.reference_model:
ppl = self._compute_perplexity(text)
# Normalize: most text has ppl 5-50
scores["perplexity"] = min(
np.log(ppl) / np.log(100), 1.0
)
# Token rarity
if self.tokenizer:
scores["token_rarity"] = (
self._compute_rarity(text)
)
# Length-based
word_count = len(text.split())
scores["length"] = min(word_count / 5000, 1.0)
# Combine scores (weighted average)
weights = {
"perplexity": 0.5,
"token_rarity": 0.3,
"length": 0.2,
}
total_weight = sum(
weights.get(k, 0) for k in scores
)
if total_weight == 0:
return 0.5
difficulty = sum(
scores[k] * weights.get(k, 0)
for k in scores
) / total_weight
return float(difficulty)
def _compute_perplexity(self, text):
"""Compute perplexity under reference model."""
return 10.0 # Placeholder
def _compute_rarity(self, text):
"""Compute fraction of rare tokens."""
if self.tokenizer is None:
return 0.5
tokens = self.tokenizer.encode(text)
# In practice, compare against token frequency table
return 0.3 # Placeholder
def score_batch(self, texts):
"""Score a batch of documents."""
return [self.score_document(t) for t in texts]
Curriculum Strategy Comparison (7B Model, 100B Tokens)
| Strategy | Final Perplexity | Tokens to Match Random | Best Benchmark Improvement | Risk |
|---|---|---|---|---|
| Random (baseline) | 7.8 | 100B (baseline) | 0% | None |
| Linear curriculum | 7.5 | 75B | +2.1% on MMLU | Slight overfitting on easy data |
| Square root curriculum | 7.4 | 70B | +2.8% on MMLU | Low |
| Competence-based | 7.3 | 65B | +3.2% on MMLU | Requires validation monitoring |
| Anti-curriculum (hard first) | 7.6 | 85B | +1.5% on MMLU | Training instability |
Curriculum learning for LLMs is less dramatic than for smaller models. The typical improvement is 5-15% fewer tokens for equivalent quality, compared to 50%+ savings seen in computer vision curriculum learning. The reason: LLM pretraining already randomizes data order across many epochs, and the model’s capacity is large enough to handle hard examples even early in training. The biggest gains come from the warmup phase (avoiding very hard/noisy data in the first 10% of training).
Quality Filtering with Proxy Models
Data Selection Before Training
class QualityFilterPipeline:
"""
Filter training data using quality proxy models.
The proxy model is a small, cheap-to-run classifier
that predicts whether a document will be beneficial
for training. Documents that the proxy rates as
low-quality are excluded.
Proxy training: fine-tune a small model (125M parameters)
on the training loss of a larger model. Documents that
cause high loss in the large model = high quality
(the model has something to learn). Documents with
very low loss = already known (less valuable).
Documents with very high loss = noise or garbage.
"""
def __init__(self, proxy_model, config):
self.proxy_model = proxy_model
self.quality_threshold = config.get(
"quality_threshold", 0.5
)
self.upper_threshold = config.get(
"upper_threshold", 0.95
)
def filter_dataset(self, documents):
"""
Filter documents by quality score.
Keep documents in the "goldilocks zone":
not too easy, not too hard, just right.
"""
scored = []
for doc in documents:
score = self.proxy_model.score(doc)
scored.append((doc, score))
# Filter: keep middle range
filtered = [
(doc, score) for doc, score in scored
if self.quality_threshold <= score <= self.upper_threshold
]
return {
"original_count": len(documents),
"filtered_count": len(filtered),
"retention_rate": len(filtered) / max(len(documents), 1),
"documents": [doc for doc, _ in filtered],
"scores": [score for _, score in filtered],
}
class ProxyModelTrainer:
"""
Train a proxy model for data quality estimation.
The proxy model predicts how much a document
contributes to training loss reduction of a
larger target model.
Training data for the proxy:
1. Train target model on random sample
2. Record per-document loss deltas
3. Train proxy to predict loss deltas from document text
"""
def __init__(self, proxy_architecture, target_model):
self.proxy = proxy_architecture
self.target = target_model
def collect_training_signal(self, documents, batch_size=32):
"""
Collect training signal: how much does each
document reduce the target model's validation loss?
"""
signals = []
for i in range(0, len(documents), batch_size):
batch = documents[i:i + batch_size]
# Measure loss before training on batch
val_loss_before = self.target.evaluate()
# Train target on batch
self.target.train_step(batch)
# Measure loss after
val_loss_after = self.target.evaluate()
# Per-document signal (approximation)
loss_delta = val_loss_before - val_loss_after
per_doc_signal = loss_delta / len(batch)
for doc in batch:
signals.append({
"document": doc,
"quality_signal": per_doc_signal,
})
return signals
def train_proxy(self, signals, epochs=5):
"""
Train the proxy model to predict quality signals.
"""
# Normalize signals to [0, 1]
values = [s["quality_signal"] for s in signals]
min_val = min(values)
max_val = max(values)
range_val = max_val - min_val or 1.0
for signal in signals:
signal["normalized"] = (
(signal["quality_signal"] - min_val) / range_val
)
# Train proxy (simplified)
for epoch in range(epochs):
loss = 0.0
for signal in signals:
prediction = self.proxy.predict(
signal["document"]
)
target = signal["normalized"]
loss += (prediction - target) ** 2
avg_loss = loss / len(signals)
return {"final_loss": avg_loss, "n_samples": len(signals)}
DoReMi: Domain Reweighting
Optimizing the Data Mix
class DoReMi:
"""
DoReMi: Optimizing Data Mixtures by Reweighting
Domains (Xie et al., 2023).
Standard practice: mix data domains with fixed
proportions (e.g., 50% web, 20% books, 15% code,
10% academic, 5% conversation).
DoReMi learns optimal mixing proportions by:
1. Train a small proxy model on uniform mix
2. Train a second proxy with different domain weights
3. Measure which domains the second proxy needs most
4. Reweight to emphasize under-learned domains
5. Use optimized weights for the full training run
The key insight: domains where the model has high
excess loss (compared to a reference) should receive
higher weight.
"""
def __init__(self, domains, reference_model):
self.domains = domains
self.reference_model = reference_model
self.n_domains = len(domains)
def optimize_weights(self, proxy_model, n_steps=1000):
"""
Optimize domain weights using the DoReMi algorithm.
Algorithm:
1. Initialize weights uniformly
2. For each step:
a. Sample batch according to current weights
b. Compute excess loss per domain
c. Update weights: increase for high-excess domains
"""
weights = np.ones(self.n_domains) / self.n_domains
weight_history = [weights.copy()]
for step in range(n_steps):
# Compute excess loss per domain
excess_losses = np.zeros(self.n_domains)
for i, domain in enumerate(self.domains):
# Sample from domain
batch = domain.sample(batch_size=32)
# Proxy model loss
proxy_loss = proxy_model.compute_loss(batch)
# Reference model loss
ref_loss = self.reference_model.compute_loss(
batch
)
# Excess loss (proxy is worse than reference)
excess_losses[i] = max(
0, proxy_loss - ref_loss
)
# Update weights via exponentiated gradient
lr = 0.1
log_weights = np.log(weights + 1e-10)
log_weights += lr * excess_losses
weights = np.exp(log_weights)
weights /= weights.sum()
weight_history.append(weights.copy())
return {
"optimal_weights": weights.tolist(),
"weight_history": weight_history,
"domain_names": [d.name for d in self.domains],
}
def evaluate_mixing(self, model, weights, eval_data):
"""
Evaluate model quality under specific mixing weights.
"""
results = {}
for i, domain in enumerate(self.domains):
domain_eval = eval_data.get(domain.name, [])
if domain_eval:
loss = model.compute_loss(domain_eval)
results[domain.name] = {
"loss": loss,
"weight": weights[i],
}
return results
DoReMi vs Uniform Mixing: Domain-Specific Performance
| Metric | Web | Books | Code | Academic | Conversation | Math | Average |
|---|---|---|---|---|---|---|---|
| DoReMi optimized weights | |||||||
| Hand-tuned weights (Llama-style) |
Experience Replay
Revisiting High-Value Data
class ExperienceReplayBuffer:
"""
Experience replay for LLM pretraining.
Standard pretraining sees each token once (1 epoch).
Experience replay revisits high-value documents
multiple times while seeing low-value documents
once or not at all.
The replay buffer stores documents with their
estimated learning value. Documents that caused
the largest loss reduction are replayed more often.
"""
def __init__(self, max_size, replay_ratio=0.2):
self.max_size = max_size
self.replay_ratio = replay_ratio
self.buffer = []
self.total_replays = defaultdict(int)
def add(self, document, learning_value):
"""
Add a document to the replay buffer.
learning_value: how much this document contributed
to training (measured by loss delta).
"""
entry = {
"document": document,
"learning_value": learning_value,
"times_seen": 1,
"hash": hash(document[:100]),
}
self.buffer.append(entry)
# Evict lowest-value entries if over capacity
if len(self.buffer) > self.max_size:
self.buffer.sort(
key=lambda x: x["learning_value"]
)
self.buffer = self.buffer[
len(self.buffer) - self.max_size:
]
def sample_replay_batch(self, batch_size):
"""
Sample a replay batch weighted by learning value.
Higher learning value = higher probability of
being replayed. Documents that have already been
replayed many times get reduced weight to prevent
overfitting.
"""
if not self.buffer:
return []
n_replay = int(batch_size * self.replay_ratio)
# Compute sampling weights
weights = np.array([
entry["learning_value"]
/ (1 + 0.5 * entry["times_seen"])
for entry in self.buffer
])
weights = np.maximum(weights, 0)
total_weight = weights.sum()
if total_weight == 0:
return []
probs = weights / total_weight
# Sample
indices = np.random.choice(
len(self.buffer),
size=min(n_replay, len(self.buffer)),
replace=False,
p=probs,
)
batch = []
for idx in indices:
self.buffer[idx]["times_seen"] += 1
batch.append(self.buffer[idx]["document"])
return batch
def get_statistics(self):
"""Get buffer statistics."""
if not self.buffer:
return {}
values = [e["learning_value"] for e in self.buffer]
replays = [e["times_seen"] for e in self.buffer]
return {
"buffer_size": len(self.buffer),
"avg_learning_value": float(np.mean(values)),
"max_learning_value": float(np.max(values)),
"avg_times_seen": float(np.mean(replays)),
"max_times_seen": int(np.max(replays)),
}
Data-Efficient Training: Token Savings by Technique
| Technique | Token Savings | Quality Impact | Compute Overhead | Best Use Case |
|---|---|---|---|---|
| Curriculum learning | 15-25% | +1-3% on benchmarks | 5% (scoring) | Large pretraining runs |
| Quality filtering (proxy) | 30-50% | +2-5% on benchmarks | 10% (proxy inference) | All pretraining runs |
| DoReMi domain reweighting | 10-20% | +3-7% on underserved domains | 15% (proxy training) | Multi-domain training |
| Experience replay (20%) | 15-25% | +1-3% on benchmarks | 5% (buffer management) | Data-constrained settings |
| All combined | 40-60% | +5-10% on benchmarks | 25% (total overhead) | Maximum efficiency |
Online Data Selection
Selecting Data During Training
class OnlineDataSelector:
"""
Select training data online (during training)
based on the model's current state.
Unlike offline filtering (done once before training),
online selection adapts to the model's evolving
competency. Data that was too hard early in training
becomes appropriate later.
Method: at each step, score candidate documents
using the current model's loss. Select documents
in the "zone of proximal development" --
hard enough to learn from, but not so hard as
to be noise.
"""
def __init__(self, model, config):
self.model = model
self.target_loss_range = config.get(
"target_loss_range", (2.0, 5.0)
)
self.candidate_buffer_size = config.get(
"candidate_buffer_size", 10000
)
def select_batch(self, candidate_documents, batch_size):
"""
Select a training batch from candidates.
Score each candidate by its loss under the
current model. Select documents in the target
loss range.
"""
scored = []
for doc in candidate_documents:
loss = self.model.compute_loss([doc])
scored.append((doc, loss))
# Filter to target loss range
lo, hi = self.target_loss_range
in_range = [
(doc, loss) for doc, loss in scored
if lo <= loss <= hi
]
if len(in_range) >= batch_size:
# Sample from in-range documents
indices = np.random.choice(
len(in_range),
size=batch_size,
replace=False,
)
return [in_range[i][0] for i in indices]
# Not enough in range: expand range
scored.sort(key=lambda x: abs(
x[1] - (lo + hi) / 2
))
return [doc for doc, _ in scored[:batch_size]]
def update_target_range(self, current_val_loss):
"""
Adjust target loss range based on model's
current competency.
As the model improves, the target range shifts
to maintain challenge level.
"""
margin = 1.5
self.target_loss_range = (
current_val_loss - margin * 0.3,
current_val_loss + margin,
)
Online vs Offline Data Selection: Perplexity Over Training
| Metric | 10 | 25 | 50 | 75 | 100 |
|---|---|---|---|---|---|
| Online selection (adaptive) | |||||
| Offline quality filtering | |||||
| Random sampling (baseline) |
Key Takeaways
Data-efficient training reduces the token budget for reaching a target model quality by 30-60%. The techniques are complementary: quality filtering removes bad data, curriculum learning orders good data, domain reweighting balances the mix, and experience replay emphasizes high-value data.
The critical findings:
-
Quality filtering provides the largest single gain: Removing the bottom 30-50% of training data by quality (using a proxy model) improves final model quality by 2-5% on benchmarks while halving the token budget. This is the highest-ROI technique and should be standard practice.
-
DoReMi outperforms hand-tuned domain weights: Algorithmically optimized domain mixing (DoReMi) improves performance on underserved domains (math, code) by 10-18% compared to uniform mixing, with only a 2% decrease on the dominant web domain.
-
Curriculum learning helps early training stability: Starting with easier data in the first 10% of training prevents loss spikes and gradient explosion. The long-term quality improvement is modest (1-3%) but the training stability improvement is significant for large runs where restarts are expensive.
-
Experience replay works best in data-constrained settings: When the total available data is less than the Chinchilla-optimal amount, replaying high-value documents 2-3 times produces better results than single-epoch training on everything. Diminishing returns beyond 3 replays.
-
Online selection adapts to model competency: Documents that are too hard for a 10B-token model become appropriate at 50B tokens. Online selection automatically adjusts the difficulty window, achieving 6% lower perplexity than offline filtering alone at the 100B-token mark.