Part of Series Frontier Research 2025-2026 16 of 30
1 Reasoning Scaling Laws: How Inference-Time Compute Changes Everything We Know About Scaling 2 Lightning Attention: Implementing Linear-Time Attention for Million-Token Contexts 3 Policy of Thoughts: Test-Time Policy Evolution and Online Reasoning Refinement 4 Test-Time Compute Scaling: When a 1B Model Beats a 405B Model 5 Self-Improving Systems: Models That Generate Their Own Training Data 6 Embodied AI Foundations: World Models, Physical Reasoning, and the Sora/V-JEPA Connection 7 Reward Model Engineering: ORM vs PRM, Verifier Design, and Why Reward Quality Determines Reasoning Quality 8 Constitutional AI and RLHF Alternatives: DPO, KTO, ORPO, and the Post-Training Revolution 9 Long-Context Research: Architectures and Techniques for 1M to 10M+ Token Windows 10 Multimodal Fusion: Early vs Late Fusion, Cross-Attention, and Interleaved Architectures 11 Efficient Fine-Tuning: LoRA, DoRA, QLoRA, GaLore, and LISA — When to Use Each 12 The Research Frontier in 2026: Open Problems and Promising Directions 13 Hallucination Mitigation: Detection, Prevention, and Why LLMs Confidently Produce Nonsense 14 Mechanistic Interpretability: Sparse Autoencoders, Feature Circuits, and Understanding What's Inside 15 GRPO Complete Algorithm: Group Relative Policy Optimization for Reasoning Models 16 Synthetic Reasoning Data: STaR, ReST, and How Models Bootstrap Their Own Training Signal 17 Alignment at Scale: Scalable Oversight, Weak-to-Strong Generalization, and Constitutional AI 18 Agent Architectures: ReAct, Tool Use, Multi-Step Planning, and Memory Systems for LLM Agents 19 Continual Learning and Catastrophic Forgetting: Why Models Lose Old Knowledge When Learning New 20 Multimodal Generation: Text-to-Image, Text-to-Video, and Unified Generation Architectures 21 Model Evaluation Beyond Benchmarks: Arena, Human Preference, and Capability Elicitation 22 Scaling Laws Complete: Kaplan, Chinchilla, Inference-Time, and the Multi-Dimensional Frontier 23 World Models: Predicting Future States from Actions for Planning and Simulation 24 Tool Use and Function Calling: How LLMs Learn to Use APIs, Calculators, and Code Interpreters 25 Safety and Red Teaming: Adversarial Attacks, Jailbreaks, and Defense Mechanisms 26 Knowledge Editing: ROME, MEMIT, and Surgically Modifying What LLMs Know 27 Chain-of-Thought Internals: What Happens Inside the Model During Reasoning 28 Sparse Upcycling: Converting Dense Models to MoE Without Retraining from Scratch 29 Data-Efficient Training: Learning More from Less with Curriculum, Filtering, and Replay 30 The Open Source LLM Ecosystem in 2026: HuggingFace, Ollama, and the Tools That Changed Everything

Expert-written math reasoning traces cost 50200perproblem(hiringPhDstowritestepbystepsolutions).At10,000problems,thatis50-200 per problem (hiring PhDs to write step-by-step solutions). At 10,000 problems, that is 500K-2M. STaR (Self-Taught Reasoner) eliminates this cost: the model generates its own reasoning traces, filters for correctness, and retrains on successful examples. After 4 iterations, STaR improved GSM-8K accuracy from 10% to 44% with zero human-written rationales. The risk: model collapse, where self-generated data amplifies errors until performance degrades. The reward: exponential improvement from fixed compute if collapse is avoided.

This post covers the algorithms, their implementations, the data flywheel dynamics, collapse mechanisms, and production-ready mitigation strategies.

STaR: Self-Taught Reasoner

The Algorithm

STaR has three steps per iteration:

  1. Generate: the model produces chain-of-thought rationales for each problem
  2. Filter: keep only rationales that lead to the correct final answer
  3. Retrain: fine-tune the model on the filtered rationales

