Chain-of-thought prompting, best-of-N sampling, and tree search all share one property: the model’s weights are frozen during inference. The reasoning strategy is static. A model that struggles with algebraic manipulation will continue to struggle — it can try different paths, but it cannot improve its underlying capability mid-generation.

Policy of Thoughts (PoT) breaks this constraint. During a single inference pass, PoT maintains a lightweight transient adapter (a LoRA module with rank 4-16) that is updated based on reward signals computed from the model’s own intermediate outputs. The model literally learns to reason better on the current problem while solving it.

Static Reasoning and Its Limits

Standard test-time compute scaling works by generating more tokens or exploring more paths. Consider best-of-N with a Process Reward Model (PRM):

def best_of_n(model, prompt, n=64, prm=None):
    """Standard best-of-N: frozen weights, multiple samples."""
    candidates = []
    for _ in range(n):
        reasoning_trace = model.generate(prompt, temperature=0.7)
        score = prm.score(prompt, reasoning_trace) if prm else heuristic_score(reasoning_trace)
        candidates.append((reasoning_trace, score))
    return max(candidates, key=lambda x: x[1])

This generates 64 independent samples from the same distribution. If the model assigns low probability to the correct reasoning path, even 64 samples may not find it. The distribution is fixed — sampling more does not shift probability mass toward better strategies.

📊

Best-of-N Diminishing Returns (Llama 8B, MATH-500)

N (samples)AccuracyMarginal GainTotal FLOPs (relative)
1 34.2% 1x
4 48.1% +13.9% 4x
16 55.8% +7.7% 16x
64 59.3% +3.5% 64x
256 61.0% +1.7% 256x
1024 61.9% +0.9% 1024x
Note: Accuracy plateaus because the model's distribution does not cover the correct solution for 38% of problems.

Beyond N64N \approx 64, gains are negligible. The model’s frozen distribution simply does not assign enough probability to the correct solution path for the remaining problems. Spending 1024x compute for a 2.6% gain over N=16N = 16 is wasteful.

The core question PoT addresses: instead of sampling more from a fixed distribution, what if we shift the distribution itself during inference?

The PoT Mechanism

PoT augments a pretrained model with a transient LoRA adapter that exists only for the duration of a single query. The adapter starts at zero (no effect on the base model) and is updated iteratively as the model generates reasoning steps.

Σ Theorem: PoT Update Rule

Let θ\theta be frozen base model weights and Δθt\Delta\theta_t be the transient adapter at reasoning step tt. The model generates token xtx_t using effective weights θ+Δθt\theta + \Delta\theta_t.

After generating a reasoning segment St=(xt0,,xt1)S_t = (x_{t_0}, \ldots, x_{t_1}), compute reward rt=R(StS1:t1,prompt)r_t = R(S_t \mid S_{1:t-1}, \text{prompt}).

Update the adapter:

Δθt+1=Δθt+ηΔθlogp(StS1:t1,prompt;θ+Δθt)rt\Delta\theta_{t+1} = \Delta\theta_t + \eta \cdot \nabla_{\Delta\theta} \log p(S_t \mid S_{1:t-1}, \text{prompt}; \theta + \Delta\theta_t) \cdot r_t

This is a single-sample REINFORCE update. The adapter accumulates a policy gradient that biases the model toward reasoning patterns that receive high reward on the current problem.

The key properties:

  1. Zero initialization: Δθ0=0\Delta\theta_0 = 0. The first reasoning step uses the unmodified base model.
  2. Transient: The adapter is discarded after the query completes. No persistent weight changes.
  3. Low rank: Rank 4-16 LoRA keeps the adapter small (0.1-0.5% of base model parameters) so updates are fast.
  4. Cumulative: Each update builds on previous updates. The adapter accumulates problem-specific reasoning improvements.

The LoRA Adapter Structure

import torch
import torch.nn as nn
import math

class TransientLoRA(nn.Module):
    """Transient LoRA adapter for Policy of Thoughts.

    Initialized to zero so the first forward pass
    is identical to the base model.
    """
    def __init__(self, in_dim, out_dim, rank=8, alpha=16.0):
        super().__init__()
        self.rank = rank
        self.scaling = alpha / rank

        # A is initialized with small random values for gradient flow
        # B is initialized to zero so the initial output is zero
        self.A = nn.Parameter(torch.randn(in_dim, rank) * 0.01)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))

        # Momentum buffers for the optimizer
        self.m_A = torch.zeros_like(self.A)
        self.m_B = torch.zeros_like(self.B)
        self.v_A = torch.zeros_like(self.A)
        self.v_B = torch.zeros_like(self.B)
        self.step_count = 0

    def forward(self, x):
        # x: [batch, seq, in_dim]
        # Output: [batch, seq, out_dim]
        return (x @ self.A @ self.B) * self.scaling

    def reset(self):
        """Reset adapter to zero for a new query."""
        nn.init.normal_(self.A, std=0.01)
        nn.init.zeros_(self.B)
        self.m_A.zero_()
        self.m_B.zero_()
        self.v_A.zero_()
        self.v_B.zero_()
        self.step_count = 0

