An LLM can write a convincing essay about gravity, but it cannot predict where a ball will land when thrown. It can describe the physics of a pendulum with textbook accuracy, but it cannot simulate the pendulum’s trajectory. This gap between linguistic knowledge and physical understanding is the central challenge of embodied AI: building systems that do not merely describe the world but model it — predicting future states from current observations and planned actions.
World models are the bridge. A world model takes the current state of an environment (pixels, sensor readings, text descriptions) and an action, then predicts the next state. If the model is accurate, it enables planning without physical trial-and-error: the agent can simulate thousands of action sequences internally and choose the best one. This is what humans do when we mentally rehearse catching a ball or navigating a room.
This post covers the architecture of world models, the connection between video generation (Sora) and world simulation, joint embedding predictive architectures (V-JEPA), how these connect to LLM-based reasoning, robotics foundation models, and an implementation of a simple world model.
What Is a World Model
Formal Definition
A world model is a learned function:
where is the state at time , is the action taken, and is the predicted next state. The model parameters are learned from observation data.
The state can be represented at different levels of abstraction:
- Pixel-level: is a raw image or video frame. The model predicts the next frame.
- Latent-level: is a compressed representation (embedding) of the visual scene. The model predicts the next embedding.
- Symbolic-level: is a structured representation (object positions, velocities, relations). The model predicts the next symbolic state.
Each level trades off between generality and efficiency. Pixel-level models are general (they work for any visual scene) but computationally expensive. Symbolic models are efficient but require defining the state representation in advance, which fails for novel environments.
The Latent World Model
Modern world models operate in latent space — a compressed representation learned by an encoder. The architecture has three components:
- Encoder : maps observations to latent states:
- Dynamics model : predicts next latent state:
- Decoder : reconstructs observations from latent states:
import torch
import torch.nn as nn
class LatentWorldModel(nn.Module):
"""
Latent-space world model with encoder, dynamics, and decoder.
Operates on image observations and discrete actions.
"""
def __init__(
self,
obs_channels=3,
obs_size=64,
latent_dim=256,
action_dim=4,
hidden_dim=512,
):
super().__init__()
# Encoder: image -> latent
self.encoder = nn.Sequential(
nn.Conv2d(obs_channels, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(256 * (obs_size // 16) ** 2, latent_dim),
)
# Dynamics model: (latent, action) -> next latent
self.dynamics = nn.Sequential(
nn.Linear(latent_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim),
)
# Decoder: latent -> image (transposed convolutions)
self.decoder_fc = nn.Linear(
latent_dim, 256 * (obs_size // 16) ** 2
)
self.decoder_conv = nn.Sequential(
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(32, obs_channels, 4, stride=2, padding=1),
nn.Sigmoid(),
)
self.obs_size = obs_size
def encode(self, obs):
"""Encode observation to latent state."""
return self.encoder(obs)
def predict_next_latent(self, z, action):
"""Predict next latent state from current latent + action."""
combined = torch.cat([z, action], dim=-1)
return self.dynamics(combined)
def decode(self, z):
"""Decode latent state to observation."""
h = self.decoder_fc(z)
spatial = self.obs_size // 16
h = h.view(-1, 256, spatial, spatial)
return self.decoder_conv(h)
def forward(self, obs, action):
"""
Full forward pass: encode current obs, predict next latent,
decode to predicted next obs.
"""
z_t = self.encode(obs)
z_t1 = self.predict_next_latent(z_t, action)
obs_pred = self.decode(z_t1)
return obs_pred, z_t, z_t1
def imagine(self, initial_obs, action_sequence):
"""
Imagine a sequence of future states from initial observation
and a planned action sequence.
Returns list of predicted observations.
"""
z = self.encode(initial_obs)
predictions = []
for action in action_sequence:
z = self.predict_next_latent(z, action)
obs_pred = self.decode(z)
predictions.append(obs_pred)
return predictions
Training the World Model
def train_world_model(model, dataset, epochs=100, lr=1e-4):
"""
Train the world model on observation-action-next_observation triples.
dataset: yields (obs_t, action_t, obs_t1) batches
"""
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
reconstruction_loss = nn.MSELoss()
latent_consistency_loss = nn.MSELoss()
for epoch in range(epochs):
total_loss = 0.0
num_batches = 0
for obs_t, action_t, obs_t1 in dataset:
# Forward pass
obs_pred, z_t, z_t1_pred = model(obs_t, action_t)
# Loss 1: Reconstruction -- predicted obs should match actual
loss_recon = reconstruction_loss(obs_pred, obs_t1)
# Loss 2: Latent consistency -- predicted latent should
# match encoded actual next obs
with torch.no_grad():
z_t1_actual = model.encode(obs_t1)
loss_latent = latent_consistency_loss(
z_t1_pred, z_t1_actual
)
# Combined loss
loss = loss_recon + 0.5 * loss_latent
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / max(num_batches, 1)
if epoch % 10 == 0:
print(f"Epoch {epoch}: loss={avg_loss:.4f}")
Sora as a World Simulator
Video Generation Is World Modeling
Sora generates videos by predicting future visual frames conditioned on a text prompt and (optionally) initial frames. This is fundamentally a world model — it predicts future states of a visual scene. The key architectural insight: Sora uses a diffusion transformer operating on spacetime patches.
The architecture (reconstructed from the technical report):
- Visual encoder: Compress video frames into a latent spacetime grid using a VAE. A 1080p, 60fps, 10-second video becomes a 3D grid of latent tokens.
- Diffusion transformer: Denoise the latent grid conditioned on text embeddings. Each transformer layer attends to all spacetime positions, enabling the model to reason about temporal consistency.
- Visual decoder: Decompress the latent grid back to pixel space.
class SoraStyleWorldModel(nn.Module):
"""
Simplified Sora-style architecture: diffusion transformer
operating on spacetime latent patches.
This is a sketch -- the real Sora is vastly larger.
"""
def __init__(
self,
patch_dim=16,
num_frames=16,
latent_channels=4,
transformer_dim=768,
num_layers=12,
num_heads=12,
):
super().__init__()
# Patch embedding: map spacetime patches to transformer dim
self.patch_embed = nn.Linear(
patch_dim * patch_dim * latent_channels,
transformer_dim,
)
# Temporal position encoding
self.temporal_pos = nn.Parameter(
torch.randn(1, num_frames, 1, transformer_dim)
)
# Spatial position encoding
self.spatial_pos = nn.Parameter(
torch.randn(1, 1, 256, transformer_dim) # Up to 16x16 patches
)
# Transformer layers
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=transformer_dim,
nhead=num_heads,
dim_feedforward=transformer_dim * 4,
batch_first=True,
)
for _ in range(num_layers)
])
# Text conditioning via cross-attention
self.text_proj = nn.Linear(768, transformer_dim)
# Output projection
self.output_proj = nn.Linear(
transformer_dim,
patch_dim * patch_dim * latent_channels,
)
self.num_frames = num_frames
self.transformer_dim = transformer_dim
def forward(self, noisy_latents, timestep, text_embedding):
"""
Predict denoised latents from noisy input.
noisy_latents: [B, T, H_patches, W_patches, patch_dim^2 * C]
timestep: diffusion timestep
text_embedding: [B, seq_len, 768]
"""
B, T, H, W, D = noisy_latents.shape
# Flatten spatial dimensions
x = noisy_latents.reshape(B, T, H * W, D)
# Patch embedding
x = self.patch_embed(x) # [B, T, H*W, transformer_dim]
# Add positional encodings
x = x + self.temporal_pos[:, :T, :, :]
x = x + self.spatial_pos[:, :, :H*W, :]
# Flatten time and space for full attention
x = x.reshape(B, T * H * W, self.transformer_dim)
# Add text conditioning
text_tokens = self.text_proj(text_embedding)
x = torch.cat([text_tokens, x], dim=1)
# Transformer
for layer in self.layers:
x = layer(x)
# Remove text tokens
text_len = text_embedding.shape[1]
x = x[:, text_len:, :]
# Output projection
x = self.output_proj(x)
# Reshape back
x = x.reshape(B, T, H, W, D)
return x
What Sora Learns About Physics
Sora demonstrates emergent physical reasoning:
- Objects fall when dropped (gravity)
- Liquids flow and splash with some realism
- Reflections on surfaces roughly obey optics
- Camera motion produces consistent parallax
But it also fails:
- Objects sometimes pass through each other
- Physics violations in long sequences (objects change size, disappear)
- Inconsistent lighting as the camera moves
- Poor understanding of cause-and-effect chains longer than 2-3 steps
Sora Physical Reasoning Assessment
| Physical Phenomenon | Short Clips (2-5s) | Medium Clips (10-20s) | Long Clips (30-60s) |
|---|---|---|---|
| Gravity (falling objects) | Mostly correct | Occasional violations | Frequent violations |
| Object permanence | Good | Moderate (objects vanish) | Poor |
| Fluid dynamics | Plausible | Approximate | Unrealistic |
| Rigid body collisions | Moderate | Poor | Very poor |
| Light/shadow consistency | Good | Moderate | Poor |
Sora does not learn Newtonian mechanics. It learns statistical regularities in video data — what scenes typically look like after a ball is thrown, not the that governs the ball’s trajectory. This distinction matters: statistical models can produce physically plausible short clips but diverge from reality over longer horizons because errors compound. True physical reasoning requires internalizing the governing equations, not just the visual patterns.
V-JEPA: Joint Embedding Prediction
Why Not Predict Pixels
Pixel-level prediction has a fundamental problem: most pixels in a video frame are redundant (background, texture details) and predicting them wastes model capacity. A model that perfectly predicts every background pixel but misses the moving ball has low reconstruction error but poor world understanding.
V-JEPA (Video Joint Embedding Predictive Architecture, LeCun/Meta) solves this by predicting in embedding space, not pixel space. The model predicts the latent representation of future frames, not the frames themselves.
The V-JEPA Architecture
where and the predictor receives a partially masked sequence of frame embeddings and must predict the embeddings of the masked (future) frames.
class VJEPAWorldModel(nn.Module):
"""
V-JEPA style world model: predict future frame embeddings
from current frame embeddings.
No pixel-level prediction -- all reasoning happens in
embedding space.
"""
def __init__(
self,
encoder_dim=768,
predictor_dim=384,
num_predictor_layers=6,
num_heads=6,
max_context_frames=16,
max_predict_frames=8,
):
super().__init__()
# The target encoder is an EMA (exponential moving average)
# of the context encoder -- not a separate model.
# Here we represent both as the same architecture.
self.context_encoder_proj = nn.Linear(
encoder_dim, predictor_dim
)
# Predictor: takes context embeddings + position tokens
# for target frames, outputs predicted target embeddings
self.mask_token = nn.Parameter(
torch.randn(1, 1, predictor_dim)
)
self.temporal_pos = nn.Parameter(
torch.randn(
1,
max_context_frames + max_predict_frames,
predictor_dim,
)
)
self.predictor = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=predictor_dim,
nhead=num_heads,
dim_feedforward=predictor_dim * 4,
batch_first=True,
),
num_layers=num_predictor_layers,
)
self.output_proj = nn.Linear(predictor_dim, encoder_dim)
def forward(
self,
context_embeddings,
target_positions,
num_context,
num_targets,
):
"""
Predict target frame embeddings from context.
context_embeddings: [B, num_context, encoder_dim]
Embeddings of visible (context) frames
target_positions: [B, num_targets]
Temporal positions of frames to predict
"""
B = context_embeddings.shape[0]
# Project context to predictor dim
context = self.context_encoder_proj(context_embeddings)
# Create mask tokens for target positions
mask_tokens = self.mask_token.expand(B, num_targets, -1)
# Combine context + mask tokens
sequence = torch.cat([context, mask_tokens], dim=1)
# Add temporal position encoding
total_len = num_context + num_targets
sequence = sequence + self.temporal_pos[:, :total_len, :]
# Predict
output = self.predictor(sequence)
# Extract predictions for target positions
target_preds = output[:, num_context:, :]
# Project back to encoder dim
target_preds = self.output_proj(target_preds)
return target_preds
Training V-JEPA
The training objective: minimize the distance between predicted embeddings and actual embeddings (from the target encoder), NOT reconstructed pixels.
where is the embedding from the target encoder (EMA of the context encoder).
def train_vjepa(
context_encoder,
target_encoder,
predictor,
video_dataset,
ema_decay=0.996,
lr=1e-4,
epochs=100,
):
"""
Train V-JEPA predictor.
context_encoder: produces embeddings for visible frames
target_encoder: EMA of context_encoder (produces targets)
predictor: predicts target embeddings from context
"""
optimizer = torch.optim.AdamW(
list(context_encoder.parameters())
+ list(predictor.parameters()),
lr=lr,
weight_decay=0.05,
)
for epoch in range(epochs):
total_loss = 0.0
num_batches = 0
for video_frames in video_dataset:
B, T, C, H, W = video_frames.shape
# Random masking: select context and target frames
num_context = T // 2
num_targets = T - num_context
# Random permutation to select context frames
perm = torch.randperm(T)
context_idx = perm[:num_context].sort().values
target_idx = perm[num_context:].sort().values
context_frames = video_frames[:, context_idx]
target_frames = video_frames[:, target_idx]
# Encode context
context_flat = context_frames.reshape(
B * num_context, C, H, W
)
context_emb = context_encoder(context_flat)
context_emb = context_emb.reshape(B, num_context, -1)
# Encode targets (with EMA encoder, no gradient)
with torch.no_grad():
target_flat = target_frames.reshape(
B * num_targets, C, H, W
)
target_emb = target_encoder(target_flat)
target_emb = target_emb.reshape(B, num_targets, -1)
# Predict target embeddings
pred_emb = predictor(
context_emb, target_idx,
num_context, num_targets,
)
# Loss: L2 distance in embedding space
loss = nn.functional.mse_loss(pred_emb, target_emb)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Update target encoder via EMA
with torch.no_grad():
for p_target, p_context in zip(
target_encoder.parameters(),
context_encoder.parameters(),
):
p_target.data.mul_(ema_decay).add_(
p_context.data, alpha=1 - ema_decay
)
total_loss += loss.item()
num_batches += 1
if epoch % 10 == 0:
avg = total_loss / max(num_batches, 1)
print(f"Epoch {epoch}: loss={avg:.4f}")
Connecting World Models to LLMs
The Multimodal Transformer
World models and LLMs converge in multimodal transformers that process both text and visual tokens in a shared attention space. The architecture:
class MultimodalWorldModelLLM(nn.Module):
"""
Multimodal transformer that combines LLM text reasoning
with world model visual prediction.
Text tokens and visual embedding tokens share the same
transformer backbone.
"""
def __init__(
self,
llm_backbone,
visual_encoder,
visual_decoder,
llm_dim=4096,
visual_dim=768,
):
super().__init__()
self.llm = llm_backbone
self.visual_encoder = visual_encoder
self.visual_decoder = visual_decoder
# Projectors between visual and LLM spaces
self.visual_to_llm = nn.Linear(visual_dim, llm_dim)
self.llm_to_visual = nn.Linear(llm_dim, visual_dim)
# Special tokens
self.visual_start_token = nn.Parameter(torch.randn(1, 1, llm_dim))
self.visual_end_token = nn.Parameter(torch.randn(1, 1, llm_dim))
def encode_visual_sequence(self, frames):
"""Encode video frames to LLM token space."""
B, T, C, H, W = frames.shape
flat = frames.reshape(B * T, C, H, W)
visual_emb = self.visual_encoder(flat)
visual_emb = visual_emb.reshape(B, T, -1)
return self.visual_to_llm(visual_emb)
def forward_with_visual_context(
self,
text_tokens,
visual_frames,
predict_next_frame=False,
):
"""
Process interleaved text and visual tokens.
The LLM sees: [visual_start, frame1, frame2, ..., frameT,
visual_end, text_token1, text_token2, ...]
"""
B = text_tokens.shape[0]
# Encode visual frames
visual_tokens = self.encode_visual_sequence(visual_frames)
# Get text embeddings
text_embeds = self.llm.embed_tokens(text_tokens)
# Interleave: visual tokens first, then text
start = self.visual_start_token.expand(B, -1, -1)
end = self.visual_end_token.expand(B, -1, -1)
combined = torch.cat(
[start, visual_tokens, end, text_embeds], dim=1
)
# Run through LLM
output = self.llm(inputs_embeds=combined)
if predict_next_frame:
# Extract the last visual token's output
last_visual_idx = 1 + visual_tokens.shape[1]
visual_output = output[:, last_visual_idx - 1, :]
# Project to visual space and decode
visual_pred = self.llm_to_visual(visual_output)
next_frame = self.visual_decoder(visual_pred)
return output, next_frame
return output
The Planning Loop
A world model + LLM enables planning: the LLM generates action plans in text, the world model simulates the outcomes, and the LLM evaluates whether the predicted outcome matches the goal.
def plan_with_world_model(
llm,
world_model,
current_observation,
goal_description,
num_candidates=10,
planning_horizon=5,
):
"""
Use LLM + world model for model-based planning.
1. LLM proposes action sequences (text)
2. World model simulates each sequence
3. LLM evaluates which simulation best matches the goal
"""
# Step 1: LLM generates candidate action sequences
prompt = (
f"Given the current scene, propose {num_candidates} "
f"different action sequences (each {planning_horizon} "
f"steps) to achieve the goal: {goal_description}\n"
f"Format each as: action1, action2, action3, ...\n"
)
candidates = llm.generate_candidates(
prompt, num_candidates=num_candidates
)
# Step 2: Simulate each candidate with the world model
simulations = []
for action_sequence in candidates:
predicted_states = world_model.imagine(
current_observation, action_sequence
)
simulations.append({
"actions": action_sequence,
"predicted_states": predicted_states,
"final_state": predicted_states[-1],
})
# Step 3: LLM evaluates which final state best matches goal
best_score = -float('inf')
best_plan = None
for sim in simulations:
eval_prompt = (
f"Goal: {goal_description}\n"
f"Does the predicted final state achieve this goal? "
f"Rate 0-10."
)
score = llm.evaluate(
eval_prompt,
visual_context=sim["final_state"],
)
if score > best_score:
best_score = score
best_plan = sim["actions"]
return best_plan, best_score
Robotics Foundation Models
From Video Understanding to Physical Interaction
The leap from video understanding to robotics requires three additional capabilities:
- Action tokenization: Convert continuous robot motor commands into discrete tokens that the transformer can process.
- Proprioceptive conditioning: Include robot joint angles, forces, and velocities as input tokens alongside visual observations.
- Safety constraints: The model must never predict actions that damage the robot or its environment.
Robotics Foundation Model Scale (2024-2026)
(Billion parameters)class RoboticsWorldModel(nn.Module):
"""
World model for robotic manipulation.
Inputs: visual observation + proprioceptive state + action
Output: predicted next visual observation + proprioceptive state
"""
def __init__(
self,
visual_dim=768,
proprioceptive_dim=14, # 7 joint angles + 7 joint velocities
action_dim=7, # 7-DOF robot arm
latent_dim=512,
hidden_dim=1024,
):
super().__init__()
# Fuse visual and proprioceptive inputs
self.visual_proj = nn.Linear(visual_dim, latent_dim)
self.proprio_proj = nn.Linear(proprioceptive_dim, latent_dim)
# Cross-modal fusion
self.fusion = nn.Sequential(
nn.Linear(latent_dim * 2 + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
# Predict next state
self.visual_predictor = nn.Linear(hidden_dim, visual_dim)
self.proprio_predictor = nn.Linear(
hidden_dim, proprioceptive_dim
)
# Safety head: predict whether action is safe
self.safety_head = nn.Sequential(
nn.Linear(hidden_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid(),
)
def forward(self, visual_emb, proprio_state, action):
"""
Predict next state from current state + action.
"""
v = self.visual_proj(visual_emb)
p = self.proprio_proj(proprio_state)
combined = torch.cat([v, p, action], dim=-1)
hidden = self.fusion(combined)
next_visual = self.visual_predictor(hidden)
next_proprio = self.proprio_predictor(hidden)
safety_score = self.safety_head(hidden)
return next_visual, next_proprio, safety_score
The Gap Between Language and Physical Reasoning
What LLMs Get Wrong
LLMs trained on text fail at physical reasoning tasks that humans find trivial:
LLM vs Human Physical Reasoning Accuracy
| Task | GPT-4 (text only) | GPT-4V (with images) | Human |
|---|---|---|---|
| Which object is heavier? (visual) | 42% | 61% | 95% |
| Will this tower of blocks fall? (image) | 38% | 55% | 89% |
| Where will the ball land? (trajectory) | 29% | 44% | 92% |
| Can this container hold this liquid? | 51% | 63% | 97% |
| Which path is shorter? (visual) | 55% | 72% | 98% |
The gap is enormous. Text-only GPT-4 achieves 29% on trajectory prediction (near random for 4-choice). Even with vision, it reaches only 44%. Humans achieve 92%. The model has textbook knowledge of parabolic trajectories but cannot apply it to a specific visual scene.
Why the Gap Exists
- No physical experience: The model has never thrown a ball, poured water, or stacked blocks. It has only read about these activities.
- No spatial grounding: Text is 1D sequential; the physical world is 3D spatial. The model cannot mentally rotate objects or visualize spatial relationships from text.
- No temporal simulation: Physical reasoning requires running a forward simulation (“what happens next?”). Text models process the entire sequence at once — they do not naturally unroll time.
Closing the Gap
World models are the missing piece. By training on video data where physics plays out visually, the model builds implicit physical intuitions that text alone cannot provide. The research frontier is unifying these visual intuitions with LLM reasoning:
World models and LLMs are solving complementary problems. LLMs excel at abstract reasoning, planning, and communication. World models excel at predicting physical dynamics and spatial relationships. Embodied AI requires both: the LLM formulates plans in language, the world model simulates their physical consequences, and the combined system acts in the real world. The Sora/V-JEPA line of research is building the world model half of this equation. The open question is how to fuse these capabilities into a single system that reasons about both language and physics with human-level competence.