Why GRPO Over PPO
PPO (Proximal Policy Optimization) requires four models in memory simultaneously:
- Policy model (the model being trained): 140 GB for 70B FP16
- Reference model (frozen copy of the initial policy): 140 GB
- Reward model (scores outputs): 140 GB
- Value model (critic, predicts expected return): 140 GB
Total: 560 GB for a 70B model. Requires 8 H100 GPUs just for model weights.
GRPO eliminates the value model entirely by using within-group comparisons as the baseline:
- Policy model: 140 GB
- Reference model: 140 GB (can be quantized to 35 GB)
- Reward model: 140 GB (can be external API)
Total: 280-315 GB β nearly half the memory of PPO.
The GRPO Algorithm
For a prompt , generate completions from the current policy . Compute rewards . The GRPO loss:
where is the importance ratio, is the group-relative advantage, and is the KL penalty coefficient.
Step 1: Group Sampling
For each prompt, generate completions (typically to ):
def group_sample(model, prompt, K=8, max_tokens=2048, temperature=1.0):
"""Generate K completions for one prompt."""
completions = []
log_probs = []
for _ in range(K):
tokens, lp = model.generate_with_logprobs(
prompt,
max_new_tokens=max_tokens,
temperature=temperature,
)
completions.append(tokens)
log_probs.append(lp) # Sum of log P(token_t | tokens_0..t-1)
return completions, log_probs
Step 2: Compute Rewards
Score each completion with the reward model:
def compute_rewards(reward_model, prompt, completions):
"""Score each completion. Higher = better."""
rewards = []
for completion in completions:
# ORM: score based on final answer correctness
# PRM: score based on step-by-step reasoning quality
r = reward_model.score(prompt, completion)
rewards.append(r)
return torch.tensor(rewards)
Step 3: Group-Relative Advantage
The key GRPO innovation β no value model needed. The advantage is computed relative to the group:
def compute_group_advantage(rewards):
"""
Compute advantage relative to group mean/std.
This replaces the PPO critic (value model).
rewards: [K] β reward for each completion in the group
Returns: [K] β normalized advantages
"""
mean = rewards.mean()
std = rewards.std()
if std < 1e-8:
# All rewards equal β no signal
return torch.zeros_like(rewards)
advantages = (rewards - mean) / (std + 1e-8)
return advantages
In PPO, the advantage is where is the value modelβs prediction. In GRPO, the advantage is . The group mean replaces the value model β it is an unbiased estimator of the expected reward. The group std normalizes the scale. This works because with samples, the group statistics are reliable enough for stable training.
Step 4: Policy Gradient with Clipping
The clipped surrogate loss prevents the policy from changing too much in one update:
def grpo_loss(
policy_model,
ref_model,
prompts,
completions,
advantages,
old_log_probs,
clip_eps=0.2,
kl_coeff=0.01,
):
"""
Compute GRPO loss for a batch of prompt-completion pairs.
policy_model: current policy being optimized
ref_model: frozen reference (initial SFT checkpoint)
prompts: list of prompt token sequences
completions: list of completion token sequences
advantages: [batch] normalized advantages
old_log_probs: [batch] log probs under policy at sampling time
"""
total_loss = torch.tensor(0.0, device="cuda")
for i in range(len(prompts)):
# Compute current log probability
input_ids = torch.cat([prompts[i], completions[i]])
current_log_prob = policy_model.compute_log_prob(
input_ids, completion_start=len(prompts[i])
)
# Compute reference log probability (for KL penalty)
with torch.no_grad():
ref_log_prob = ref_model.compute_log_prob(
input_ids, completion_start=len(prompts[i])
)
# Importance ratio
ratio = torch.exp(current_log_prob - old_log_probs[i])
# Clipped surrogate
clipped_ratio = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps)
surrogate = torch.min(ratio * advantages[i], clipped_ratio * advantages[i])
# KL divergence penalty (keeps policy close to reference)
kl = current_log_prob - ref_log_prob
# Combined loss (negative because we maximize the surrogate)
total_loss += -surrogate + kl_coeff * kl
return total_loss / len(prompts)
Step 5: Complete Training Loop
def train_grpo(
policy_model,
ref_model,
reward_model,
train_prompts,
K=8,
num_epochs=3,
batch_size=4,
lr=1e-6,
clip_eps=0.2,
kl_coeff=0.01,
):
"""Complete GRPO training loop."""
optimizer = torch.optim.AdamW(policy_model.parameters(), lr=lr)
for epoch in range(num_epochs):
for batch_start in range(0, len(train_prompts), batch_size):
batch_prompts = train_prompts[batch_start:batch_start + batch_size]
all_completions = []
all_advantages = []
all_old_log_probs = []
all_prompts_expanded = []
for prompt in batch_prompts:
# Step 1: Generate K completions
completions, log_probs = group_sample(
policy_model, prompt, K=K
)
# Step 2: Compute rewards
rewards = compute_rewards(reward_model, prompt, completions)
# Step 3: Compute group-relative advantages
advantages = compute_group_advantage(rewards)
# Store for training
for j in range(K):
all_prompts_expanded.append(prompt)
all_completions.append(completions[j])
all_advantages.append(advantages[j])
all_old_log_probs.append(log_probs[j])
# Step 4: Policy gradient update
advantages_tensor = torch.stack(all_advantages)
old_lp_tensor = torch.stack(all_old_log_probs)
loss = grpo_loss(
policy_model, ref_model,
all_prompts_expanded, all_completions,
advantages_tensor, old_lp_tensor,
clip_eps=clip_eps, kl_coeff=kl_coeff,
)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy_model.parameters(), 1.0)
optimizer.step()
# Logging
mean_reward = rewards.mean().item()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}, "
f"Mean Reward: {mean_reward:.4f}")
GRPO Training Hyperparameters (DeepSeek-R1 Style)
| Parameter | Value | Rationale |
|---|---|---|
| K (group size) | 8-16 | Larger K = more stable advantages, more compute |
| clip_eps | 0.2 | Standard PPO clipping, prevents large updates |
| kl_coeff (beta) | 0.01-0.05 | Higher = stay closer to reference, less exploration |
| Learning rate | 1e-6 to 5e-7 | Much lower than SFT (prevent catastrophic forgetting) |
| Max tokens per completion | 2048-8192 | Reasoning traces can be long |
| Temperature | 1.0 | Full stochasticity for diverse group samples |
| Gradient clipping | 1.0 | Prevents RL training instability |
The DeepSeek-R1 Three-Stage Recipe
- Stage 1: SFT β Fine-tune base model on high-quality reasoning traces. This gives the model the format of thinking.
- Stage 2: GRPO β Run GRPO with a math/code reward model. This teaches the model to reason correctly (not just format correctly).
- Stage 3: Rejection Sampling + SFT β Generate many GRPO solutions, keep the best (correct + clean), fine-tune on those. This βdistillsβ the RL policy into a cleaner model.