Expert-written math reasoning traces cost 500K-2M. STaR (Self-Taught Reasoner) eliminates this cost: the model generates its own reasoning traces, filters for correctness, and retrains on successful examples. After 4 iterations, STaR improved GSM-8K accuracy from 10% to 44% with zero human-written rationales. The risk: model collapse, where self-generated data amplifies errors until performance degrades. The reward: exponential improvement from fixed compute if collapse is avoided.
This post covers the algorithms, their implementations, the data flywheel dynamics, collapse mechanisms, and production-ready mitigation strategies.
STaR: Self-Taught Reasoner
The Algorithm
STaR has three steps per iteration:
- Generate: the model produces chain-of-thought rationales for each problem
- Filter: keep only rationales that lead to the correct final answer
- Retrain: fine-tune the model on the filtered rationales
When the model fails to produce any correct rationale for a problem, STaR uses “rationalization from hint”: provide the correct answer and ask the model to generate a rationale that leads to it. This prevents the training distribution from shrinking to only “easy” problems.
import torch
import json
import random
from dataclasses import dataclass
from typing import Sequence
@dataclass
class Problem:
question: str
correct_answer: str
difficulty: str = "unknown"
@dataclass
class Rationale:
question: str
chain_of_thought: str
final_answer: str
correct: bool
source: str # "direct" or "hint"
class STaRTrainer:
"""
Complete STaR training loop.
Self-Taught Reasoner: generate rationales, filter, retrain.
"""
def __init__(self, model, tokenizer, optimizer_cls, lr=1e-5):
self.model = model
self.tokenizer = tokenizer
self.optimizer_cls = optimizer_cls
self.lr = lr
def generate_rationales(self, problems, num_samples=8,
temperature=0.7, max_tokens=512):
"""
Step 1: Generate chain-of-thought rationales for each problem.
Args:
problems: List of Problem objects
num_samples: Number of rationales to generate per problem
temperature: Sampling temperature
max_tokens: Maximum tokens per rationale
"""
all_rationales = []
for problem in problems:
prompt = self._build_rationale_prompt(problem.question)
inputs = self.tokenizer(prompt, return_tensors="pt").to(
self.model.device
)
rationales_for_problem = []
for _ in range(num_samples):
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
top_p=0.95,
)
generated = self.tokenizer.decode(
output[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True,
)
# Parse chain-of-thought and final answer
cot, answer = self._parse_rationale(generated)
is_correct = self._check_answer(answer, problem.correct_answer)
rationales_for_problem.append(Rationale(
question=problem.question,
chain_of_thought=cot,
final_answer=answer,
correct=is_correct,
source="direct",
))
all_rationales.append(rationales_for_problem)
return all_rationales
def rationalize_from_hint(self, problem, max_tokens=512):
"""
Generate a rationale given the correct answer as a hint.
Used when no direct rationale was correct.
"""
prompt = self._build_hint_prompt(problem.question, problem.correct_answer)
inputs = self.tokenizer(prompt, return_tensors="pt").to(
self.model.device
)
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.3, # Lower temperature for hint-based
do_sample=True,
)
generated = self.tokenizer.decode(
output[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True,
)
cot, answer = self._parse_rationale(generated)
is_correct = self._check_answer(answer, problem.correct_answer)
return Rationale(
question=problem.question,
chain_of_thought=cot,
final_answer=answer,
correct=is_correct,
source="hint",
)
def filter_and_collect(self, problems, all_rationales):
"""
Step 2: Filter for correct rationales.
For problems with no correct direct rationale, use hint rationalization.
"""
training_data = []
stats = {
'direct_correct': 0,
'hint_correct': 0,
'no_correct': 0,
'total_problems': len(problems),
}
for problem, rationales in zip(problems, all_rationales):
# Find first correct direct rationale
correct_rationale = None
for r in rationales:
if r.correct:
correct_rationale = r
stats['direct_correct'] += 1
break
# Rationalize from hint if no direct rationale was correct
if correct_rationale is None:
hint_rationale = self.rationalize_from_hint(problem)
if hint_rationale.correct:
correct_rationale = hint_rationale
stats['hint_correct'] += 1
else:
stats['no_correct'] += 1
continue
training_data.append({
'question': correct_rationale.question,
'rationale': correct_rationale.chain_of_thought,
'answer': correct_rationale.final_answer,
'source': correct_rationale.source,
})
return training_data, stats
def retrain(self, training_data, epochs=3, batch_size=8):
"""
Step 3: Fine-tune the model on filtered rationales.
"""
optimizer = self.optimizer_cls(self.model.parameters(), lr=self.lr)
self.model.train()
# Prepare training examples
examples = []
for item in training_data:
text = self._format_training_example(
item['question'], item['rationale'], item['answer']
)
encoded = self.tokenizer(
text, return_tensors="pt", truncation=True,
max_length=1024, padding="max_length",
)
examples.append(encoded)
for epoch in range(epochs):
random.shuffle(examples)
total_loss = 0.0
num_batches = 0
for i in range(0, len(examples), batch_size):
batch = examples[i:i + batch_size]
# Stack batch
input_ids = torch.cat(
[e['input_ids'] for e in batch]
).to(self.model.device)
attention_mask = torch.cat(
[e['attention_mask'] for e in batch]
).to(self.model.device)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
)
loss = outputs.loss
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
print(f" Epoch {epoch + 1}/{epochs}: loss = {avg_loss:.4f}")
def run_iteration(self, problems, num_samples=8, epochs=3):
"""Run one complete STaR iteration."""
print("Generating rationales...")
all_rationales = self.generate_rationales(problems, num_samples)
print("Filtering...")
training_data, stats = self.filter_and_collect(problems, all_rationales)
print(f" Direct correct: {stats['direct_correct']}")
print(f" Hint correct: {stats['hint_correct']}")
print(f" No correct: {stats['no_correct']}")
print("Retraining...")
self.retrain(training_data, epochs)
return training_data, stats
def run_full_training(self, problems, num_iterations=5,
num_samples=8, epochs=3):
"""Run multiple STaR iterations."""
all_stats = []
for iteration in range(num_iterations):
print(f"\n=== STaR Iteration {iteration + 1}/{num_iterations} ===")
_, stats = self.run_iteration(problems, num_samples, epochs)
all_stats.append(stats)
# Evaluate
accuracy = self.evaluate(problems)
print(f" Accuracy after iteration {iteration + 1}: {accuracy:.3f}")
return all_stats
def evaluate(self, problems, temperature=0.0):
"""Evaluate model accuracy (greedy decoding)."""
self.model.eval()
correct = 0
for problem in problems:
prompt = self._build_rationale_prompt(problem.question)
inputs = self.tokenizer(prompt, return_tensors="pt").to(
self.model.device
)
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=512,
temperature=temperature,
do_sample=False,
)
generated = self.tokenizer.decode(
output[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True,
)
_, answer = self._parse_rationale(generated)
if self._check_answer(answer, problem.correct_answer):
correct += 1
self.model.train()
return correct / len(problems)
def _build_rationale_prompt(self, question):
return (f"Question: {question}\n"
f"Let's think step by step.\n")
def _build_hint_prompt(self, question, answer):
return (f"Question: {question}\n"
f"The answer is {answer}.\n"
f"Explain step by step why.\n")
def _format_training_example(self, question, rationale, answer):
return (f"Question: {question}\n"
f"Let's think step by step.\n"
f"{rationale}\n"
f"The answer is {answer}.")
def _parse_rationale(self, text):
"""Split generated text into chain-of-thought and final answer."""
# Look for "The answer is" pattern
markers = ["the answer is", "therefore,", "thus,", "so the answer is"]
text_lower = text.lower()
for marker in markers:
if marker in text_lower:
idx = text_lower.index(marker)
cot = text[:idx].strip()
answer = text[idx:].strip()
# Extract just the answer value
answer = answer.split('\n')[0].strip()
return cot, answer
# If no marker found, last line is the answer
lines = text.strip().split('\n')
if len(lines) > 1:
return '\n'.join(lines[:-1]), lines[-1]
return text, text
def _check_answer(self, predicted, correct):
"""Check if predicted answer matches correct answer."""
pred = predicted.lower().strip().rstrip('.')
corr = correct.lower().strip().rstrip('.')
return pred == corr or corr in pred
STaR Accuracy Improvement Over Iterations (GSM8K)
| Metric | Base | Iter 1 | Iter 2 | Iter 3 | Iter 4 | Iter 5 |
|---|---|---|---|---|---|---|
| 7B Model on GSM8K |
ReST: Reinforced Self-Training
From Filtering to Reinforcement Learning
ReST treats the self-improvement loop as an RL problem. Instead of binary filtering (correct/incorrect), ReST assigns a scalar reward to each generated rationale and optimizes using a policy gradient objective.
class ReSTTrainer:
"""
Reinforced Self-Training.
Two phases:
1. Grow: Generate samples from the current policy
2. Improve: Train on samples weighted by reward
"""
def __init__(self, model, tokenizer, reward_fn, lr=1e-5):
self.model = model
self.tokenizer = tokenizer
self.reward_fn = reward_fn
self.lr = lr
def grow_phase(self, problems, num_samples=16, temperature=0.8):
"""
Phase 1: Generate multiple solutions for each problem.
Score each solution with the reward function.
"""
dataset = []
for problem in problems:
prompt = f"Question: {problem.question}\nSolution:\n"
inputs = self.tokenizer(prompt, return_tensors="pt").to(
self.model.device
)
solutions = []
for _ in range(num_samples):
with torch.no_grad():
output = self.model.generate(
**inputs,
max_new_tokens=512,
temperature=temperature,
do_sample=True,
top_p=0.95,
)
generated = self.tokenizer.decode(
output[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True,
)
# Score with reward function
reward = self.reward_fn(
problem.question, generated, problem.correct_answer
)
solutions.append({
'question': problem.question,
'solution': generated,
'reward': reward,
})
dataset.extend(solutions)
return dataset
def improve_phase(self, dataset, threshold_percentile=70,
epochs=3, batch_size=8):
"""
Phase 2: Train on high-reward solutions.
Options:
1. Binary: train on solutions above reward threshold
2. Weighted: weight loss by reward value
3. Contrastive: pair high-reward and low-reward solutions
"""
# Compute reward threshold (e.g., 70th percentile)
rewards = [d['reward'] for d in dataset]
threshold = sorted(rewards)[int(len(rewards) * threshold_percentile / 100)]
# Filter to high-reward solutions
good_solutions = [d for d in dataset if d['reward'] >= threshold]
print(f" Grow phase: {len(dataset)} solutions")
print(f" Reward threshold (p{threshold_percentile}): {threshold:.3f}")
print(f" Training on {len(good_solutions)} solutions "
f"({100 * len(good_solutions) / len(dataset):.0f}%)")
# Train
optimizer = self.optimizer_cls(self.model.parameters(), lr=self.lr)
self.model.train()
for epoch in range(epochs):
random.shuffle(good_solutions)
total_loss = 0.0
n_batches = 0
for i in range(0, len(good_solutions), batch_size):
batch = good_solutions[i:i + batch_size]
# Format as training texts
texts = [
f"Question: {d['question']}\nSolution:\n{d['solution']}"
for d in batch
]
encoded = self.tokenizer(
texts, return_tensors="pt", truncation=True,
max_length=1024, padding=True,
)
input_ids = encoded['input_ids'].to(self.model.device)
attention_mask = encoded['attention_mask'].to(self.model.device)
# Optional: weight loss by reward
rewards_tensor = torch.tensor(
[d['reward'] for d in batch],
device=self.model.device,
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=input_ids,
)
# Reward-weighted loss
# Standard loss is averaged over tokens;
# multiply by per-example reward
loss = outputs.loss * rewards_tensor.mean()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
n_batches += 1
print(f" Epoch {epoch + 1}: loss = {total_loss / n_batches:.4f}")
def run_iteration(self, problems, num_samples=16, epochs=3):
"""Run one Grow-Improve iteration."""
print("Grow phase...")
dataset = self.grow_phase(problems, num_samples)
print("Improve phase...")
self.improve_phase(dataset, epochs=epochs)
return dataset
optimizer_cls = torch.optim.AdamW
STaR vs ReST Comparison
| Aspect | STaR | ReST |
|---|---|---|
| Selection signal | Binary (correct/incorrect) | Scalar reward |
| Hint mechanism | Rationalization from hint | No hint (higher threshold) |
| Training objective | Cross-entropy on correct rationales | Reward-weighted cross-entropy |
| Samples per problem | 8 typical | 16-64 typical |
| Best use case | Tasks with verifiable answers | Tasks with graded quality |
| Collapse risk | Moderate (hint distribution shift) | Higher (reward hacking) |
The Data Flywheel
How Self-Improvement Creates Exponential Gains
class DataFlywheel:
"""
Model the self-improvement dynamics.
Each iteration:
1. Better model generates more correct rationales
2. More correct rationales = better training data
3. Better training data = better model (next iteration)
This compounds until hitting a ceiling.
"""
def __init__(self, initial_accuracy=0.35, ceiling=0.85):
self.accuracy = initial_accuracy
self.ceiling = ceiling
self.history = [initial_accuracy]
def simulate_iteration(self, num_samples=8):
"""
Simulate one STaR iteration.
The improvement per iteration depends on:
- Current accuracy (determines quality of generated data)
- Number of samples (more samples = more likely to find correct ones)
- Distance from ceiling (diminishing returns)
"""
# Probability of at least one correct rationale in N samples
p_at_least_one = 1 - (1 - self.accuracy) ** num_samples
# Training data quality: fraction of problems with correct rationales
data_quality = p_at_least_one
# Improvement: proportional to data quality and distance from ceiling
headroom = self.ceiling - self.accuracy
improvement = data_quality * headroom * 0.2 # 20% of possible improvement
self.accuracy = min(self.ceiling, self.accuracy + improvement)
self.history.append(self.accuracy)
return {
'accuracy': self.accuracy,
'p_at_least_one_correct': p_at_least_one,
'data_quality': data_quality,
'improvement': improvement,
}
def simulate_full(self, iterations=10, num_samples=8):
"""Simulate multiple iterations."""
results = []
for i in range(iterations):
result = self.simulate_iteration(num_samples)
results.append(result)
return results
Why the Flywheel Stalls
The flywheel hits diminishing returns for three reasons:
FLYWHEEL_STALL_REASONS = {
"distribution_narrowing": {
"description": (
"Each iteration, the model only trains on problems it can solve. "
"Hard problems are never represented in training data. "
"The model gets very good at easy problems but never improves "
"on hard ones."
),
"mitigation": "Hint rationalization (STaR) or curriculum sampling.",
},
"mode_collapse": {
"description": (
"The model converges to a single reasoning style. "
"It generates the same chain-of-thought pattern for all problems, "
"even when that pattern is suboptimal."
),
"mitigation": "Temperature diversity, multiple starting points, "
"reward for diversity.",
},
"error_amplification": {
"description": (
"Correct answers can have incorrect reasoning. "
"The model learns to produce answers that happen to be right "
"for wrong reasons. These errors compound across iterations."
),
"mitigation": "Process reward models (score individual steps, "
"not just final answers).",
},
}
The most insidious failure mode is error amplification with correct answers. On GSM8K, approximately 15-20% of “correct” rationales contain arithmetic errors that cancel out or reasoning steps that skip critical logic. Training on these rationales teaches the model to produce plausible-looking but logically flawed reasoning. After 5+ iterations, the model’s reasoning becomes increasingly superficial even as its accuracy plateaus.
Model Collapse: Detection and Prevention
What Collapse Looks Like
class CollapseDetector:
"""
Detect early signs of model collapse in self-improvement loops.
"""
def __init__(self):
self.iteration_metrics = []
def measure_diversity(self, generated_rationales):
"""
Measure diversity of generated rationales.
Collapse = all rationales look the same.
"""
from collections import Counter
import numpy as np
# Metric 1: Unique n-gram ratio
all_trigrams = []
for rationale in generated_rationales:
words = rationale.lower().split()
trigrams = [tuple(words[i:i + 3]) for i in range(len(words) - 2)]
all_trigrams.extend(trigrams)
unique_ratio = len(set(all_trigrams)) / max(len(all_trigrams), 1)
# Metric 2: Average pairwise similarity (lower = more diverse)
from difflib import SequenceMatcher
similarities = []
sample_size = min(50, len(generated_rationales))
sample = random.sample(generated_rationales, sample_size)
for i in range(len(sample)):
for j in range(i + 1, len(sample)):
sim = SequenceMatcher(None, sample[i], sample[j]).ratio()
similarities.append(sim)
avg_similarity = np.mean(similarities) if similarities else 0.0
# Metric 3: Vocabulary size relative to text length
all_words = ' '.join(generated_rationales).split()
vocab_ratio = len(set(all_words)) / max(len(all_words), 1)
return {
'unique_trigram_ratio': unique_ratio,
'avg_pairwise_similarity': avg_similarity,
'vocab_ratio': vocab_ratio,
}
def measure_difficulty_coverage(self, problems, success_mask):
"""
Check if the model is solving diverse difficulty levels.
Collapse = only solving easy problems.
"""
difficulty_success = {}
for problem, success in zip(problems, success_mask):
d = problem.difficulty
if d not in difficulty_success:
difficulty_success[d] = {'correct': 0, 'total': 0}
difficulty_success[d]['total'] += 1
if success:
difficulty_success[d]['correct'] += 1
for d in difficulty_success:
stats = difficulty_success[d]
stats['accuracy'] = stats['correct'] / max(stats['total'], 1)
return difficulty_success
def check_collapse(self, iteration_num, diversity_metrics,
difficulty_coverage, accuracy):
"""
Determine if collapse is occurring.
"""
alerts = []
# Check diversity
if diversity_metrics['unique_trigram_ratio'] < 0.3:
alerts.append("LOW_DIVERSITY: unique trigram ratio below 0.3")
if diversity_metrics['avg_pairwise_similarity'] > 0.7:
alerts.append("HIGH_SIMILARITY: average pairwise similarity above 0.7")
# Check difficulty coverage
hard_accuracy = difficulty_coverage.get('hard', {}).get('accuracy', 0)
easy_accuracy = difficulty_coverage.get('easy', {}).get('accuracy', 0)
if hard_accuracy < 0.1 and easy_accuracy > 0.8:
alerts.append("DIFFICULTY_COLLAPSE: solving easy but not hard problems")
# Check accuracy plateau
if len(self.iteration_metrics) >= 3:
recent = [m['accuracy'] for m in self.iteration_metrics[-3:]]
if max(recent) - min(recent) < 0.005:
alerts.append("ACCURACY_PLATEAU: no improvement in 3 iterations")
self.iteration_metrics.append({
'iteration': iteration_num,
'accuracy': accuracy,
'diversity': diversity_metrics,
})
return alerts
Mitigation Strategies
class CollapseMitigation:
"""
Strategies to prevent model collapse in self-improvement loops.
"""
@staticmethod
def data_mixing(synthetic_data, original_data, synthetic_fraction=0.5):
"""
Mix synthetic data with original training data.
Prevents distribution drift by anchoring to real data.
"""
num_synthetic = int(len(original_data) * synthetic_fraction /
(1 - synthetic_fraction))
num_synthetic = min(num_synthetic, len(synthetic_data))
sampled_synthetic = random.sample(synthetic_data, num_synthetic)
mixed = original_data + sampled_synthetic
random.shuffle(mixed)
return mixed
@staticmethod
def temperature_curriculum(iteration, max_iterations=10):
"""
Increase temperature over iterations to maintain diversity.
Early iterations: low temperature for quality.
Later iterations: higher temperature to explore.
"""
base_temp = 0.5
max_temp = 1.2
progress = iteration / max_iterations
return base_temp + (max_temp - base_temp) * progress
@staticmethod
def process_reward_filtering(rationales, process_reward_model):
"""
Filter by reasoning process quality, not just final answer.
Prevents training on correct-answer-wrong-reasoning examples.
"""
filtered = []
for r in rationales:
if not r.correct:
continue
# Score each reasoning step
steps = r.chain_of_thought.split('\n')
step_scores = []
for i, step in enumerate(steps):
context = '\n'.join(steps[:i + 1])
score = process_reward_model.score_step(
r.question, context, step
)
step_scores.append(score)
# Keep only rationales where ALL steps score above threshold
min_step_score = min(step_scores) if step_scores else 0
if min_step_score >= 0.5:
filtered.append(r)
return filtered
@staticmethod
def difficulty_reweighting(problems, current_accuracy_by_difficulty):
"""
Oversample hard problems to prevent difficulty collapse.
"""
reweighted = []
for problem in problems:
difficulty = problem.difficulty
current_acc = current_accuracy_by_difficulty.get(difficulty, 0.5)
# Higher weight for lower accuracy (harder problems)
weight = max(0.1, 1.0 - current_acc)
# Sample with replacement proportional to weight
if random.random() < weight:
reweighted.append(problem)
return reweighted
Effect of Mitigation Strategies on GSM8K Accuracy
| Metric | Base | Iter 1 | Iter 2 | Iter 3 | Iter 4 | Iter 5 | Iter 8 | Iter 10 |
|---|---|---|---|---|---|---|---|---|
| No Mitigation | ||||||||
| Data Mixing (50%) | ||||||||
| Process Reward + Mixing |
Reward Function Design for Self-Training
Types of Reward Signals
class RewardFunctions:
"""
Different reward functions for self-training.
The reward function determines what the model optimizes for.
"""
@staticmethod
def binary_correctness(question, solution, correct_answer):
"""
Simplest reward: 1 if correct, 0 if wrong.
Used by basic STaR.
"""
predicted = extract_final_answer(solution)
return 1.0 if answers_match(predicted, correct_answer) else 0.0
@staticmethod
def soft_correctness(question, solution, correct_answer):
"""
Soft reward based on partial correctness.
Useful for multi-step problems.
"""
predicted = extract_final_answer(solution)
if answers_match(predicted, correct_answer):
return 1.0
# Partial credit for being close
try:
pred_val = float(predicted)
correct_val = float(correct_answer)
relative_error = abs(pred_val - correct_val) / max(abs(correct_val), 1e-8)
return max(0, 1.0 - relative_error)
except (ValueError, TypeError):
return 0.0
@staticmethod
def process_reward(question, solution, correct_answer, prm_model):
"""
Process Reward Model: score the reasoning process.
Rewards correct reasoning steps, not just correct answers.
"""
# Score each step
steps = solution.split('\n')
step_scores = []
for i, step in enumerate(steps):
context = '\n'.join(steps[:i + 1])
score = prm_model.score_step(question, context, step)
step_scores.append(score)
if not step_scores:
return 0.0
# Combine: minimum step score * correctness
correctness = 1.0 if answers_match(
extract_final_answer(solution), correct_answer
) else 0.0
process_quality = min(step_scores)
return 0.7 * correctness + 0.3 * process_quality
def extract_final_answer(solution):
"""Extract the final numerical or text answer from a solution."""
lines = solution.strip().split('\n')
for line in reversed(lines):
line = line.strip()
if line and any(c.isdigit() for c in line):
# Extract number
import re
numbers = re.findall(r'-?\d+\.?\d*', line)
if numbers:
return numbers[-1]
return solution.strip().split('\n')[-1] if solution.strip() else ""
def answers_match(predicted, correct):
"""Check if two answers match (handling numeric comparison)."""
pred = str(predicted).strip().lower()
corr = str(correct).strip().lower()
if pred == corr:
return True
try:
return abs(float(pred) - float(corr)) < 1e-6
except (ValueError, TypeError):
return False
Production Considerations
Compute Budget
COMPUTE_BUDGET = {
"star_7b_gsm8k": {
"model_size": "7B",
"dataset": "GSM8K (8.5K problems)",
"samples_per_problem": 8,
"iterations": 5,
"generation_cost_per_iter": "~2 H100-hours (68K generations)",
"training_cost_per_iter": "~0.5 H100-hours (3 epochs on 6K examples)",
"total_cost": "~12.5 H100-hours = $37.50",
"accuracy_improvement": "35% -> 63%",
},
"rest_70b_math": {
"model_size": "70B",
"dataset": "MATH (12.5K problems)",
"samples_per_problem": 64,
"iterations": 3,
"generation_cost_per_iter": "~160 H100-hours (800K generations)",
"training_cost_per_iter": "~20 H100-hours (3 epochs on 50K examples)",
"total_cost": "~540 H100-hours = $1,620",
"accuracy_improvement": "42% -> 58%",
},
}
Self-Training Compute Efficiency
| Method | Model | Dataset | GPU-Hours | Accuracy Gain | Gain/GPU-Hour |
|---|---|---|---|---|---|
| STaR | 7B | GSM8K | 12.5 | +28% | 2.24%/hr |
| ReST | 7B | GSM8K | 18 | +25% | 1.39%/hr |
| STaR | 70B | GSM8K | 200 | +18% | 0.09%/hr |
| ReST | 70B | MATH | 540 | +16% | 0.03%/hr |
| SFT baseline | 7B | GSM8K | 5 | +20% | 4.00%/hr |
Self-improving systems are not a replacement for high-quality human data. They are a multiplier: starting from a small amount of human-verified reasoning data, STaR and ReST can bootstrap 2-3x the model’s initial reasoning capability at low compute cost. The key constraint is collapse prevention. Without data mixing, process reward filtering, and diversity maintenance, the self-improvement loop degrades after 3-5 iterations. With these mitigations, the loop sustains improvement for 8-10 iterations before hitting the ceiling imposed by the base model’s representation capacity.