DeepMind’s DreamerV3 trained a world model on Atari frames and used it to plan 15 steps ahead. The result: superhuman performance on 26 games with 100x fewer environment interactions than model-free RL. The key: instead of trying 10 million random actions in the real environment, the agent simulates 10 million actions in its learned world model and only executes the best trajectory. The cost of a prediction error: if the world model is 1% wrong per step, compounding that over 15 steps gives 14% average error — but that is still cheaper than 10 million real environment queries.
World Model Fundamentals
State Transition Prediction
from dataclasses import dataclass, field
from enum import Enum
import numpy as np
from typing import Optional
class StateSpace(Enum):
PIXEL = "pixel"
LATENT = "latent"
SYMBOLIC = "symbolic"
TEXT = "text"
HYBRID = "hybrid"
@dataclass
class WorldModelConfig:
"""Configuration for a world model."""
state_dim: int
action_dim: int
latent_dim: int
state_space: StateSpace
sequence_length: int
prediction_horizon: int
stochastic: bool = True
deterministic_size: int = 512
stochastic_size: int = 32
discrete_latent: bool = True
n_categories: int = 32
category_size: int = 32
class LatentDynamicsModel:
"""
Latent dynamics world model (RSSM architecture).
The model operates in three stages:
1. Encoder: observation -> latent state
(compress high-dimensional observation to compact latent)
2. Transition: (latent_t, action_t) -> latent_{t+1}
(predict next latent state from current state + action)
3. Decoder: latent -> observation, reward
(reconstruct observation and predict reward)
The latent state has two components:
- Deterministic: captures slow-changing features
(recurrent state, h_t)
- Stochastic: captures fast-changing, uncertain features
(sampled from a learned distribution, z_t)
Full state: s_t = concat(h_t, z_t)
"""
def __init__(self, config):
self.config = config
self.encoder = None # observation -> z
self.dynamics = None # (h_t, z_t, a_t) -> h_{t+1}
self.posterior = None # (h_{t+1}, obs) -> z_{t+1}
self.prior = None # h_{t+1} -> z_{t+1}
self.decoder = None # (h_t, z_t) -> obs_hat
self.reward_head = None # (h_t, z_t) -> reward
def encode(self, observation):
"""
Encode observation to latent representation.
For images: CNN encoder
For text: transformer encoder
For symbolic states: MLP encoder
"""
if self.encoder is None:
return np.zeros(self.config.latent_dim)
return self.encoder(observation)
def predict_next_state(self, state, action):
"""
Predict next latent state given current state and action.
This is the core world model prediction:
s_{t+1} = f(s_t, a_t)
With the RSSM architecture:
h_{t+1} = GRU(h_t, concat(z_t, a_t))
z_{t+1} ~ p(z | h_{t+1}) [prior, no observation]
"""
h_t = state[:self.config.deterministic_size]
z_t = state[self.config.deterministic_size:]
# Deterministic transition
combined = np.concatenate([z_t, action])
h_next = self._gru_step(h_t, combined)
# Stochastic prediction (prior)
z_prior_params = self._prior_network(h_next)
z_next = self._sample_stochastic(z_prior_params)
next_state = np.concatenate([h_next, z_next])
return next_state
def predict_trajectory(self, initial_state, actions):
"""
Predict a full trajectory of states from
initial state and action sequence.
Used for planning: simulate many possible
action sequences and pick the best one.
"""
states = [initial_state]
current_state = initial_state
for action in actions:
next_state = self.predict_next_state(
current_state, action
)
states.append(next_state)
current_state = next_state
# Decode rewards for the trajectory
rewards = [
self._predict_reward(s) for s in states
]
return {
"states": states,
"rewards": rewards,
"total_reward": sum(rewards),
}
def _gru_step(self, h, x):
"""Single GRU step (placeholder)."""
return np.tanh(h * 0.9 + x[:len(h)] * 0.1)
def _prior_network(self, h):
"""Prior distribution parameters from h."""
return {"mean": h[:self.config.stochastic_size],
"std": np.ones(self.config.stochastic_size)}
def _sample_stochastic(self, params):
"""Sample from stochastic distribution."""
return params["mean"] + params["std"] * np.random.randn(
*params["mean"].shape
)
def _predict_reward(self, state):
"""Predict reward from latent state."""
return 0.0
The RSSM (Recurrent State-Space Model) architecture separates the latent state into deterministic and stochastic components. The deterministic component (GRU hidden state) captures slow dynamics that are predictable. The stochastic component captures fast, uncertain dynamics. This separation is critical: a purely deterministic model cannot represent ambiguity (will the ball go left or right?), while a purely stochastic model has high variance in long rollouts.
JEPA: Joint Embedding Predictive Architecture
Predicting in Latent Space
class JEPA:
"""
Joint Embedding Predictive Architecture (Yann LeCun, 2022).
Key insight: predict in latent space, not pixel space.
Pixel-space prediction is hard because it requires
predicting every pixel, including irrelevant details
(exact texture, lighting variations). Predicting in
a learned latent space focuses the model on
semantically meaningful state changes.
Architecture:
1. Context encoder: encodes observed context
2. Target encoder: encodes future target (EMA of context encoder)
3. Predictor: predicts target embedding from context embedding
Loss: distance between predicted target embedding
and actual target embedding (no decoder needed).
Collapse prevention: VICReg, Barlow Twins, or
asymmetric architecture (predictor only on one side).
"""
def __init__(self, config):
self.embed_dim = config.get("embed_dim", 768)
self.pred_depth = config.get("pred_depth", 6)
self.ema_decay = config.get("ema_decay", 0.996)
def forward(self, context_frames, target_frames):
"""
Forward pass for JEPA training.
context_frames: observed frames [B, T_ctx, C, H, W]
target_frames: future frames [B, T_tgt, C, H, W]
"""
# Encode context
context_embeddings = self._context_encoder(
context_frames
)
# Encode targets (with EMA encoder, no gradient)
target_embeddings = self._target_encoder(
target_frames
)
# Predict target embeddings from context
predicted_embeddings = self._predictor(
context_embeddings
)
# Loss: L2 distance in embedding space
loss = np.mean(
(predicted_embeddings - target_embeddings) ** 2
)
return {
"loss": loss,
"context_embeddings": context_embeddings,
"predicted_embeddings": predicted_embeddings,
"target_embeddings": target_embeddings,
}
def _context_encoder(self, frames):
"""Encode context frames to embeddings."""
return np.zeros((frames.shape[0], self.embed_dim))
def _target_encoder(self, frames):
"""Encode target frames (EMA, no gradient)."""
return np.zeros((frames.shape[0], self.embed_dim))
def _predictor(self, context_embeddings):
"""Predict target embeddings from context."""
return context_embeddings # Placeholder
def plan_with_jepa(self, current_state,
candidate_actions, goal_state):
"""
Planning with JEPA: select actions that
move toward the goal in embedding space.
1. Encode current state
2. Encode goal state
3. For each candidate action sequence:
a. Predict future embedding
b. Compute distance to goal embedding
4. Select action sequence with smallest distance
"""
current_embed = self._context_encoder(
current_state
)
goal_embed = self._target_encoder(goal_state)
best_actions = None
best_distance = float("inf")
for actions in candidate_actions:
# Simulate trajectory in latent space
embed = current_embed
for action in actions:
embed = self._predict_with_action(
embed, action
)
distance = np.linalg.norm(embed - goal_embed)
if distance < best_distance:
best_distance = distance
best_actions = actions
return best_actions, best_distance
def _predict_with_action(self, embed, action):
"""Predict next embedding given action."""
return embed # Placeholder
World Model Architectures: Prediction Quality
| Architecture | State Space | 1-Step Error | 10-Step Error | 50-Step Error | Training Data | Parameters |
|---|---|---|---|---|---|---|
| Pixel CNN (autoregressive) | Pixel | Low | Medium | Very High | Video frames | 100-500M |
| RSSM (DreamerV3) | Latent | Medium | Low | Medium | RL trajectories | 10-200M |
| JEPA (V-JEPA) | Latent | Low | Low | Low | Video | 300M-1B |
| Transformer (Genie) | Latent | Low | Medium | High | Video + actions | 1-11B |
| Diffusion world model | Pixel | Very Low | Medium | High | Video | 500M-2B |
DreamerV3: Model-Based RL with World Models
Learning to Act by Dreaming
class DreamerV3:
"""
DreamerV3: mastering diverse domains without data-specific tuning.
Architecture:
1. World model: RSSM with discrete latent variables
2. Actor: policy that selects actions in latent space
3. Critic: value function that estimates returns in latent space
Training loop:
1. Collect experience in real environment
2. Train world model on collected experience
3. "Dream": generate imagined trajectories in world model
4. Train actor and critic on imagined trajectories
5. Repeat
Key innovation in V3: symlog predictions and
free bits for KL balancing. These enable training
across environments with very different reward
scales without hyperparameter tuning.
"""
def __init__(self, config):
self.world_model = LatentDynamicsModel(config)
self.imagination_horizon = config.get(
"imagination_horizon", 15
)
self.n_imagination_trajectories = config.get(
"n_trajectories", 16
)
self.discount = config.get("discount", 0.997)
self.lambda_gae = config.get("lambda_gae", 0.95)
def imagine_trajectories(self, initial_states):
"""
Generate imagined trajectories in the world model.
Start from real encoded states, then roll out
using the learned dynamics and actor policy.
No real environment interaction needed.
"""
trajectories = []
for state in initial_states:
trajectory = {
"states": [state],
"actions": [],
"rewards": [],
"values": [],
}
current = state
for t in range(self.imagination_horizon):
# Actor selects action in latent space
action = self._actor(current)
# World model predicts next state
next_state = (
self.world_model.predict_next_state(
current, action
)
)
# Predict reward
reward = (
self.world_model._predict_reward(
next_state
)
)
# Critic estimates value
value = self._critic(next_state)
trajectory["states"].append(next_state)
trajectory["actions"].append(action)
trajectory["rewards"].append(reward)
trajectory["values"].append(value)
current = next_state
trajectories.append(trajectory)
return trajectories
def compute_lambda_returns(self, rewards, values):
"""
Compute lambda-returns for actor-critic training.
G_t^lambda = r_t + gamma * (
(1-lambda) * V(s_{t+1})
+ lambda * G_{t+1}^lambda
)
This blends TD(0) (lambda=0, low variance,
high bias) with Monte Carlo (lambda=1, high
variance, low bias).
"""
T = len(rewards)
returns = np.zeros(T)
# Bootstrap from final value
next_return = values[-1] if values else 0.0
for t in reversed(range(T)):
td_target = (
rewards[t] + self.discount * values[t]
)
returns[t] = (
(1 - self.lambda_gae) * td_target
+ self.lambda_gae * (
rewards[t]
+ self.discount * next_return
)
)
next_return = returns[t]
return returns
def symlog(self, x):
"""
Symlog transformation (DreamerV3).
symlog(x) = sign(x) * log(|x| + 1)
Compresses large values while preserving sign.
Enables training across environments with
reward scales differing by 6 orders of magnitude
without reward normalization.
"""
return np.sign(x) * np.log(np.abs(x) + 1)
def symexp(self, x):
"""Inverse of symlog."""
return np.sign(x) * (np.exp(np.abs(x)) - 1)
def _actor(self, state):
"""Policy network: state -> action."""
return np.zeros(self.world_model.config.action_dim)
def _critic(self, state):
"""Value network: state -> value estimate."""
return 0.0
Video Prediction as World Modeling
Learning Physics from Internet Video
class VideoWorldModel:
"""
World model trained on internet video.
Internet video is the largest source of world dynamics:
objects fall, liquids flow, people walk, cars drive.
A model that can predict future video frames given
past frames has learned implicit physics.
Challenges:
- No action labels in internet video (unlike RL)
- Resolution and diversity requirements are extreme
- Computational cost: video transformers are expensive
- Evaluation: FVD, LPIPS, but human judgment is needed
"""
def __init__(self, config):
self.resolution = config.get("resolution", 256)
self.frame_rate = config.get("frame_rate", 4)
self.context_frames = config.get("context_frames", 16)
self.prediction_frames = config.get(
"prediction_frames", 16
)
self.latent_dim = config.get("latent_dim", 1024)
def train_on_video(self, video_dataset):
"""
Training loop for video world model.
Loss components:
1. Reconstruction loss: predicted frames vs actual
2. Perceptual loss: LPIPS distance in feature space
3. KL divergence: regularize latent distribution
4. Temporal consistency: smooth latent trajectories
"""
metrics = {
"reconstruction_loss": [],
"perceptual_loss": [],
"kl_loss": [],
}
for batch in video_dataset:
context = batch[:, :self.context_frames]
target = batch[:, self.context_frames:]
# Encode context to latent
context_latent = self._encode_video(context)
# Predict future latents
predicted_latents = self._predict_future(
context_latent,
n_steps=self.prediction_frames,
)
# Decode to predicted frames
predicted_frames = self._decode_video(
predicted_latents
)
# Compute losses
recon_loss = np.mean(
(predicted_frames - target) ** 2
)
metrics["reconstruction_loss"].append(recon_loss)
return metrics
def predict_with_action(self, context_frames, action):
"""
Predict future given context and an action.
For video models trained without action labels,
the 'action' is specified as a text description
or a directional vector. The model must learn
to condition its predictions on the action.
"""
context_latent = self._encode_video(context_frames)
# Condition on action (text -> embedding)
action_embedding = self._encode_action(action)
# Predict future latent conditioned on action
future_latent = self._predict_conditional(
context_latent, action_embedding
)
# Decode to frames
future_frames = self._decode_video(future_latent)
return future_frames
def _encode_video(self, frames):
"""Encode video frames to latent."""
return np.zeros(
(frames.shape[0], self.latent_dim)
)
def _predict_future(self, latent, n_steps):
"""Predict future latents autoregressively."""
return [latent] * n_steps
def _decode_video(self, latents):
"""Decode latents to video frames."""
return np.zeros((1, len(latents), 3,
self.resolution, self.resolution))
def _encode_action(self, action):
"""Encode action to embedding."""
return np.zeros(self.latent_dim)
def _predict_conditional(self, latent, action_embed):
"""Predict future conditioned on action."""
return latent
Video Prediction Quality: FVD Score by Prediction Horizon
| Metric | 4 | 8 | 16 | 32 | 64 |
|---|---|---|---|---|---|
| Diffusion-based (SVD) | |||||
| Autoregressive (VideoGPT) | |||||
| Latent dynamics (RSSM) | |||||
| JEPA (V-JEPA) |
FVD (Frechet Video Distance) measures the quality of generated video distributions, not individual frame accuracy. A low FVD means the generated videos are statistically similar to real videos. However, FVD does not capture physical plausibility: a model can achieve good FVD by generating realistic-looking but physically impossible scenes. Human evaluation remains essential for assessing whether the world model has learned actual physics.
World Models for LLM Agents
Integrating World Models with Language
class LLMWorldModelAgent:
"""
Integrate a world model with an LLM agent
for planning.
The LLM generates candidate action plans in
natural language. The world model simulates
each plan and predicts outcomes. The LLM selects
the plan with the best predicted outcome.
This separation allows:
- LLM: creative plan generation (broad search)
- World model: accurate outcome prediction (evaluation)
"""
def __init__(self, llm, world_model, n_candidates=10):
self.llm = llm
self.world_model = world_model
self.n_candidates = n_candidates
def plan(self, task_description, current_state):
"""
Generate and evaluate candidate plans.
"""
# Step 1: LLM generates candidate plans
plans = self._generate_candidate_plans(
task_description, current_state
)
# Step 2: World model evaluates each plan
evaluations = []
for plan in plans:
actions = self._parse_plan_to_actions(plan)
result = self.world_model.predict_trajectory(
current_state, actions
)
evaluations.append({
"plan": plan,
"predicted_reward": result["total_reward"],
"predicted_states": result["states"],
"n_steps": len(actions),
})
# Step 3: Select best plan
evaluations.sort(
key=lambda x: x["predicted_reward"],
reverse=True,
)
return evaluations[0]
def _generate_candidate_plans(self, task, state):
"""Use LLM to generate candidate action plans."""
prompt = (
f"Task: {task}\n"
f"Current state: {state}\n\n"
f"Generate {self.n_candidates} different "
f"approaches to complete this task. "
f"Each approach should be a concrete sequence "
f"of actions."
)
plans = self.llm.generate(
prompt,
n=self.n_candidates,
temperature=0.9,
)
return plans
def _parse_plan_to_actions(self, plan_text):
"""Parse a text plan into executable actions."""
return [] # Placeholder
Compounding Error Analysis
The Fundamental Challenge
class CompoundingErrorAnalysis:
"""
Analyze compounding prediction errors in world models.
A world model with 1% per-step error accumulates
errors over multi-step rollouts:
After T steps, expected error ~ 1 - (1 - e)^T
For e = 0.01, T = 50: error ~ 39%
For e = 0.01, T = 100: error ~ 63%
For e = 0.001, T = 50: error ~ 4.9%
Mitigation strategies:
1. Re-plan frequently (shorten T)
2. Train on multi-step predictions (not just 1-step)
3. Use ensembles (average over multiple models)
4. Plan in latent space (lower-dimensional, smoother)
"""
def simulate_error_accumulation(
self, per_step_error, max_steps=100
):
"""
Simulate error accumulation over rollout steps.
"""
results = []
cumulative_error = 0.0
for t in range(1, max_steps + 1):
cumulative_error = (
1.0 - (1.0 - per_step_error) ** t
)
results.append({
"step": t,
"cumulative_error": cumulative_error,
"remaining_accuracy": 1.0 - cumulative_error,
})
return results
def find_safe_horizon(self, per_step_error,
max_acceptable_error=0.2):
"""
Find the maximum planning horizon for a given
per-step error rate and acceptable total error.
"""
import math
if per_step_error <= 0:
return float("inf")
horizon = math.log(1 - max_acceptable_error) / math.log(
1 - per_step_error
)
return int(horizon)
Compounding Error: Cumulative Error by Rollout Length
| Metric | 1 | 5 | 10 | 20 | 50 | 100 |
|---|---|---|---|---|---|---|
| 1% per-step error | ||||||
| 0.5% per-step error | ||||||
| 0.1% per-step error | ||||||
| Ensemble (5 models, 1% each) |
Key Takeaways
World models predict the consequences of actions, enabling planning without real-world execution. The field has matured from pixel-space prediction to latent-space dynamics, with JEPA and DreamerV3 representing the current state of the art.
The critical findings:
-
Predict in latent space, not pixel space: Pixel-level prediction wastes capacity on irrelevant details (exact textures, lighting). JEPA and RSSM models predict in learned latent spaces, focusing on semantically meaningful state changes and achieving lower compounding error over long rollouts.
-
DreamerV3’s symlog trick enables generalization: The symlog transformation compresses reward magnitudes, allowing a single set of hyperparameters to work across environments with reward scales differing by 6 orders of magnitude. This is the key innovation that made DreamerV3 the first model-based agent to match model-free RL across diverse domains.
-
Compounding error limits planning horizon: A model with 1% per-step error accumulates 40% cumulative error after 50 steps. Practical implications: re-plan every 5-15 steps rather than executing a single 50-step plan. Ensembles of 5 models reduce effective per-step error from 1% to 0.4%.
-
Video prediction from internet data learns implicit physics: Models trained on internet video (Sora, Genie, SVD) learn object permanence, gravity, and collision dynamics without explicit physics supervision. However, they do not reliably learn fine-grained physical reasoning (exact trajectories, conservation laws).
-
LLM + world model is a natural architecture for agents: LLMs generate creative candidate plans (broad search). World models evaluate plans by simulating outcomes (accurate scoring). This separation leverages the strengths of each: LLMs for plan diversity, world models for physical grounding.