The B-zero initialization is critical. At t=0t = 0, self.A @ self.B is all zeros, so the adapter contributes nothing. The base model operates unmodified. After the first update, B acquires non-zero values and the adapter begins influencing generation.

Where to Insert the Adapter

PoT applies transient LoRA to the attention query/value projections in every kk-th layer (default: every 4th layer). This provides a reasonable balance between expressiveness and update cost.

class PoTModel(nn.Module):
    """Base model augmented with transient LoRA adapters."""

    def __init__(self, base_model, lora_rank=8, lora_alpha=16.0, lora_every_n=4):
        super().__init__()
        self.base = base_model
        self.adapters = nn.ModuleDict()

        for i, layer in enumerate(base_model.layers):
            if i % lora_every_n == 0:
                d = layer.self_attn.q_proj.in_features
                self.adapters[f"layer_{i}_q"] = TransientLoRA(d, d, lora_rank, lora_alpha)
                self.adapters[f"layer_{i}_v"] = TransientLoRA(d, d, lora_rank, lora_alpha)

    def forward(self, input_ids, use_adapter=True):
        hidden = self.base.embed(input_ids)

        for i, layer in enumerate(self.base.layers):
            key = f"layer_{i}_q"
            if use_adapter and key in self.adapters:
                # Inject adapter output into Q and V projections
                q_adapter = self.adapters[f"layer_{i}_q"]
                v_adapter = self.adapters[f"layer_{i}_v"]

                orig_q = layer.self_attn.q_proj(hidden)
                orig_v = layer.self_attn.v_proj(hidden)

                q = orig_q + q_adapter(hidden)
                v = orig_v + v_adapter(hidden)

                hidden = layer.forward_with_qv(hidden, q=q, v=v)
            else:
                hidden = layer(hidden)

        logits = self.base.lm_head(hidden)
        return logits

    def reset_adapters(self):
        for adapter in self.adapters.values():
            adapter.reset()
ℹ️ Parameter Count

For a 7B model with 32 layers, inserting rank-8 LoRA into Q and V projections every 4th layer means 8 adapters x 2 (Q and V) = 16 adapter pairs. Each pair has 4096×8+8×4096=65,5364096 \times 8 + 8 \times 4096 = 65{,}536 parameters. Total: 16 x 65,536 = 1,048,576 parameters — 0.015% of the base model. The memory overhead is negligible.

The Reward Signal

The adapter update requires a reward signal rtr_t for each reasoning segment. PoT computes this from three sources: confidence scoring, coherence checking, and self-verification.

Confidence Scoring

After the model generates a reasoning segment, compute the average log-probability of the tokens in that segment. High confidence (high log-prob) suggests the model is on a familiar and likely-correct reasoning path. Low confidence suggests uncertainty or guessing.

def confidence_reward(model, tokens, segment_start, segment_end):
    """Compute confidence-based reward for a reasoning segment.

    Returns a value in [-1, 1] based on the model's own
    confidence in the tokens it generated.
    """
    with torch.no_grad():
        logits = model.forward(tokens[:segment_end])

    # Extract log-probs for the segment tokens
    segment_logits = logits[segment_start:segment_end]
    segment_tokens = tokens[segment_start + 1:segment_end + 1]

    log_probs = torch.log_softmax(segment_logits, dim=-1)
    token_log_probs = log_probs.gather(1, segment_tokens.unsqueeze(1)).squeeze(1)

    avg_log_prob = token_log_probs.mean().item()

    # Normalize to [-1, 1] range
    # Typical log-probs range from -10 (very uncertain) to -0.1 (very confident)
    normalized = (avg_log_prob + 5.0) / 5.0  # Centers around -5
    return max(-1.0, min(1.0, normalized))

Confidence alone is insufficient — a model can be confidently wrong. The other reward components address this.

Coherence Checking