When the model fails to produce any correct rationale for a problem, STaR uses “rationalization from hint”: provide the correct answer and ask the model to generate a rationale that leads to it. This prevents the training distribution from shrinking to only “easy” problems.

import torch
import json
import random
from dataclasses import dataclass
from typing import Sequence

@dataclass
class Problem:
    question: str
    correct_answer: str
    difficulty: str = "unknown"

@dataclass
class Rationale:
    question: str
    chain_of_thought: str
    final_answer: str
    correct: bool
    source: str  # "direct" or "hint"

class STaRTrainer:
    """
    Complete STaR training loop.
    Self-Taught Reasoner: generate rationales, filter, retrain.
    """

    def __init__(self, model, tokenizer, optimizer_cls, lr=1e-5):
        self.model = model
        self.tokenizer = tokenizer
        self.optimizer_cls = optimizer_cls
        self.lr = lr

    def generate_rationales(self, problems, num_samples=8,
                           temperature=0.7, max_tokens=512):
        """
        Step 1: Generate chain-of-thought rationales for each problem.

        Args:
            problems: List of Problem objects
            num_samples: Number of rationales to generate per problem
            temperature: Sampling temperature
            max_tokens: Maximum tokens per rationale
        """
        all_rationales = []

        for problem in problems:
            prompt = self._build_rationale_prompt(problem.question)
            inputs = self.tokenizer(prompt, return_tensors="pt").to(
                self.model.device
            )

            rationales_for_problem = []
            for _ in range(num_samples):
                with torch.no_grad():
                    output = self.model.generate(
                        **inputs,
                        max_new_tokens=max_tokens,
                        temperature=temperature,
                        do_sample=True,
                        top_p=0.95,
                    )

                generated = self.tokenizer.decode(
                    output[0][inputs['input_ids'].shape[1]:],
                    skip_special_tokens=True,
                )

                # Parse chain-of-thought and final answer
                cot, answer = self._parse_rationale(generated)
                is_correct = self._check_answer(answer, problem.correct_answer)

                rationales_for_problem.append(Rationale(
                    question=problem.question,
                    chain_of_thought=cot,
                    final_answer=answer,
                    correct=is_correct,
                    source="direct",
                ))

            all_rationales.append(rationales_for_problem)

        return all_rationales

    def rationalize_from_hint(self, problem, max_tokens=512):
        """
        Generate a rationale given the correct answer as a hint.
        Used when no direct rationale was correct.
        """
        prompt = self._build_hint_prompt(problem.question, problem.correct_answer)
        inputs = self.tokenizer(prompt, return_tensors="pt").to(
            self.model.device
        )

        with torch.no_grad():
            output = self.model.generate(
                **inputs,
                max_new_tokens=max_tokens,
                temperature=0.3,  # Lower temperature for hint-based
                do_sample=True,
            )

        generated = self.tokenizer.decode(
            output[0][inputs['input_ids'].shape[1]:],
            skip_special_tokens=True,
        )

        cot, answer = self._parse_rationale(generated)
        is_correct = self._check_answer(answer, problem.correct_answer)

        return Rationale(
            question=problem.question,
            chain_of_thought=cot,
            final_answer=answer,
            correct=is_correct,
            source="hint",
        )

    def filter_and_collect(self, problems, all_rationales):
        """
        Step 2: Filter for correct rationales.
        For problems with no correct direct rationale, use hint rationalization.
        """
        training_data = []
        stats = {
            'direct_correct': 0,
            'hint_correct': 0,
            'no_correct': 0,
            'total_problems': len(problems),
        }

        for problem, rationales in zip(problems, all_rationales):
            # Find first correct direct rationale
            correct_rationale = None
            for r in rationales:
                if r.correct:
                    correct_rationale = r
                    stats['direct_correct'] += 1
                    break

            # Rationalize from hint if no direct rationale was correct
            if correct_rationale is None:
                hint_rationale = self.rationalize_from_hint(problem)
                if hint_rationale.correct:
                    correct_rationale = hint_rationale
                    stats['hint_correct'] += 1
                else:
                    stats['no_correct'] += 1
                    continue

            training_data.append({
                'question': correct_rationale.question,
                'rationale': correct_rationale.chain_of_thought,
                'answer': correct_rationale.final_answer,
                'source': correct_rationale.source,
            })

        return training_data, stats

    def retrain(self, training_data, epochs=3, batch_size=8):
        """
        Step 3: Fine-tune the model on filtered rationales.
        """
        optimizer = self.optimizer_cls(self.model.parameters(), lr=self.lr)
        self.model.train()

        # Prepare training examples
        examples = []
        for item in training_data:
            text = self._format_training_example(
                item['question'], item['rationale'], item['answer']
            )
            encoded = self.tokenizer(
                text, return_tensors="pt", truncation=True,
                max_length=1024, padding="max_length",
            )
            examples.append(encoded)

        for epoch in range(epochs):
            random.shuffle(examples)
            total_loss = 0.0
            num_batches = 0

            for i in range(0, len(examples), batch_size):
                batch = examples[i:i + batch_size]
                # Stack batch
                input_ids = torch.cat(
                    [e['input_ids'] for e in batch]
                ).to(self.model.device)
                attention_mask = torch.cat(
                    [e['attention_mask'] for e in batch]
                ).to(self.model.device)

                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=input_ids,
                )

                loss = outputs.loss
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

                total_loss += loss.item()
                num_batches += 1

            avg_loss = total_loss / num_batches
            print(f"  Epoch {epoch + 1}/{epochs}: loss = {avg_loss:.4f}")

    def run_iteration(self, problems, num_samples=8, epochs=3):
        """Run one complete STaR iteration."""
        print("Generating rationales...")
        all_rationales = self.generate_rationales(problems, num_samples)

        print("Filtering...")
        training_data, stats = self.filter_and_collect(problems, all_rationales)
        print(f"  Direct correct: {stats['direct_correct']}")
        print(f"  Hint correct: {stats['hint_correct']}")
        print(f"  No correct: {stats['no_correct']}")

        print("Retraining...")
        self.retrain(training_data, epochs)

        return training_data, stats

    def run_full_training(self, problems, num_iterations=5,
                         num_samples=8, epochs=3):
        """Run multiple STaR iterations."""
        all_stats = []

        for iteration in range(num_iterations):
            print(f"\n=== STaR Iteration {iteration + 1}/{num_iterations} ===")
            _, stats = self.run_iteration(problems, num_samples, epochs)
            all_stats.append(stats)

            # Evaluate
            accuracy = self.evaluate(problems)
            print(f"  Accuracy after iteration {iteration + 1}: {accuracy:.3f}")

        return all_stats

    def evaluate(self, problems, temperature=0.0):
        """Evaluate model accuracy (greedy decoding)."""
        self.model.eval()
        correct = 0
        for problem in problems:
            prompt = self._build_rationale_prompt(problem.question)
            inputs = self.tokenizer(prompt, return_tensors="pt").to(
                self.model.device
            )

            with torch.no_grad():
                output = self.model.generate(
                    **inputs,
                    max_new_tokens=512,
                    temperature=temperature,
                    do_sample=False,
                )

            generated = self.tokenizer.decode(
                output[0][inputs['input_ids'].shape[1]:],
                skip_special_tokens=True,
            )
            _, answer = self._parse_rationale(generated)
            if self._check_answer(answer, problem.correct_answer):
                correct += 1

        self.model.train()
        return correct / len(problems)

    def _build_rationale_prompt(self, question):
        return (f"Question: {question}\n"
                f"Let's think step by step.\n")

    def _build_hint_prompt(self, question, answer):
        return (f"Question: {question}\n"
                f"The answer is {answer}.\n"
                f"Explain step by step why.\n")

    def _format_training_example(self, question, rationale, answer):
        return (f"Question: {question}\n"
                f"Let's think step by step.\n"
                f"{rationale}\n"
                f"The answer is {answer}.")

    def _parse_rationale(self, text):
        """Split generated text into chain-of-thought and final answer."""
        # Look for "The answer is" pattern
        markers = ["the answer is", "therefore,", "thus,", "so the answer is"]
        text_lower = text.lower()

        for marker in markers:
            if marker in text_lower:
                idx = text_lower.index(marker)
                cot = text[:idx].strip()
                answer = text[idx:].strip()
                # Extract just the answer value
                answer = answer.split('\n')[0].strip()
                return cot, answer

        # If no marker found, last line is the answer
        lines = text.strip().split('\n')
        if len(lines) > 1:
            return '\n'.join(lines[:-1]), lines[-1]
        return text, text

    def _check_answer(self, predicted, correct):
        """Check if predicted answer matches correct answer."""
        pred = predicted.lower().strip().rstrip('.')
        corr = correct.lower().strip().rstrip('.')
        return pred == corr or corr in pred

