AlphaGo Zero trained entirely from self-play — no human games, just iterative self-improvement. After 3 days, it beat the version trained on expert human games. After 40 days, it beat the world champion. Self-improving LLMs attempt the same loop: generate reasoning traces, filter for correctness, retrain, repeat. When the loop works, performance compounds: STaR improved GSM-8K from 10% to 44% over 4 iterations. When it fails, model collapse degrades every metric until the model outputs repetitive nonsense. The difference is quality filtering.
This post covers the three foundational approaches: STaR (Self-Taught Reasoner), ReST (Reinforced Self-Training), and self-play, along with the collapse mechanisms and mitigation strategies.
STaR: Self-Taught Reasoner
The Core Idea
STaR (Zelikman et al., 2022) improves a model’s reasoning by having it generate rationales (chain-of-thought), keeping only the rationales that lead to correct answers, and retraining on those successful rationales.
The algorithm:
class STaR:
"""
Self-Taught Reasoner.
The model generates reasoning traces, keeps correct ones, and retrains.
"""
def __init__(self, model, tokenizer, problems_with_answers):
self.model = model
self.tokenizer = tokenizer
self.problems = problems_with_answers # List of (question, answer) pairs
def run_iteration(self, temperature=0.7, num_samples=8):
"""Run one STaR iteration."""
successful_rationales = []
rationalized_from_hint = []
for question, correct_answer in self.problems:
# Step 1: Generate rationales (chain-of-thought)
rationales = self._generate_rationales(
question, num_samples, temperature
)
# Step 2: Check which rationales lead to the correct answer
found_correct = False
for rationale in rationales:
predicted_answer = self._extract_answer(rationale)
if self._answers_match(predicted_answer, correct_answer):
successful_rationales.append({
"question": question,
"rationale": rationale,
"answer": correct_answer,
})
found_correct = True
break # Keep the first correct rationale
# Step 3: Rationalization from hint (if no correct rationale found)
if not found_correct:
# Give the model the correct answer and ask it to justify
hint_rationale = self._rationalize_from_hint(
question, correct_answer
)
if hint_rationale:
rationalized_from_hint.append({
"question": question,
"rationale": hint_rationale,
"answer": correct_answer,
})
# Step 4: Create training dataset
training_data = successful_rationales + rationalized_from_hint
return {
"training_examples": training_data,
"correct_without_hint": len(successful_rationales),
"rationalized_from_hint": len(rationalized_from_hint),
"total_problems": len(self.problems),
"success_rate": len(successful_rationales) / len(self.problems),
}
def _generate_rationales(self, question, num_samples, temperature):
"""Generate multiple reasoning traces for a question."""
prompt = (
f"Question: {question}\n"
f"Let me think step by step.\n"
)
rationales = []
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to("cuda")
for _ in range(num_samples):
with torch.no_grad():
output = self.model.generate(
input_ids,
max_new_tokens=512,
temperature=temperature,
top_p=0.95,
do_sample=True,
)
rationale = self.tokenizer.decode(
output[0][input_ids.shape[1]:], skip_special_tokens=True
)
rationales.append(rationale)
return rationales
def _rationalize_from_hint(self, question, correct_answer):
"""Generate a rationale given the correct answer (hint)."""
prompt = (
f"Question: {question}\n"
f"The answer is {correct_answer}.\n"
f"Let me explain step by step why this is correct.\n"
)
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
output = self.model.generate(
input_ids, max_new_tokens=512,
temperature=0.3, # Lower temperature for hint-based generation
)
rationale = self.tokenizer.decode(
output[0][input_ids.shape[1]:], skip_special_tokens=True
)
# Verify the rationalization is coherent
predicted = self._extract_answer(rationale)
if self._answers_match(predicted, correct_answer):
return rationale
return None
def _extract_answer(self, rationale):
"""Extract the final answer from a reasoning trace."""
# Look for "The answer is X" or "Therefore, X" patterns
import re
patterns = [
r"(?:the answer is|therefore|thus|so)\s*[:.]?\s*(.+?)(?:\.|$)",
r"(?:=\s*)(\d+(?:\.\d+)?)",
]
for pattern in patterns:
match = re.search(pattern, rationale.lower())
if match:
return match.group(1).strip()
# Fallback: last line
lines = rationale.strip().split('\n')
return lines[-1].strip() if lines else ""
def _answers_match(self, predicted, correct):
"""Check if two answers match (with normalization)."""
pred = predicted.strip().lower().rstrip('.')
corr = correct.strip().lower().rstrip('.')
return pred == corr or pred in corr or corr in pred
The Rationalization Trick
The key innovation in STaR is “rationalization from hint.” When the model cannot generate a correct rationale on its own, it is given the correct answer and asked to work backward. This is crucial because:
- Without hints, STaR can only improve on problems it can already sometimes solve.
- With hints, STaR can learn to solve problems it has never solved before.
- The hint-based rationales teach the model new reasoning patterns.
STaR Improvement Over Iterations (GSM8K Math)
(% accuracy)ReST: Reinforced Self-Training
Algorithm
ReST (Gulcehre et al., 2023) is a more general framework that combines self-generation with reward model scoring:
class ReST:
"""
Reinforced Self-Training.
Generate -> Score -> Filter -> Retrain loop.
"""
def __init__(self, model, reward_model, problems):
self.model = model
self.reward_model = reward_model
self.problems = problems
def run_iteration(
self,
num_samples_per_problem=16,
top_k_fraction=0.25,
temperature=0.8,
):
"""Run one ReST iteration."""
all_samples = []
# Step 1: Generate (grow phase)
for problem in self.problems:
samples = self._generate_samples(
problem, num_samples_per_problem, temperature
)
all_samples.extend(samples)
# Step 2: Score with reward model (filter phase)
scored = self._score_samples(all_samples)
# Step 3: Filter top-K
scored.sort(key=lambda x: x["reward"], reverse=True)
top_k = int(len(scored) * top_k_fraction)
filtered = scored[:top_k]
# Step 4: Create training dataset
training_data = [
{
"prompt": s["problem"],
"response": s["response"],
"reward": s["reward"],
}
for s in filtered
]
return {
"total_generated": len(all_samples),
"after_filtering": len(filtered),
"avg_reward_all": sum(s["reward"] for s in scored) / len(scored),
"avg_reward_filtered": sum(s["reward"] for s in filtered) / len(filtered),
"training_data": training_data,
}
def _generate_samples(self, problem, n, temperature):
"""Generate n candidate responses for a problem."""
samples = []
input_ids = self.tokenizer.encode(
problem["prompt"], return_tensors="pt"
).to("cuda")
for _ in range(n):
with torch.no_grad():
output = self.model.generate(
input_ids,
max_new_tokens=1024,
temperature=temperature,
do_sample=True,
)
response = self.tokenizer.decode(
output[0][input_ids.shape[1]:], skip_special_tokens=True
)
samples.append({
"problem": problem["prompt"],
"response": response,
"correct_answer": problem.get("answer"),
})
return samples
def _score_samples(self, samples):
"""Score samples using the reward model."""
scored = []
for sample in samples:
# Reward model scores the response quality
reward = self.reward_model.score(
sample["problem"], sample["response"]
)
# Bonus for correct answer (if verifiable)
if sample.get("correct_answer"):
predicted = self._extract_answer(sample["response"])
if predicted == sample["correct_answer"]:
reward += 1.0 # Correctness bonus
sample["reward"] = reward
scored.append(sample)
return scored
ReST vs. STaR
Comparison: STaR vs. ReST
| Aspect | STaR | ReST |
|---|---|---|
| Verification | Answer matching | Reward model |
| Applicability | Tasks with verifiable answers | Any task |
| Hint mechanism | Yes (rationalization) | No |
| Samples per problem | ~8 | ~16-64 |
| Training signal | Binary (correct/incorrect) | Continuous (reward score) |
| Risk of reward hacking | Low (ground truth) | Medium (reward model bias) |
Self-Play for Reasoning
The Self-Play Framework
Self-play generates training data by having the model play multiple roles — generator and verifier:
class SelfPlayReasoner:
"""
Self-play for reasoning improvement.
The model plays both 'solver' and 'verifier' roles.
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def run_self_play_round(self, problems, rounds=3):
"""Run multiple rounds of self-play."""
training_data = []
for problem in problems:
# Round 1: Generate multiple solutions
solutions = self._generate_solutions(problem, n=8)
# Round 2: Self-verify each solution
verified = self._self_verify(problem, solutions)
# Round 3: Generate improved solution based on verification
improved = self._generate_improved(problem, verified)
# Keep the best chain
if improved["quality_score"] > 0.7:
training_data.append({
"problem": problem,
"solution": improved["solution"],
"verification": improved["verification"],
"quality": improved["quality_score"],
})
return training_data
def _generate_solutions(self, problem, n=8):
"""Generate n candidate solutions."""
prompt = f"Solve this problem step by step:\n{problem}\n\nSolution:"
solutions = []
for _ in range(n):
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
output = self.model.generate(
input_ids, max_new_tokens=512,
temperature=0.8, do_sample=True,
)
sol = self.tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
solutions.append(sol)
return solutions
def _self_verify(self, problem, solutions):
"""Have the model verify each solution."""
verified = []
for sol in solutions:
prompt = (
f"Problem: {problem}\n\n"
f"Proposed solution:\n{sol}\n\n"
f"Is this solution correct? Check each step carefully. "
f"Output CORRECT or INCORRECT with explanation."
)
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
output = self.model.generate(
input_ids, max_new_tokens=256,
temperature=0.3,
)
verification = self.tokenizer.decode(
output[0][input_ids.shape[1]:], skip_special_tokens=True
)
is_correct = "correct" in verification.lower().split('\n')[0].lower()
verified.append({
"solution": sol,
"verification": verification,
"self_judged_correct": is_correct,
})
return verified
def _generate_improved(self, problem, verified_solutions):
"""Generate an improved solution based on self-verification feedback."""
# Find the best verified solution
correct_solutions = [v for v in verified_solutions if v["self_judged_correct"]]
if correct_solutions:
best = correct_solutions[0]
return {
"solution": best["solution"],
"verification": best["verification"],
"quality_score": 0.9,
}
# No correct solution -- generate one informed by the errors
error_feedback = "\n".join([
f"Attempt: {v['solution'][:200]}...\nIssue: {v['verification'][:200]}"
for v in verified_solutions[:3]
])
prompt = (
f"Problem: {problem}\n\n"
f"Previous attempts had these issues:\n{error_feedback}\n\n"
f"Avoiding those mistakes, here is the correct solution:\n"
)
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
output = self.model.generate(
input_ids, max_new_tokens=512, temperature=0.5,
)
improved = self.tokenizer.decode(
output[0][input_ids.shape[1]:], skip_special_tokens=True
)
return {
"solution": improved,
"verification": "Generated from error feedback",
"quality_score": 0.6,
}
The Data Flywheel
Iteration Dynamics
class DataFlywheel:
"""
The complete self-improvement loop.
Model -> Generate Data -> Filter -> Retrain -> Better Model -> Repeat.
"""
def __init__(self, initial_model, problems, reward_model=None):
self.model = initial_model
self.problems = problems
self.reward_model = reward_model
self.iteration_history = []
def run(self, num_iterations=5, method="star"):
"""Run the full flywheel for multiple iterations."""
for i in range(num_iterations):
print(f"\n=== Iteration {i+1} ===")
# Generate and filter training data
if method == "star":
star = STaR(self.model, self.tokenizer, self.problems)
result = star.run_iteration()
training_data = result["training_examples"]
elif method == "rest":
rest = ReST(self.model, self.reward_model, self.problems)
result = rest.run_iteration()
training_data = result["training_data"]
elif method == "self_play":
sp = SelfPlayReasoner(self.model, self.tokenizer)
training_data = sp.run_self_play_round(self.problems)
# Evaluate before retraining
pre_score = self._evaluate()
# Retrain model
self.model = self._retrain(training_data)
# Evaluate after retraining
post_score = self._evaluate()
self.iteration_history.append({
"iteration": i + 1,
"training_examples": len(training_data),
"pre_score": pre_score,
"post_score": post_score,
"improvement": post_score - pre_score,
})
print(f" Training examples: {len(training_data)}")
print(f" Score: {pre_score:.3f} -> {post_score:.3f}")
# Early stopping if no improvement
if post_score <= pre_score:
print(" No improvement, stopping.")
break
return self.iteration_history
def _evaluate(self):
"""Evaluate the model on held-out problems."""
correct = 0
for problem in self.problems[:100]: # Eval on subset
response = self._generate_one(problem["prompt"])
predicted = self._extract_answer(response)
if predicted == problem.get("answer", ""):
correct += 1
return correct / 100
def _retrain(self, training_data):
"""Fine-tune the model on the generated data."""
# In practice, this would be SFT using the training data
# Here we return the model as-is for the demonstration
return self.model
Convergence and Collapse
Self-Improvement Trajectories: Success vs. Collapse
(% accuracy)Model Collapse: The Failure Mode
What Causes Collapse
Model collapse occurs when the model trains on its own outputs without sufficient diversity pressure:
class CollapseDetector:
"""Detect signs of model collapse during self-improvement."""
def __init__(self):
self.diversity_history = []
self.accuracy_history = []
def check_for_collapse(self, generated_texts, accuracy):
"""Check for early signs of collapse."""
# Metric 1: Output diversity
diversity = self._compute_diversity(generated_texts)
self.diversity_history.append(diversity)
# Metric 2: Accuracy trend
self.accuracy_history.append(accuracy)
warnings = []
# Warning 1: Diversity dropping
if len(self.diversity_history) >= 3:
recent = self.diversity_history[-3:]
if all(recent[i] < recent[i-1] for i in range(1, len(recent))):
warnings.append({
"type": "diversity_decline",
"message": f"Output diversity declining: {recent}",
"severity": "high",
})
# Warning 2: Accuracy plateauing then dropping
if len(self.accuracy_history) >= 4:
recent = self.accuracy_history[-4:]
if recent[-1] < recent[-2] < recent[-3]:
warnings.append({
"type": "accuracy_decline",
"message": f"Accuracy declining: {recent}",
"severity": "critical",
})
# Warning 3: Repetition in outputs
repetition_rate = self._compute_repetition(generated_texts)
if repetition_rate > 0.3:
warnings.append({
"type": "high_repetition",
"message": f"Repetition rate: {repetition_rate:.2%}",
"severity": "high",
})
return {
"diversity": diversity,
"accuracy": accuracy,
"repetition_rate": repetition_rate,
"collapse_risk": len(warnings) > 0,
"warnings": warnings,
}
def _compute_diversity(self, texts):
"""Compute vocabulary diversity (type-token ratio)."""
all_words = []
for text in texts[:100]:
all_words.extend(text.lower().split())
if not all_words:
return 0
return len(set(all_words)) / len(all_words)
def _compute_repetition(self, texts):
"""Compute fraction of texts that are near-duplicates of each other."""
if len(texts) < 2:
return 0
from collections import Counter
# Simple: check 3-gram overlap between pairs
trigram_sets = []
for text in texts[:100]:
words = text.lower().split()
trigrams = set()
for i in range(len(words) - 2):
trigrams.add(tuple(words[i:i+3]))
trigram_sets.append(trigrams)
duplicate_pairs = 0
total_pairs = 0
for i in range(len(trigram_sets)):
for j in range(i + 1, min(i + 10, len(trigram_sets))):
intersection = trigram_sets[i] & trigram_sets[j]
union = trigram_sets[i] | trigram_sets[j]
jaccard = len(intersection) / max(len(union), 1)
if jaccard > 0.5:
duplicate_pairs += 1
total_pairs += 1
return duplicate_pairs / max(total_pairs, 1)
Mitigation Strategies
class CollapseMitigator:
"""Strategies to prevent model collapse during self-improvement."""
def __init__(self):
self.original_data_fraction = 0.3 # Keep 30% natural data
def apply_mitigations(self, synthetic_data, original_data):
"""Apply collapse mitigation strategies."""
import random
# Strategy 1: Mix with original (natural) data
natural_sample = random.sample(
original_data,
min(
int(len(synthetic_data) * self.original_data_fraction),
len(original_data),
)
)
mixed_data = synthetic_data + natural_sample
random.shuffle(mixed_data)
# Strategy 2: Diversity filtering
# Remove synthetic examples that are too similar to each other
diverse_data = self._diversity_filter(mixed_data, max_similarity=0.7)
# Strategy 3: Temperature scheduling
# Use higher temperature in later iterations to maintain diversity
# (applied during generation, not here)
return diverse_data
def _diversity_filter(self, data, max_similarity=0.7):
"""Remove near-duplicate examples to maintain diversity."""
# Simple dedup using trigram overlap
kept = [data[0]] if data else []
kept_trigrams = [self._get_trigrams(data[0]["response"])] if data else []
for example in data[1:]:
trigrams = self._get_trigrams(example.get("response", ""))
is_duplicate = False
for kt in kept_trigrams[-50:]: # Compare with recent examples
overlap = len(trigrams & kt) / max(len(trigrams | kt), 1)
if overlap > max_similarity:
is_duplicate = True
break
if not is_duplicate:
kept.append(example)
kept_trigrams.append(trigrams)
return kept
def _get_trigrams(self, text):
words = text.lower().split()
return set(tuple(words[i:i+3]) for i in range(len(words) - 2))
Shumailov et al. (2023) showed that training on model-generated data for even a few generations causes measurable distribution shift. The model’s output distribution narrows, losing tail behaviors and minority patterns. This is not a catastrophic failure — the model still generates coherent text — but the diversity and coverage of its outputs degrades progressively. The fix: always include a substantial fraction (20-30%) of natural (human-generated) data in the training mix.
Complete Self-Improvement Pipeline
class SelfImprovementPipeline:
"""
Complete self-improvement pipeline with collapse detection and mitigation.
"""
def __init__(self, model, tokenizer, problems, original_data):
self.model = model
self.tokenizer = tokenizer
self.problems = problems
self.original_data = original_data
self.collapse_detector = CollapseDetector()
self.mitigator = CollapseMitigator()
self.history = []
def run(self, max_iterations=10, method="star"):
"""Run the self-improvement pipeline with safety checks."""
for i in range(max_iterations):
# Generate training data
star = STaR(self.model, self.tokenizer, self.problems)
result = star.run_iteration()
synthetic_data = result["training_examples"]
# Check for collapse
generated_texts = [d.get("rationale", "") for d in synthetic_data]
collapse_check = self.collapse_detector.check_for_collapse(
generated_texts, result["success_rate"]
)
if collapse_check["collapse_risk"]:
print(f"Iteration {i+1}: Collapse risk detected!")
for warning in collapse_check["warnings"]:
print(f" WARNING: {warning['message']}")
if any(w["severity"] == "critical" for w in collapse_check["warnings"]):
print("Critical collapse detected. Stopping.")
break
# Apply mitigations
training_data = self.mitigator.apply_mitigations(
synthetic_data, self.original_data
)
# Retrain
self.model = self._finetune(training_data)
# Evaluate
score = self._evaluate()
self.history.append({
"iteration": i + 1,
"success_rate": result["success_rate"],
"eval_score": score,
"collapse_risk": collapse_check["collapse_risk"],
"diversity": collapse_check["diversity"],
})
print(f"Iteration {i+1}: score={score:.3f}, "
f"diversity={collapse_check['diversity']:.3f}")
return self.history