Coherence measures whether the current reasoning step is logically consistent with previous steps. PoT implements this by checking whether the model can predict a summary of the previous steps given the current step.

def coherence_reward(model, full_context, current_segment):
    """Check if the current segment is coherent with previous reasoning.

    Uses the model itself to judge coherence by computing
    perplexity of previous conclusions given the current step.
    """
    # Extract key conclusions from previous steps
    # (tokens appearing after "therefore", "so", "thus", "=")
    conclusion_tokens = extract_conclusions(full_context[:-len(current_segment)])

    if not conclusion_tokens:
        return 0.0  # No previous conclusions to check against

    # Compute how well the current segment predicts previous conclusions
    with torch.no_grad():
        logits = model.forward(
            torch.cat([current_segment, conclusion_tokens])
        )

    # Check cross-entropy of conclusion tokens given current context
    pred_logits = logits[len(current_segment) - 1:-1]
    ce_loss = nn.functional.cross_entropy(
        pred_logits, conclusion_tokens, reduction='mean'
    ).item()

    # Low cross-entropy = coherent (model finds conclusions unsurprising)
    coherence = max(-1.0, min(1.0, (3.0 - ce_loss) / 3.0))
    return coherence

Self-Verification

For problems with checkable intermediate results (math equations, code correctness), PoT prompts the model to verify its own work.

def verification_reward(model, prompt, reasoning_so_far, current_segment):
    """Prompt the model to verify the current reasoning step.

    Appends a verification prompt and checks whether the model
    generates affirmation or rejection.
    """
    verify_prompt = (
        reasoning_so_far +
        current_segment +
        "\n\nVerification: Is the above step correct? Answer YES or NO.\n"
    )

    verify_tokens = tokenize(verify_prompt)
    with torch.no_grad():
        logits = model.forward(verify_tokens)

    # Check probability of YES vs NO token
    last_logits = logits[-1]
    yes_logit = last_logits[YES_TOKEN_ID]
    no_logit = last_logits[NO_TOKEN_ID]

    # Reward proportional to P(YES) - P(NO)
    probs = torch.softmax(torch.tensor([yes_logit, no_logit]), dim=0)
    reward = (probs[0] - probs[1]).item()  # Range: [-1, 1]
    return reward

Combined Reward

def compute_reward(model, prompt, reasoning_so_far, current_segment, tokens,
                   segment_start, segment_end):
    """Combine all reward signals with learned weights."""
    r_conf = confidence_reward(model, tokens, segment_start, segment_end)
    r_cohere = coherence_reward(model, reasoning_so_far, current_segment)
    r_verify = verification_reward(model, prompt, reasoning_so_far, current_segment)

    # Weighted combination (weights can be tuned per task)
    w_conf, w_cohere, w_verify = 0.2, 0.3, 0.5
    combined = w_conf * r_conf + w_cohere * r_cohere + w_verify * r_verify

    return combined
Reward Computation Cost

Each reward computation requires 1-2 extra forward passes through the model (one for confidence/coherence, one for verification). With reward computed every KK tokens (typically K=64-256K = 64\text{-}256), the overhead is 2 extra forward passes per KK tokens generated. For K=128K = 128, this is a 1.6% overhead on forward pass cost — small relative to the adapter update cost.

The Update Loop

The complete PoT generation loop: generate a segment, compute reward, update the adapter, continue generating with the improved adapter.