STaR Accuracy Improvement Over Iterations (GSM8K)

Metric BaseIter 1Iter 2Iter 3Iter 4Iter 5
7B Model on GSM8K
35.2
48.7
55.3
59.1
61.8
63.2

ReST: Reinforced Self-Training

From Filtering to Reinforcement Learning

ReST treats the self-improvement loop as an RL problem. Instead of binary filtering (correct/incorrect), ReST assigns a scalar reward to each generated rationale and optimizes using a policy gradient objective.

class ReSTTrainer:
    """
    Reinforced Self-Training.

    Two phases:
    1. Grow: Generate samples from the current policy
    2. Improve: Train on samples weighted by reward
    """

    def __init__(self, model, tokenizer, reward_fn, lr=1e-5):
        self.model = model
        self.tokenizer = tokenizer
        self.reward_fn = reward_fn
        self.lr = lr

    def grow_phase(self, problems, num_samples=16, temperature=0.8):
        """
        Phase 1: Generate multiple solutions for each problem.
        Score each solution with the reward function.
        """
        dataset = []

        for problem in problems:
            prompt = f"Question: {problem.question}\nSolution:\n"
            inputs = self.tokenizer(prompt, return_tensors="pt").to(
                self.model.device
            )

            solutions = []
            for _ in range(num_samples):
                with torch.no_grad():
                    output = self.model.generate(
                        **inputs,
                        max_new_tokens=512,
                        temperature=temperature,
                        do_sample=True,
                        top_p=0.95,
                    )

                generated = self.tokenizer.decode(
                    output[0][inputs['input_ids'].shape[1]:],
                    skip_special_tokens=True,
                )

                # Score with reward function
                reward = self.reward_fn(
                    problem.question, generated, problem.correct_answer
                )

                solutions.append({
                    'question': problem.question,
                    'solution': generated,
                    'reward': reward,
                })

            dataset.extend(solutions)

        return dataset

    def improve_phase(self, dataset, threshold_percentile=70,
                      epochs=3, batch_size=8):
        """
        Phase 2: Train on high-reward solutions.

        Options:
        1. Binary: train on solutions above reward threshold
        2. Weighted: weight loss by reward value
        3. Contrastive: pair high-reward and low-reward solutions
        """
        # Compute reward threshold (e.g., 70th percentile)
        rewards = [d['reward'] for d in dataset]
        threshold = sorted(rewards)[int(len(rewards) * threshold_percentile / 100)]

        # Filter to high-reward solutions
        good_solutions = [d for d in dataset if d['reward'] >= threshold]

        print(f"  Grow phase: {len(dataset)} solutions")
        print(f"  Reward threshold (p{threshold_percentile}): {threshold:.3f}")
        print(f"  Training on {len(good_solutions)} solutions "
              f"({100 * len(good_solutions) / len(dataset):.0f}%)")

        # Train
        optimizer = self.optimizer_cls(self.model.parameters(), lr=self.lr)
        self.model.train()

        for epoch in range(epochs):
            random.shuffle(good_solutions)
            total_loss = 0.0
            n_batches = 0

            for i in range(0, len(good_solutions), batch_size):
                batch = good_solutions[i:i + batch_size]

                # Format as training texts
                texts = [
                    f"Question: {d['question']}\nSolution:\n{d['solution']}"
                    for d in batch
                ]
                encoded = self.tokenizer(
                    texts, return_tensors="pt", truncation=True,
                    max_length=1024, padding=True,
                )
                input_ids = encoded['input_ids'].to(self.model.device)
                attention_mask = encoded['attention_mask'].to(self.model.device)

                # Optional: weight loss by reward
                rewards_tensor = torch.tensor(
                    [d['reward'] for d in batch],
                    device=self.model.device,
                )

                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=input_ids,
                )

                # Reward-weighted loss
                # Standard loss is averaged over tokens;
                # multiply by per-example reward
                loss = outputs.loss * rewards_tensor.mean()

                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()

                total_loss += loss.item()
                n_batches += 1

            print(f"  Epoch {epoch + 1}: loss = {total_loss / n_batches:.4f}")

    def run_iteration(self, problems, num_samples=16, epochs=3):
        """Run one Grow-Improve iteration."""
        print("Grow phase...")
        dataset = self.grow_phase(problems, num_samples)

        print("Improve phase...")
        self.improve_phase(dataset, epochs=epochs)

        return dataset

    optimizer_cls = torch.optim.AdamW
