A reasoning modelβs quality is bounded by its reward model. If the reward model cannot distinguish good reasoning from bad, no amount of RL training or test-time search will produce better answers. This post covers the two reward model paradigms (outcome-based and process-based), their architectures, training data requirements, and implementation.
Outcome Reward Models (ORM)
An ORM scores the final answer only. Given a problem and a candidate solution (including reasoning trace), the ORM produces:
The ORM is trained on (problem, solution, correctness_label) triples. It learns to predict whether the final answer is correct, regardless of the reasoning path.
import torch
import torch.nn as nn
class OutcomeRewardModel(nn.Module):
"""ORM: scores the final answer, ignoring intermediate reasoning."""
def __init__(self, base_model, hidden_dim=4096):
super().__init__()
self.base_model = base_model # Pretrained LLM backbone
self.reward_head = nn.Linear(hidden_dim, 1)
def forward(self, input_ids, attention_mask):
# Get last hidden state at the final token position
outputs = self.base_model(input_ids, attention_mask=attention_mask)
last_hidden = outputs.last_hidden_state[:, -1, :] # [B, hidden_dim]
reward = torch.sigmoid(self.reward_head(last_hidden)) # [B, 1]
return reward.squeeze(-1)
Training data: For math, generate many solutions per problem, check correctness against ground truth. Label: 1 if final answer matches, 0 otherwise. For code, run tests. For general tasks, use human annotations.
Limitation: ORM cannot distinguish a correct answer reached by flawed reasoning from one reached by sound reasoning. A model that memorizes βthe answer to this type of problem is always 7β gets the same reward as one that derives the answer step by step.
Process Reward Models (PRM)
A PRM scores each step of the reasoning. Given problem and reasoning steps :
Each step gets a score between 0 and 1. The overall score is the product β if any step is wrong (score near 0), the total score collapses.
class ProcessRewardModel(nn.Module):
"""PRM: scores each reasoning step independently."""
def __init__(self, base_model, hidden_dim=4096):
super().__init__()
self.base_model = base_model
self.step_head = nn.Linear(hidden_dim, 1)
def forward(self, input_ids, attention_mask, step_boundaries):
"""
step_boundaries: list of token indices where each step ends
Returns: per-step reward scores
"""
outputs = self.base_model(input_ids, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state # [B, seq_len, hidden_dim]
step_scores = []
for boundary in step_boundaries:
step_hidden = hidden_states[:, boundary, :] # [B, hidden_dim]
score = torch.sigmoid(self.step_head(step_hidden))
step_scores.append(score.squeeze(-1))
return torch.stack(step_scores, dim=-1) # [B, num_steps]
Training data: Much harder to obtain. Requires step-level annotations: βstep 3 introduced an error.β Options: (1) Human annotators mark each step (expensive, OpenAIβs PRM800K dataset), (2) Automated: check if the answer is still reachable after each step using Monte Carlo sampling β generate many completions from that step, measure the fraction that reach the correct answer.
ORM vs PRM Comparison
| Property | ORM | PRM |
|---|---|---|
| Scores | Final answer only | Each reasoning step |
| Training data | Easy (auto-check answers) | Hard (step-level labels) |
| Best-of-N quality (N=8) | +8% on MATH | +12% on MATH |
| Reward hacking risk | High (correct answer, wrong reason) | Low (each step checked) |
| Compute at inference | 1 forward pass per solution | 1 forward pass, but read at each step |
| Used by | Many open models | OpenAI (o1), DeepSeek-R1 |
On MATH-500 with best-of-8 selection: ORM improves accuracy from 71% to 79% (+8 points). PRM improves from 71% to 83% (+12 points). The 4-point gap between ORM and PRM represents the value of step-level feedback. For high-stakes applications (math competitions, code generation), PRM is strictly better.
Reward-Guided Search
At inference time, the reward model guides which reasoning paths to explore:
Best-of-N
Generate complete solutions. Score each with the reward model. Return the highest-scoring one.
def best_of_n(model, reward_model, prompt, n=8, max_tokens=4096):
"""Generate N solutions, return the one with highest reward."""
candidates = []
for _ in range(n):
solution = model.generate(prompt, max_tokens=max_tokens,
temperature=0.7, top_p=0.95)
score = reward_model.score(prompt + solution)
candidates.append((score, solution))
candidates.sort(key=lambda x: x[0], reverse=True)
return candidates[0][1] # Highest-scoring solution
Cost: times the generation cost. Parallelizable across GPUs.
Beam Search with PRM
Use the PRM to prune bad reasoning paths early:
def prm_beam_search(model, prm, prompt, beam_width=4, max_steps=20):
"""Beam search guided by process reward model."""
beams = [(1.0, prompt)] # (cumulative_score, text_so_far)
for step in range(max_steps):
candidates = []
for score, text in beams:
# Generate next reasoning step (up to next newline)
for _ in range(beam_width):
next_step = model.generate(text, max_tokens=200,
stop=["\n\n"])
step_score = prm.score_step(text + next_step)
candidates.append((score * step_score, text + next_step))
# Keep top beam_width candidates
candidates.sort(key=lambda x: x[0], reverse=True)
beams = candidates[:beam_width]
# Check if any beam has reached a final answer
for score, text in beams:
if "Final Answer:" in text:
return text
return beams[0][1] # Best beam
Accuracy vs Compute: Different Search Strategies (MATH-500)
(% accuracy)PRM beam search achieves the highest accuracy because it prunes bad reasoning paths early, directing compute toward promising directions rather than generating complete solutions that may be fatally flawed from step 3.
Practical Reward Model Training
Step 1: Generate Training Data
def generate_orm_data(model, problems, solutions_per_problem=16):
"""Generate ORM training data: (problem, solution, label)."""
dataset = []
for problem in problems:
for _ in range(solutions_per_problem):
solution = model.generate(problem, temperature=1.0)
answer = extract_answer(solution)
correct = check_answer(answer, problem.ground_truth)
dataset.append({
"problem": problem.text,
"solution": solution,
"label": 1.0 if correct else 0.0,
})
return dataset
Step 2: Train with Binary Cross-Entropy
def train_orm(reward_model, dataset, epochs=3, lr=1e-5):
optimizer = torch.optim.AdamW(reward_model.parameters(), lr=lr)
loss_fn = nn.BCELoss()
for epoch in range(epochs):
for batch in dataloader(dataset, batch_size=32):
scores = reward_model(batch.input_ids, batch.attention_mask)
loss = loss_fn(scores, batch.labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
PRM Training with Monte Carlo Labels
For PRM, we need step-level labels. The Monte Carlo method:
def generate_prm_labels(model, problem, num_rollouts=32):
"""Label each step by what fraction of completions from that step succeed."""
full_solution = model.generate(problem, temperature=0.0)
steps = split_into_steps(full_solution)
step_labels = []
for i, step in enumerate(steps):
partial = problem + "".join(steps[:i+1])
# Complete from this point multiple times
successes = 0
for _ in range(num_rollouts):
completion = model.generate(partial, temperature=0.7)
if check_answer(extract_answer(partial + completion),
problem.ground_truth):
successes += 1
step_labels.append(successes / num_rollouts)
return step_labels # e.g., [0.95, 0.88, 0.72, 0.05, 0.02]
# Step 4 introduced an error (quality drops from 72% to 5%)
For each step of each solution, you generate 32 completions. A 10-step solution needs 320 completions. Across 10K problems x 16 solutions each: 51.2 million generations. At 500 tokens per generation on an H100: ~14 GPU-hours. This is feasible but non-trivial. Frontier labs amortize this cost because the PRM dramatically improves downstream model quality.
Reviewer Agent Validation
Challenge: Using only this post, implement a function that takes a list of reasoning step scores from a PRM and determines which step (if any) introduced the first critical error (score drop of more than 50%).
Expected:
def find_error_step(step_scores):
"""Find the first step where quality drops dramatically."""
for i in range(1, len(step_scores)):
if step_scores[i] < step_scores[i-1] * 0.5:
return i # This step introduced the error
return -1 # No critical error found
# Example: [0.95, 0.88, 0.72, 0.05, 0.02]
# Returns: 3 (step_scores[3]=0.05 < step_scores[2]*0.5=0.36)