class PoTGenerator:
    """Complete Policy of Thoughts generation pipeline."""

    def __init__(self, model, segment_size=128, lr=1e-4, max_updates=32):
        self.model = model
        self.segment_size = segment_size
        self.lr = lr
        self.max_updates = max_updates

    def generate(self, prompt, max_tokens=4096):
        """Generate with online adapter updates."""
        self.model.reset_adapters()

        prompt_tokens = tokenize(prompt)
        generated = []
        all_tokens = prompt_tokens.clone()
        update_count = 0

        while len(generated) < max_tokens:
            # Phase 1: Generate a segment with current adapter
            segment = self._generate_segment(all_tokens, self.segment_size)

            if self._is_done(segment):
                # Model signaled end of reasoning
                generated.extend(segment)
                break

            segment_start = len(all_tokens)
            all_tokens = torch.cat([all_tokens, segment])
            segment_end = len(all_tokens)
            generated.extend(segment)

            # Phase 2: Compute reward for this segment
            reward = compute_reward(
                self.model, prompt, all_tokens[:segment_start],
                segment, all_tokens, segment_start, segment_end
            )

            # Phase 3: Update adapter if we haven't exceeded budget
            if update_count < self.max_updates:
                self._update_adapter(all_tokens, segment_start, segment_end, reward)
                update_count += 1

        return detokenize(generated), update_count

    def _generate_segment(self, context, length):
        """Generate a fixed-length segment using the current adapter."""
        tokens = []
        current = context
        for _ in range(length):
            with torch.no_grad():
                logits = self.model(current.unsqueeze(0))
            next_token = sample(logits[0, -1])
            tokens.append(next_token)
            current = torch.cat([current, next_token.unsqueeze(0)])
            if next_token.item() == EOS_TOKEN_ID:
                break
        return torch.tensor(tokens)

    def _update_adapter(self, tokens, seg_start, seg_end, reward):
        """REINFORCE update on the transient adapter."""
        # Enable gradients only for adapter parameters
        adapter_params = list(self.model.adapters.parameters())
        for p in adapter_params:
            p.requires_grad_(True)

        # Forward pass to compute log-probs of the segment
        logits = self.model(tokens[:seg_end].unsqueeze(0), use_adapter=True)
        segment_logits = logits[0, seg_start:seg_end]
        segment_targets = tokens[seg_start + 1:seg_end + 1]

        log_probs = torch.log_softmax(segment_logits, dim=-1)
        token_log_probs = log_probs.gather(1, segment_targets.unsqueeze(1)).squeeze(1)
        policy_log_prob = token_log_probs.sum()

        # REINFORCE: maximize log_prob * reward
        loss = -policy_log_prob * reward

        # Backward and update
        loss.backward()

        with torch.no_grad():
            for p in adapter_params:
                if p.grad is not None:
                    p.data -= self.lr * p.grad
                    p.grad.zero_()

        # Disable gradients for inference
        for p in adapter_params:
            p.requires_grad_(False)

    def _is_done(self, segment):
        return EOS_TOKEN_ID in segment.tolist()
⚠️ Gradient Through KV Cache

The adapter update requires a backward pass through the segment’s forward computation. This means the KV cache entries for the segment must retain their computation graph. In practice, PoT maintains a small “gradient window” of the most recent segment’s KV entries with gradients enabled, while older entries are detached. This bounds the backward pass cost to O(segment_size×dmodel2)O(\text{segment\_size} \times d_{\text{model}}^2) per update.

What Changes After Each Update

To build intuition for what the adapter learns, consider a math problem where the model needs to factor a polynomial. On the first attempt (update 0), the model might try a brute-force approach that leads nowhere. The reward signal for this segment is low.

After the first update, the adapter has learned (via the negative reward) to suppress the token patterns associated with brute-force factoring. On the second segment, the model is more likely to try a different strategy — synthetic division or the rational root theorem.

If synthetic division yields a positive reward (coherent steps, self-verification passes), the adapter reinforces this pattern. Subsequent segments are generated with an increasingly refined policy that favors productive reasoning strategies for this specific problem.

def trace_adapter_evolution(model, prompt, segment_size=128, num_updates=16):
    """Diagnostic: track how the adapter changes generation distribution."""
    model.reset_adapters()
    tokens = tokenize(prompt)
    distributions = []

    for step in range(num_updates):
        # Record distribution before update
        with torch.no_grad():
            logits = model(tokens.unsqueeze(0))
            probs = torch.softmax(logits[0, -1], dim=-1)
            top_k = torch.topk(probs, k=20)
            distributions.append({
                'step': step,
                'top_tokens': [(detokenize([t.item()]), p.item())
                               for t, p in zip(top_k.indices, top_k.values)],
                'entropy': -(probs * (probs + 1e-10).log()).sum().item(),
            })

        # Generate segment and update adapter
        segment = generate_segment(model, tokens, segment_size)
        tokens = torch.cat([tokens, segment])
        reward = compute_reward(model, prompt, tokens[:-len(segment)],
                                segment, tokens, len(tokens) - len(segment), len(tokens))
        update_adapter(model, tokens, len(tokens) - len(segment), len(tokens), reward)

    return distributions

Token Entropy Over PoT Updates (Llama 8B, Competition Math)

(bits (next-token entropy))
Update 0 High entropy: uncertain
8.2 bits (next-token entropy)
Update 2
7.1 bits (next-token entropy)
Update 4
5.8 bits (next-token entropy)
Update 8
4.3 bits (next-token entropy)
Update 12
3.5 bits (next-token entropy)
Update 16 Low entropy: focused
3.1 bits (next-token entropy)