📊

STaR vs ReST Comparison

AspectSTaRReST
Selection signal Binary (correct/incorrect) Scalar reward
Hint mechanism Rationalization from hint No hint (higher threshold)
Training objective Cross-entropy on correct rationales Reward-weighted cross-entropy
Samples per problem 8 typical 16-64 typical
Best use case Tasks with verifiable answers Tasks with graded quality
Collapse risk Moderate (hint distribution shift) Higher (reward hacking)
Note: STaR is simpler and works well when answers are easily verifiable (math, code). ReST is more flexible but requires a reliable reward function.

The Data Flywheel

How Self-Improvement Creates Exponential Gains

class DataFlywheel:
    """
    Model the self-improvement dynamics.

    Each iteration:
    1. Better model generates more correct rationales
    2. More correct rationales = better training data
    3. Better training data = better model (next iteration)

    This compounds until hitting a ceiling.
    """

    def __init__(self, initial_accuracy=0.35, ceiling=0.85):
        self.accuracy = initial_accuracy
        self.ceiling = ceiling
        self.history = [initial_accuracy]

    def simulate_iteration(self, num_samples=8):
        """
        Simulate one STaR iteration.

        The improvement per iteration depends on:
        - Current accuracy (determines quality of generated data)
        - Number of samples (more samples = more likely to find correct ones)
        - Distance from ceiling (diminishing returns)
        """
        # Probability of at least one correct rationale in N samples
        p_at_least_one = 1 - (1 - self.accuracy) ** num_samples

        # Training data quality: fraction of problems with correct rationales
        data_quality = p_at_least_one

        # Improvement: proportional to data quality and distance from ceiling
        headroom = self.ceiling - self.accuracy
        improvement = data_quality * headroom * 0.2  # 20% of possible improvement

        self.accuracy = min(self.ceiling, self.accuracy + improvement)
        self.history.append(self.accuracy)

        return {
            'accuracy': self.accuracy,
            'p_at_least_one_correct': p_at_least_one,
            'data_quality': data_quality,
            'improvement': improvement,
        }

    def simulate_full(self, iterations=10, num_samples=8):
        """Simulate multiple iterations."""
        results = []
        for i in range(iterations):
            result = self.simulate_iteration(num_samples)
            results.append(result)
        return results

Why the Flywheel Stalls

The flywheel hits diminishing returns for three reasons:

FLYWHEEL_STALL_REASONS = {
    "distribution_narrowing": {
        "description": (
            "Each iteration, the model only trains on problems it can solve. "
            "Hard problems are never represented in training data. "
            "The model gets very good at easy problems but never improves "
            "on hard ones."
        ),
        "mitigation": "Hint rationalization (STaR) or curriculum sampling.",
    },
    "mode_collapse": {
        "description": (
            "The model converges to a single reasoning style. "
            "It generates the same chain-of-thought pattern for all problems, "
            "even when that pattern is suboptimal."
        ),
        "mitigation": "Temperature diversity, multiple starting points, "
                      "reward for diversity.",
    },
    "error_amplification": {
        "description": (
            "Correct answers can have incorrect reasoning. "
            "The model learns to produce answers that happen to be right "
            "for wrong reasons. These errors compound across iterations."
        ),
        "mitigation": "Process reward models (score individual steps, "
                      "not just final answers).",
    },
}
⚠️ Warning

The most insidious failure mode is error amplification with correct answers. On GSM8K, approximately 15-20% of “correct” rationales contain arithmetic errors that cancel out or reasoning steps that skip critical logic. Training on these rationales teaches the model to produce plausible-looking but logically flawed reasoning. After 5+ iterations, the model’s reasoning becomes increasingly superficial even as its accuracy plateaus.