Entropy drops as the adapter accumulates updates. The model becomes more decisive — its distribution sharpens around the reasoning patterns that received high reward on this specific problem. This is the core mechanism: PoT converts test-time compute into distribution refinement, not just distribution sampling.

When PoT Helps and When It Hurts

PoT’s overhead is substantial: 3-5x more compute than standard generation due to the reward computation and adapter update backward passes. This overhead is only justified when the quality gain exceeds what cheaper methods (best-of-N, beam search) can achieve.

Problems Where PoT Excels

Multi-step mathematical proofs, code debugging, and logical deduction chains share a common property: the correct strategy is not immediately obvious, but once found, it can be verified and reinforced. PoT thrives here because:

  1. The reward signal is informative — mathematical correctness and code execution provide clear feedback.
  2. The problem requires strategy adaptation — trying multiple approaches sequentially until one works.
  3. The correct approach, once identified, benefits from reinforcement across subsequent steps.
📊

PoT vs Best-of-N vs Standard (Llama 8B)

TaskStandardBest-of-64PoT (16 updates)PoT Compute
MATH-500 (competition) 34.2% 59.3% 67.8% 3.2x standard
HumanEval (code) 41.5% 62.0% 70.2% 3.8x standard
LogiQA (logic puzzles) 52.1% 65.4% 73.1% 3.5x standard
ARC-Challenge (science) 68.3% 78.1% 80.5% 3.3x standard
TriviaQA (factual) 71.2% 73.0% 72.8% 3.1x standard
MMLU (general knowledge) 64.5% 66.2% 65.1% 3.0x standard
Note: Best-of-64 uses 64x compute. PoT uses 3-4x compute. PoT is more compute-efficient on reasoning tasks.

On MATH-500, PoT at 3.2x compute (67.8%) beats best-of-64 at 64x compute (59.3%). The adapter updates shift the distribution in ways that sampling alone cannot.

Problems Where PoT Hurts

Factual recall (TriviaQA), broad knowledge (MMLU), and simple tasks show negligible or negative gains. The reasons:

  1. No strategy to adapt: Factual questions have one correct answer that the model either knows or does not. No amount of adapter updating will create knowledge that is not in the weights.
  2. Reward signal is weak: For factual questions, confidence scoring is unreliable (the model may be confidently wrong about a fact), and self-verification adds noise.
  3. Overhead exceeds benefit: A 3x compute penalty for a 0-1% accuracy gain is a net loss.
🚨 PoT Can Degrade Quality

On simple tasks, PoT updates can push the model away from the correct answer. If the first segment receives a spuriously low reward (e.g., the model is uncertain about a correct fact), the adapter update suppresses the correct reasoning path. For production systems, a difficulty classifier should gate PoT: only activate it for problems estimated to benefit from extended reasoning.

Compute Cost Analysis

PoT’s compute cost has four components per segment:

  1. Forward pass (generation): Same as standard generation. Cost: 1 forward pass per segment.
  2. Forward pass (reward): 1-2 forward passes for confidence + verification. Cost: 1.5 forward passes per segment.
  3. Forward pass (gradient computation): One forward pass with gradients enabled for the adapter parameters. Cost: 1 forward pass per segment.
  4. Backward pass (adapter update): Backward through the segment with respect to adapter parameters only. Cost: approximately 1 forward pass equivalent per segment (adapters are small, so the backward is cheaper than a full model backward).

Total per segment: approximately 4.5 forward pass equivalents, versus 1 for standard generation. With segment size 128 and maximum 16 updates over 2048 generated tokens:

Total FLOPs=2048×Ffwdgeneration+16×1.5×Ffwd×128reward+16×2.0×Ffwd×128update\text{Total FLOPs} = \underbrace{2048 \times F_{\text{fwd}}}_{\text{generation}} + \underbrace{16 \times 1.5 \times F_{\text{fwd}} \times 128}_{\text{reward}} + \underbrace{16 \times 2.0 \times F_{\text{fwd}} \times 128}_{\text{update}}

where FfwdF_{\text{fwd}} is the FLOPs for one forward pass on one token.

Simplifying (using per-token cost as the unit):

Total cost=2048+16×1.5×128+16×2.0×128=2048+3072+4096=9216 token-equivalents\text{Total cost} = 2048 + 16 \times 1.5 \times 128 + 16 \times 2.0 \times 128 = 2048 + 3072 + 4096 = 9216 \text{ token-equivalents}

This is approximately 4.5x the cost of standard generation (2048 tokens). In practice, measured overhead is 3-5x depending on model size and segment length.