Model Collapse: Detection and Prevention

What Collapse Looks Like

class CollapseDetector:
    """
    Detect early signs of model collapse in self-improvement loops.
    """

    def __init__(self):
        self.iteration_metrics = []

    def measure_diversity(self, generated_rationales):
        """
        Measure diversity of generated rationales.
        Collapse = all rationales look the same.
        """
        from collections import Counter
        import numpy as np

        # Metric 1: Unique n-gram ratio
        all_trigrams = []
        for rationale in generated_rationales:
            words = rationale.lower().split()
            trigrams = [tuple(words[i:i + 3]) for i in range(len(words) - 2)]
            all_trigrams.extend(trigrams)

        unique_ratio = len(set(all_trigrams)) / max(len(all_trigrams), 1)

        # Metric 2: Average pairwise similarity (lower = more diverse)
        from difflib import SequenceMatcher
        similarities = []
        sample_size = min(50, len(generated_rationales))
        sample = random.sample(generated_rationales, sample_size)
        for i in range(len(sample)):
            for j in range(i + 1, len(sample)):
                sim = SequenceMatcher(None, sample[i], sample[j]).ratio()
                similarities.append(sim)

        avg_similarity = np.mean(similarities) if similarities else 0.0

        # Metric 3: Vocabulary size relative to text length
        all_words = ' '.join(generated_rationales).split()
        vocab_ratio = len(set(all_words)) / max(len(all_words), 1)

        return {
            'unique_trigram_ratio': unique_ratio,
            'avg_pairwise_similarity': avg_similarity,
            'vocab_ratio': vocab_ratio,
        }

    def measure_difficulty_coverage(self, problems, success_mask):
        """
        Check if the model is solving diverse difficulty levels.
        Collapse = only solving easy problems.
        """
        difficulty_success = {}
        for problem, success in zip(problems, success_mask):
            d = problem.difficulty
            if d not in difficulty_success:
                difficulty_success[d] = {'correct': 0, 'total': 0}
            difficulty_success[d]['total'] += 1
            if success:
                difficulty_success[d]['correct'] += 1

        for d in difficulty_success:
            stats = difficulty_success[d]
            stats['accuracy'] = stats['correct'] / max(stats['total'], 1)

        return difficulty_success

    def check_collapse(self, iteration_num, diversity_metrics,
                       difficulty_coverage, accuracy):
        """
        Determine if collapse is occurring.
        """
        alerts = []

        # Check diversity
        if diversity_metrics['unique_trigram_ratio'] < 0.3:
            alerts.append("LOW_DIVERSITY: unique trigram ratio below 0.3")
        if diversity_metrics['avg_pairwise_similarity'] > 0.7:
            alerts.append("HIGH_SIMILARITY: average pairwise similarity above 0.7")

        # Check difficulty coverage
        hard_accuracy = difficulty_coverage.get('hard', {}).get('accuracy', 0)
        easy_accuracy = difficulty_coverage.get('easy', {}).get('accuracy', 0)
        if hard_accuracy < 0.1 and easy_accuracy > 0.8:
            alerts.append("DIFFICULTY_COLLAPSE: solving easy but not hard problems")

        # Check accuracy plateau
        if len(self.iteration_metrics) >= 3:
            recent = [m['accuracy'] for m in self.iteration_metrics[-3:]]
            if max(recent) - min(recent) < 0.005:
                alerts.append("ACCURACY_PLATEAU: no improvement in 3 iterations")

        self.iteration_metrics.append({
            'iteration': iteration_num,
            'accuracy': accuracy,
            'diversity': diversity_metrics,
        })

        return alerts

Mitigation Strategies

class CollapseMitigation:
    """
    Strategies to prevent model collapse in self-improvement loops.
    """

    @staticmethod
    def data_mixing(synthetic_data, original_data, synthetic_fraction=0.5):
        """
        Mix synthetic data with original training data.
        Prevents distribution drift by anchoring to real data.
        """
        num_synthetic = int(len(original_data) * synthetic_fraction /
                           (1 - synthetic_fraction))
        num_synthetic = min(num_synthetic, len(synthetic_data))
        sampled_synthetic = random.sample(synthetic_data, num_synthetic)
        mixed = original_data + sampled_synthetic
        random.shuffle(mixed)
        return mixed

    @staticmethod
    def temperature_curriculum(iteration, max_iterations=10):
        """
        Increase temperature over iterations to maintain diversity.
        Early iterations: low temperature for quality.
        Later iterations: higher temperature to explore.
        """
        base_temp = 0.5
        max_temp = 1.2
        progress = iteration / max_iterations
        return base_temp + (max_temp - base_temp) * progress

    @staticmethod
    def process_reward_filtering(rationales, process_reward_model):
        """
        Filter by reasoning process quality, not just final answer.
        Prevents training on correct-answer-wrong-reasoning examples.
        """
        filtered = []
        for r in rationales:
            if not r.correct:
                continue

            # Score each reasoning step
            steps = r.chain_of_thought.split('\n')
            step_scores = []
            for i, step in enumerate(steps):
                context = '\n'.join(steps[:i + 1])
                score = process_reward_model.score_step(
                    r.question, context, step
                )
                step_scores.append(score)

            # Keep only rationales where ALL steps score above threshold
            min_step_score = min(step_scores) if step_scores else 0
            if min_step_score >= 0.5:
                filtered.append(r)

        return filtered

    @staticmethod
    def difficulty_reweighting(problems, current_accuracy_by_difficulty):
        """
        Oversample hard problems to prevent difficulty collapse.
        """
        reweighted = []
        for problem in problems:
            difficulty = problem.difficulty
            current_acc = current_accuracy_by_difficulty.get(difficulty, 0.5)
            # Higher weight for lower accuracy (harder problems)
            weight = max(0.1, 1.0 - current_acc)
            # Sample with replacement proportional to weight
            if random.random() < weight:
                reweighted.append(problem)

        return reweighted

Effect of Mitigation Strategies on GSM8K Accuracy

Metric BaseIter 1Iter 2Iter 3Iter 4Iter 5Iter 8Iter 10
No Mitigation
35.2
48.7
55.3
57.1
57.5
57.3
56.8
55.2
Data Mixing (50%)
35.2
47.5
54.1
58.2
61
62.8
64.5
65.1
Process Reward + Mixing
35.2
46.8
53.7
58.9
62.5
65.3
68.1
69.5

Reward Function Design for Self-Training

Types of Reward Signals