📊

PoT Compute Overhead by Configuration

Segment SizeMax UpdatesEffective TokensOverhead vs StandardMATH-500 Accuracy
64 32 2048 generated 5.1x 69.2%
128 16 2048 generated 4.5x 67.8%
256 8 2048 generated 3.8x 65.1%
512 4 2048 generated 3.2x 60.3%
128 32 4096 generated 4.5x 71.4%
128 0 (no updates) 2048 generated 1.0x 34.2%
Note: Segment size 128 with 16 updates balances cost and accuracy. Smaller segments update more frequently but each update sees less context.

The sweet spot is segment size 128 with 16 updates: 4.5x overhead for a 33.6 percentage point gain over standard generation. Smaller segments (64 tokens) give slightly better accuracy but at higher cost. Larger segments (512 tokens) update too infrequently for the adapter to converge.

PoT vs Existing Test-Time Methods

Comparison with Best-of-N

Best-of-N generates NN independent samples and selects the best. PoT generates one sample with iterative refinement. The compute efficiency comparison:

Best-of-N efficiency=ΔQN×CgenPoT efficiency=ΔQK×Cupdate\text{Best-of-N efficiency} = \frac{\Delta Q}{N \times C_{\text{gen}}} \quad \text{PoT efficiency} = \frac{\Delta Q}{K \times C_{\text{update}}}

where ΔQ\Delta Q is quality improvement, NN is number of samples, KK is number of adapter updates, CgenC_{\text{gen}} is generation cost per sample, and CupdateC_{\text{update}} is the cost per update cycle (generation + reward + backward).

Quality vs Compute Budget (Llama 8B, MATH-500)

(% accuracy on MATH-500)
Standard (1x)
34.2 % accuracy on MATH-500
Best-of-4 (4x)
48.1 % accuracy on MATH-500
PoT 4 updates (3.2x)
60.3 % accuracy on MATH-500
Best-of-16 (16x)
55.8 % accuracy on MATH-500
PoT 16 updates (4.5x)
67.8 % accuracy on MATH-500
Best-of-64 (64x)
59.3 % accuracy on MATH-500
PoT 32 updates (5.1x)
69.2 % accuracy on MATH-500

At every compute budget, PoT outperforms best-of-N on reasoning tasks. The gap widens at higher budgets because PoT’s improvements compound (each update builds on previous updates), while best-of-N’s improvements are independent draws from a fixed distribution.

Comparison with Tree Search (MCTS)

Monte Carlo Tree Search explores a tree of reasoning paths, using a PRM to score branches. PoT is fundamentally different: it does not explore multiple paths. It generates a single path but continuously improves the generator.

# MCTS: explore multiple paths, select best
#   Compute: O(branching_factor * depth * forward_cost)
#   Memory: O(branching_factor * depth * KV_cache_per_token)
#   Quality: depends on branching factor and PRM quality

# PoT: single path, improve generator
#   Compute: O(num_updates * segment_size * update_cost)
#   Memory: O(sequence_length * KV_cache_per_token + adapter_size)
#   Quality: depends on reward quality and learning rate

PoT has a significant memory advantage: it maintains one KV cache for one generation path, plus the small adapter. MCTS maintains KV caches for all active branches. For a branching factor of 8 and depth 5, MCTS requires 8x the KV cache memory.

📊

PoT vs MCTS Memory and Compute (Llama 8B, 2048 tokens)

MethodKV Cache MemoryCompute (FLOPs)MATH-500 Accuracy
Standard generation 672 MB 1.0x 34.2%
MCTS (branch=4, depth=8) 5.4 GB 12x 64.5%
MCTS (branch=8, depth=8) 10.7 GB 24x 68.1%
PoT (16 updates) 672 MB + 4 MB adapter 4.5x 67.8%
PoT + MCTS hybrid 2.7 GB 10x 74.3%
Note: KV cache for Llama 8B: 32 layers x 8 KV heads x 128 dim x 2048 tokens x 2 bytes = 672 MB.

PoT achieves comparable accuracy to MCTS (branch=8) at 5.3x less compute and 16x less memory. The hybrid (PoT-refined model used as the policy in MCTS) achieves the best accuracy by combining distribution refinement with path exploration.

Implementation: Adam Optimizer for Adapter Updates

The basic REINFORCE update shown earlier uses vanilla SGD. In practice, Adam provides significantly faster adapter convergence because the reward signal is noisy.