class RewardFunctions:
    """
    Different reward functions for self-training.
    The reward function determines what the model optimizes for.
    """

    @staticmethod
    def binary_correctness(question, solution, correct_answer):
        """
        Simplest reward: 1 if correct, 0 if wrong.
        Used by basic STaR.
        """
        predicted = extract_final_answer(solution)
        return 1.0 if answers_match(predicted, correct_answer) else 0.0

    @staticmethod
    def soft_correctness(question, solution, correct_answer):
        """
        Soft reward based on partial correctness.
        Useful for multi-step problems.
        """
        predicted = extract_final_answer(solution)

        if answers_match(predicted, correct_answer):
            return 1.0

        # Partial credit for being close
        try:
            pred_val = float(predicted)
            correct_val = float(correct_answer)
            relative_error = abs(pred_val - correct_val) / max(abs(correct_val), 1e-8)
            return max(0, 1.0 - relative_error)
        except (ValueError, TypeError):
            return 0.0

    @staticmethod
    def process_reward(question, solution, correct_answer, prm_model):
        """
        Process Reward Model: score the reasoning process.
        Rewards correct reasoning steps, not just correct answers.
        """
        # Score each step
        steps = solution.split('\n')
        step_scores = []
        for i, step in enumerate(steps):
            context = '\n'.join(steps[:i + 1])
            score = prm_model.score_step(question, context, step)
            step_scores.append(score)

        if not step_scores:
            return 0.0

        # Combine: minimum step score * correctness
        correctness = 1.0 if answers_match(
            extract_final_answer(solution), correct_answer
        ) else 0.0

        process_quality = min(step_scores)
        return 0.7 * correctness + 0.3 * process_quality

def extract_final_answer(solution):
    """Extract the final numerical or text answer from a solution."""
    lines = solution.strip().split('\n')
    for line in reversed(lines):
        line = line.strip()
        if line and any(c.isdigit() for c in line):
            # Extract number
            import re
            numbers = re.findall(r'-?\d+\.?\d*', line)
            if numbers:
                return numbers[-1]
    return solution.strip().split('\n')[-1] if solution.strip() else ""

def answers_match(predicted, correct):
    """Check if two answers match (handling numeric comparison)."""
    pred = str(predicted).strip().lower()
    corr = str(correct).strip().lower()

    if pred == corr:
        return True

    try:
        return abs(float(pred) - float(corr)) < 1e-6
    except (ValueError, TypeError):
        return False

Production Considerations

Compute Budget

COMPUTE_BUDGET = {
    "star_7b_gsm8k": {
        "model_size": "7B",
        "dataset": "GSM8K (8.5K problems)",
        "samples_per_problem": 8,
        "iterations": 5,
        "generation_cost_per_iter": "~2 H100-hours (68K generations)",
        "training_cost_per_iter": "~0.5 H100-hours (3 epochs on 6K examples)",
        "total_cost": "~12.5 H100-hours = $37.50",
        "accuracy_improvement": "35% -> 63%",
    },
    "rest_70b_math": {
        "model_size": "70B",
        "dataset": "MATH (12.5K problems)",
        "samples_per_problem": 64,
        "iterations": 3,
        "generation_cost_per_iter": "~160 H100-hours (800K generations)",
        "training_cost_per_iter": "~20 H100-hours (3 epochs on 50K examples)",
        "total_cost": "~540 H100-hours = $1,620",
        "accuracy_improvement": "42% -> 58%",
    },
}
📊

Self-Training Compute Efficiency

MethodModelDatasetGPU-HoursAccuracy GainGain/GPU-Hour
STaR 7B GSM8K 12.5 +28% 2.24%/hr
ReST 7B GSM8K 18 +25% 1.39%/hr
STaR 70B GSM8K 200 +18% 0.09%/hr
ReST 70B MATH 540 +16% 0.03%/hr
SFT baseline 7B GSM8K 5 +20% 4.00%/hr
Note: STaR is most efficient on smaller models where generation is cheap. For larger models, the generation cost dominates. SFT on human-written solutions is more efficient per GPU-hour but requires expensive human annotation.

Self-improving systems are not a replacement for high-quality human data. They are a multiplier: starting from a small amount of human-verified reasoning data, STaR and ReST can bootstrap 2-3x the model’s initial reasoning capability at low compute cost. The key constraint is collapse prevention. Without data mixing, process reward filtering, and diversity maintenance, the self-improvement loop degrades after 3-5 iterations. With these mitigations, the loop sustains improvement for 8-10 iterations before hitting the ceiling imposed by the base model’s representation capacity.