class AdamAdapterOptimizer:
    """Adam optimizer specialized for transient LoRA updates.

    Uses aggressive hyperparameters suited for few-step optimization:
    higher learning rate, lower beta2 (faster variance adaptation).
    """
    def __init__(self, adapters, lr=3e-4, beta1=0.9, beta2=0.95, eps=1e-8):
        self.adapters = adapters
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps

    def step(self, reward):
        """Apply one Adam step to all adapter parameters.

        Assumes .grad is populated from the backward pass.
        Reward is used as a scalar multiplier on the gradient.
        """
        for name, adapter in self.adapters.items():
            for pname in ['A', 'B']:
                p = getattr(adapter, pname)
                if p.grad is None:
                    continue

                # Scale gradient by reward (REINFORCE)
                grad = p.grad * reward

                adapter.step_count += 1
                t = adapter.step_count

                # Update momentum buffers
                m = getattr(adapter, f'm_{pname}')
                v = getattr(adapter, f'v_{pname}')

                m.mul_(self.beta1).add_(grad, alpha=1 - self.beta1)
                v.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)

                # Bias correction
                m_hat = m / (1 - self.beta1 ** t)
                v_hat = v / (1 - self.beta2 ** t)

                # Update parameters
                p.data.sub_(self.lr * m_hat / (v_hat.sqrt() + self.eps))

                p.grad.zero_()

The hyperparameters differ from standard training Adam:

  • Learning rate 3e-4 (vs 1e-4 for training): We need fast adaptation in 16-32 steps.
  • Beta2 = 0.95 (vs 0.999 for training): The variance estimate must adapt quickly because the reward distribution shifts as the adapter changes.
  • No weight decay: The adapter is transient and discarded after the query. Regularization is unnecessary.

Stabilization: Reward Normalization and Gradient Clipping

Raw REINFORCE is high-variance. Two techniques stabilize PoT updates.

Running Reward Baseline

Subtract a baseline from the reward to reduce variance. PoT maintains a running mean of recent rewards:

class RewardBaseline:
    """Exponential moving average baseline for variance reduction."""
    def __init__(self, decay=0.9):
        self.decay = decay
        self.baseline = 0.0
        self.initialized = False

    def update_and_normalize(self, reward):
        if not self.initialized:
            self.baseline = reward
            self.initialized = True
            return 0.0  # First reward has zero advantage

        advantage = reward - self.baseline
        self.baseline = self.decay * self.baseline + (1 - self.decay) * reward
        return advantage

Gradient Clipping

Clip the adapter gradients to prevent catastrophic updates from extreme reward signals:

def clip_adapter_gradients(adapters, max_norm=1.0):
    """Clip gradients across all adapter parameters."""
    all_grads = []
    for adapter in adapters.values():
        for pname in ['A', 'B']:
            p = getattr(adapter, pname)
            if p.grad is not None:
                all_grads.append(p.grad)

    if not all_grads:
        return 0.0

    total_norm = torch.sqrt(sum(g.norm() ** 2 for g in all_grads))
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1.0:
        for g in all_grads:
            g.mul_(clip_coef)

    return total_norm.item()

Production Deployment Considerations

Difficulty-Gated Activation

PoT should only activate for problems that benefit from extended reasoning. A lightweight classifier (a linear probe on the model’s first-layer hidden states) estimates problem difficulty:

class DifficultyGate(nn.Module):
    """Predict whether PoT will benefit this query.

    Trained on (query, benefit_from_PoT) pairs where benefit
    is measured as accuracy_with_PoT - accuracy_without_PoT.
    """
    def __init__(self, d_model):
        super().__init__()
        self.probe = nn.Linear(d_model, 1)

    def forward(self, hidden_states):
        # Use [CLS] or first token representation
        cls_repr = hidden_states[:, 0, :]
        return torch.sigmoid(self.probe(cls_repr))

    def should_activate_pot(self, hidden_states, threshold=0.6):
        benefit_prob = self.forward(hidden_states).item()
        return benefit_prob > threshold

Batching Challenges

PoT complicates batched inference. Different requests in a batch may be at different adapter update stages, and each request has its own transient adapter. Two approaches:

  1. Per-request adapters: Each request in the batch has its own LoRA adapter. The forward pass applies different adapters to different batch elements. This requires custom CUDA kernels for batched LoRA application.

  2. Synchronous updates: All requests in a batch update their adapters at the same step boundaries. Simpler to implement but forces uniform segment sizes.

class BatchedPoTEngine:
    """Manage per-request adapters in a batched inference engine."""

    def __init__(self, model, batch_size, lora_rank=8):
        self.model = model
        self.batch_size = batch_size
        # One set of adapters per batch slot
        self.per_request_adapters = [
            create_adapters(model, lora_rank) for _ in range(batch_size)
        ]

    def batched_forward(self, input_ids, active_mask):
        """Forward pass applying per-request adapters.

        input_ids: [batch_size, seq_len]
        active_mask: [batch_size] — which slots have active PoT
        """
        # Base model forward (shared)
        base_output = self.model.base_forward(input_ids)

        # Apply per-request adapter corrections
        for i in range(self.batch_size):
            if active_mask[i]:
                adapter_output = self.per_request_adapters[i](
                    base_output[i:i+1]
                )
                base_output[i] += adapter_output[0]

        return self.model.lm_head(base_output)
💡 Adapter Memory per Batch Slot

A rank-8 adapter for a 7B model uses approximately 4 MB. For a batch of 32 requests with PoT active, that is 128 MB of adapter memory — negligible compared to the KV cache. The overhead is compute (per-request backward passes), not memory.

Throughput Impact

📊

Throughput Impact of PoT (Llama 8B on H100)

ModeTokens/secLatency (2K tokens)Accuracy (MATH-500)
Standard generation 4200 tok/s 0.48s 34.2%
PoT (16 updates, seg=128) 920 tok/s 2.2s 67.8%
PoT (8 updates, seg=256) 1350 tok/s 1.5s 65.1%
Best-of-64 4200 tok/s per sample 30.5s total 59.3%
Note: PoT latency is per-request. Best-of-64 latency is total (sequential generation). Parallelized best-of-64 is faster per-request but uses 64x GPU resources.

PoT reduces throughput by 4.5x for a single request but is more compute-efficient than best-of-64 for the same quality level. The latency cost (2.2s vs 0.48s) is acceptable for high-value tasks where users expect longer processing times.

Failure Modes

Reward Hacking

If the confidence reward dominates, the adapter may learn to generate high-confidence gibberish — text that the model assigns high log-probability to but that is not correct. The self-verification reward component mitigates this, but imperfect verification can still be exploited.

# Example of reward hacking:
# Step 1: Model generates "Let x = 5" (moderate confidence)
# Step 2: Adapter update increases confidence bias
# Step 3: Model generates "Therefore x = 5" (high confidence, circular)
# Step 4: High reward from confidence, neutral from verification
# Step 5: Adapter reinforces circular reasoning

# Mitigation: monotonically increase verification weight over updates
def adaptive_reward_weights(update_step, total_updates):
    progress = update_step / total_updates
    w_conf = 0.3 * (1 - progress)      # Decrease confidence weight
    w_cohere = 0.3                       # Keep coherence constant
    w_verify = 0.4 + 0.3 * progress     # Increase verification weight
    return w_conf, w_cohere, w_verify

Catastrophic Adapter Divergence

With aggressive learning rates, the adapter can diverge — producing wildly different outputs from the base model. Gradient clipping and a maximum adapter norm constraint prevent this:

def constrain_adapter_norm(adapters, max_norm=2.0):
    """Ensure adapter does not deviate too far from zero."""
    for adapter in adapters.values():
        output_norm = (adapter.A @ adapter.B).norm() * adapter.scaling
        if output_norm > max_norm:
            scale = max_norm / output_norm
            adapter.A.data *= scale.sqrt()
            adapter.B.data *= scale.sqrt()

Summary

Policy of Thoughts introduces online learning into inference. Instead of sampling more from a frozen distribution, PoT shifts the distribution itself by updating a transient LoRA adapter based on self-generated reward signals. The mechanism is a forward-reward-update loop: generate a reasoning segment, score it for confidence/coherence/correctness, backpropagate through the adapter, and continue generating with the improved policy.

The cost is 3-5x standard generation — substantial but far more efficient than best-of-N for reasoning tasks. A 3.2x compute investment yields accuracy improvements that best-of-64 (64x compute) cannot match on mathematical reasoning. The key limitation is task-dependence: PoT helps on problems requiring strategy adaptation (math, code, logic) and hurts on factual recall where no amount of policy refinement can create missing knowledge.

The architecture — zero-initialized transient LoRA, multi-signal reward, Adam with aggressive hyperparameters, difficulty-gated activation — reflects a broader principle in inference-time compute scaling: the most efficient use of extra compute is not generating more samples from the same distribution, but improving the distribution itself.