Language models generate text one token at a time. Can the same approach generate images? Yes โ if you tokenize images first. A 256x256 image can be compressed into a sequence of 256 discrete tokens by a VQ-VAE (Vector Quantized Variational Autoencoder). An autoregressive transformer then predicts these image tokens one at a time, conditioned on a text prompt. The generated token sequence is decoded back into pixels. This is conceptually identical to text generation: predict the next token given the previous tokens.
The alternative is diffusion: start with pure noise and iteratively denoise it into an image, guided by a text embedding. Diffusion models (Stable Diffusion, DALL-E 3, Imagen) currently produce higher-quality images than autoregressive approaches, but autoregressive models are catching up (Parti, Chameleon, Emu3) and have one fundamental advantage: they can be unified with text generation in a single model. A model that both understands and generates text, images, and video from a single architecture and training process.
This post covers both paradigms โ diffusion and autoregressive โ and the unified architectures that combine understanding and generation across modalities.
Image Tokenization
Why Tokenize Images
To generate images with a language model, images must be represented as discrete token sequences. A 256x256 RGB image has 196,608 raw pixel values โ too long for a transformer sequence. Tokenization compresses this to a few hundred discrete tokens while preserving enough information to reconstruct the image.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class VectorQuantizer(nn.Module):
"""
Vector Quantization layer.
Maps continuous encoder outputs to discrete codebook entries.
Given encoder output z of shape (B, D, H, W):
1. For each spatial position, find the nearest codebook entry
2. Replace the continuous vector with the codebook entry
3. Pass gradients through via straight-through estimator
The codebook has K entries, each of dimension D.
"""
def __init__(self, n_codes=8192, code_dim=256,
commitment_cost=0.25):
super().__init__()
self.n_codes = n_codes
self.code_dim = code_dim
self.commitment_cost = commitment_cost
# Codebook: K vectors of dimension D
self.codebook = nn.Embedding(n_codes, code_dim)
# Initialize uniformly
self.codebook.weight.data.uniform_(
-1.0 / n_codes, 1.0 / n_codes
)
def forward(self, z):
"""
Quantize encoder output.
z: (B, D, H, W) -> continuous encoder features
Returns: quantized (B, D, H, W), loss, indices (B, H, W)
"""
B, D, H, W = z.shape
# Reshape: (B, D, H, W) -> (B*H*W, D)
z_flat = z.permute(0, 2, 3, 1).contiguous().view(-1, D)
# Compute distances to all codebook entries
# d(z, e) = ||z||^2 + ||e||^2 - 2 * z^T * e
d = (
z_flat.pow(2).sum(dim=1, keepdim=True)
+ self.codebook.weight.pow(2).sum(dim=1)
- 2 * z_flat @ self.codebook.weight.t()
)
# Find nearest codebook entry for each position
indices = d.argmin(dim=1) # (B*H*W,)
# Look up quantized vectors
z_q = self.codebook(indices) # (B*H*W, D)
# Reshape back
z_q = z_q.view(B, H, W, D).permute(0, 3, 1, 2)
indices = indices.view(B, H, W)
# Compute losses
# Codebook loss: move codebook entries toward encoder output
codebook_loss = F.mse_loss(z_q.detach(), z)
# Commitment loss: encourage encoder to commit to entries
commitment_loss = F.mse_loss(z_q, z.detach())
loss = codebook_loss + self.commitment_cost * commitment_loss
# Straight-through estimator: copy gradients from z_q to z
z_q = z + (z_q - z).detach()
return z_q, loss, indices
def decode_indices(self, indices):
"""
Convert token indices back to continuous vectors.
indices: (B, H, W) -> z_q: (B, D, H, W)
"""
B, H, W = indices.shape
z_q = self.codebook(indices.view(-1))
z_q = z_q.view(B, H, W, self.code_dim).permute(0, 3, 1, 2)
return z_q
Complete VQ-VAE Image Tokenizer
class ResidualBlock(nn.Module):
"""Residual block for encoder/decoder."""
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.norm1 = nn.GroupNorm(32, channels)
self.norm2 = nn.GroupNorm(32, channels)
def forward(self, x):
residual = x
x = self.norm1(x)
x = F.silu(x)
x = self.conv1(x)
x = self.norm2(x)
x = F.silu(x)
x = self.conv2(x)
return x + residual
class ImageTokenizer(nn.Module):
"""
VQ-VAE Image Tokenizer.
Encodes a 256x256 image into a 16x16 grid of discrete tokens
(256 tokens total). Each token is an index into a codebook
of 8192 entries.
The compression ratio: 256*256*3 = 196,608 values
compressed to 256 tokens = 768x compression.
Architecture:
- Encoder: ConvNet that downsamples 256x256 -> 16x16
- Vector Quantizer: maps 16x16 continuous features to
discrete tokens
- Decoder: ConvNet that upsamples 16x16 -> 256x256
"""
def __init__(self, in_channels=3, hidden_dim=256,
n_codes=8192, code_dim=256,
n_downsample=4):
super().__init__()
self.n_downsample = n_downsample
# 4 downsamples: 256 -> 128 -> 64 -> 32 -> 16
# Encoder
encoder_layers = [
nn.Conv2d(in_channels, hidden_dim, 3, padding=1),
]
for i in range(n_downsample):
encoder_layers.extend([
ResidualBlock(hidden_dim),
ResidualBlock(hidden_dim),
nn.Conv2d(hidden_dim, hidden_dim, 4, stride=2,
padding=1),
])
encoder_layers.extend([
ResidualBlock(hidden_dim),
nn.Conv2d(hidden_dim, code_dim, 1),
])
self.encoder = nn.Sequential(*encoder_layers)
# Vector Quantizer
self.quantizer = VectorQuantizer(
n_codes=n_codes, code_dim=code_dim
)
# Decoder
decoder_layers = [
nn.Conv2d(code_dim, hidden_dim, 1),
ResidualBlock(hidden_dim),
]
for i in range(n_downsample):
decoder_layers.extend([
nn.ConvTranspose2d(hidden_dim, hidden_dim, 4,
stride=2, padding=1),
ResidualBlock(hidden_dim),
ResidualBlock(hidden_dim),
])
decoder_layers.append(
nn.Conv2d(hidden_dim, in_channels, 3, padding=1)
)
self.decoder = nn.Sequential(*decoder_layers)
def encode(self, x):
"""
Encode an image to discrete tokens.
x: (B, 3, 256, 256) -> indices: (B, 16, 16)
"""
z = self.encoder(x) # (B, code_dim, 16, 16)
z_q, vq_loss, indices = self.quantizer(z)
return indices, vq_loss
def decode(self, indices):
"""
Decode discrete tokens back to an image.
indices: (B, 16, 16) -> x_recon: (B, 3, 256, 256)
"""
z_q = self.quantizer.decode_indices(indices)
x_recon = self.decoder(z_q)
return x_recon
def forward(self, x):
"""Full encode-quantize-decode pass."""
z = self.encoder(x)
z_q, vq_loss, indices = self.quantizer(z)
x_recon = self.decoder(z_q)
# Reconstruction loss
recon_loss = F.mse_loss(x_recon, x)
return {
"reconstruction": x_recon,
"indices": indices,
"loss": recon_loss + vq_loss,
"recon_loss": recon_loss.item(),
"vq_loss": vq_loss.item(),
}
def tokens_to_image(self, token_sequence):
"""
Convert a flat token sequence to an image.
token_sequence: (B, 256) -> image: (B, 3, 256, 256)
This is what the autoregressive model produces:
a flat sequence of 256 token IDs.
"""
B = token_sequence.shape[0]
grid_size = int(np.sqrt(token_sequence.shape[1]))
indices = token_sequence.view(B, grid_size, grid_size)
return self.decode(indices)
Image Tokenizer Comparison
| Tokenizer | Codebook Size | Tokens (256x256) | rFID | Params |
|---|---|---|---|---|
| VQGAN (Esser 2021) | 16384 | 256 (16x16) | 7.94 | 72M |
| DALL-E dVAE | 8192 | 1024 (32x32) | 32.0 | 40M |
| LlamaGen tokenizer | 16384 | 256 (16x16) | 2.19 | 72M |
| Cosmos tokenizer (NVIDIA) | 65536 | 256 (16x16) | 1.12 | 150M |
| Open-MAGVIT2 | 262144 | 256 (16x16) | 1.17 | 200M |
Autoregressive Image Generation
Generating Images Token by Token
Once images are tokenized, an autoregressive transformer generates image tokens conditioned on text tokens. The architecture is identical to a language model: predict the next token given all previous tokens. The input sequence is [text_tokens] [image_tokens], and the model generates image tokens one at a time.
class AutoregressiveImageGenerator:
"""
Generate images using an autoregressive transformer.
Input: text prompt (tokenized)
Output: sequence of image tokens
The model is a standard decoder-only transformer
trained on interleaved text-image sequences:
[BOS] text tokens [IMG_START] image tokens [IMG_END]
"""
def __init__(self, transformer, text_tokenizer,
image_tokenizer, image_token_offset=32000):
self.transformer = transformer
self.text_tokenizer = text_tokenizer
self.image_tokenizer = image_tokenizer
# Image tokens are offset to avoid collision with
# text vocabulary
self.image_token_offset = image_token_offset
self.img_start_token = image_token_offset - 2
self.img_end_token = image_token_offset - 1
self.n_image_tokens = 256 # 16x16 grid
def generate(self, text_prompt, temperature=1.0,
top_k=256, top_p=0.95, cfg_scale=5.0):
"""
Generate an image from a text prompt.
Uses classifier-free guidance (CFG):
logits = logits_uncond + cfg_scale * (logits_cond - logits_uncond)
"""
# Tokenize text prompt
text_tokens = self.text_tokenizer.encode(text_prompt)
prefix = torch.tensor(
text_tokens + [self.img_start_token]
).unsqueeze(0)
# For CFG: also prepare unconditional prefix
# (empty text + img_start)
uncond_prefix = torch.tensor(
[self.img_start_token]
).unsqueeze(0)
generated_tokens = []
for i in range(self.n_image_tokens):
# Conditional logits
cond_input = torch.cat([
prefix,
torch.tensor(generated_tokens).unsqueeze(0)
], dim=1) if generated_tokens else prefix
cond_logits = self.transformer(cond_input).logits[:, -1]
# Unconditional logits (for CFG)
uncond_input = torch.cat([
uncond_prefix,
torch.tensor(generated_tokens).unsqueeze(0)
], dim=1) if generated_tokens else uncond_prefix
uncond_logits = self.transformer(uncond_input).logits[:, -1]
# Classifier-free guidance
logits = (
uncond_logits
+ cfg_scale * (cond_logits - uncond_logits)
)
# Only consider image tokens
img_logits = logits[:, self.image_token_offset:
self.image_token_offset
+ self.image_tokenizer.quantizer.n_codes]
# Temperature scaling
img_logits = img_logits / temperature
# Top-k filtering
if top_k > 0:
top_k_vals, _ = torch.topk(img_logits, top_k)
threshold = top_k_vals[:, -1].unsqueeze(-1)
img_logits[img_logits < threshold] = float("-inf")
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(
img_logits, descending=True
)
cumulative_probs = torch.cumsum(
F.softmax(sorted_logits, dim=-1), dim=-1
)
mask = cumulative_probs > top_p
mask[:, 1:] = mask[:, :-1].clone()
mask[:, 0] = False
indices_to_remove = sorted_indices[mask]
img_logits[:, indices_to_remove] = float("-inf")
# Sample
probs = F.softmax(img_logits, dim=-1)
next_token = torch.multinomial(probs, 1).item()
generated_tokens.append(
next_token + self.image_token_offset
)
# Decode tokens to image
image_indices = torch.tensor(generated_tokens).unsqueeze(0)
image = self.image_tokenizer.tokens_to_image(image_indices)
return {
"image": image,
"tokens": generated_tokens,
"n_tokens": len(generated_tokens),
}
Autoregressive image generation produces tokens sequentially: 256 forward passes for a 16x16 grid. At 20ms per forward pass on an H100, that is 5.1 seconds per image. Diffusion models use 20-50 denoising steps, each processing the full image in parallel, taking 2-5 seconds on the same hardware. The speed gap is narrowing as speculative decoding and parallel token prediction are applied to image generation.
Diffusion Models
The Denoising Process
Diffusion models generate images by learning to reverse a noise-adding process. The forward process gradually adds Gaussian noise to a clean image over steps until it becomes pure noise. The reverse process learns to denoise: given a noisy image at step , predict the clean image (or equivalently, predict the noise that was added).
The noise schedule defines how much noise is added at each step. At step , the noisy image is:
where and .
class SimpleDiffusion:
"""
Minimal diffusion model for image generation.
Implements DDPM (Denoising Diffusion Probabilistic Models).
"""
def __init__(self, model, n_steps=1000,
beta_start=0.0001, beta_end=0.02):
self.model = model # U-Net that predicts noise
self.n_steps = n_steps
# Linear noise schedule
self.betas = torch.linspace(beta_start, beta_end, n_steps)
self.alphas = 1.0 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
def forward_process(self, x_0, t):
"""
Add noise to x_0 at timestep t.
Returns noisy image and the noise that was added.
"""
noise = torch.randn_like(x_0)
alpha_bar_t = self.alpha_bars[t].view(-1, 1, 1, 1)
x_t = (
torch.sqrt(alpha_bar_t) * x_0
+ torch.sqrt(1 - alpha_bar_t) * noise
)
return x_t, noise
def train_step(self, x_0):
"""
One training step: sample a random timestep,
add noise, predict the noise, compute loss.
"""
B = x_0.shape[0]
# Random timestep for each example
t = torch.randint(0, self.n_steps, (B,))
# Add noise
x_t, noise = self.forward_process(x_0, t)
# Predict noise
noise_pred = self.model(x_t, t)
# Loss: MSE between predicted and actual noise
loss = F.mse_loss(noise_pred, noise)
return loss
@torch.no_grad()
def sample(self, shape, text_embedding=None, cfg_scale=7.5):
"""
Generate an image by iterative denoising.
Start from pure noise and denoise step by step.
"""
# Start from pure noise
x = torch.randn(shape)
for t in reversed(range(self.n_steps)):
t_batch = torch.full(
(shape[0],), t, dtype=torch.long
)
# Predict noise
if text_embedding is not None and cfg_scale > 1.0:
# Classifier-free guidance
noise_cond = self.model(
x, t_batch, text_embedding
)
noise_uncond = self.model(
x, t_batch, None
)
noise_pred = (
noise_uncond
+ cfg_scale * (noise_cond - noise_uncond)
)
else:
noise_pred = self.model(x, t_batch, text_embedding)
# Denoise: reverse one step
alpha_t = self.alphas[t]
alpha_bar_t = self.alpha_bars[t]
alpha_bar_prev = (
self.alpha_bars[t - 1] if t > 0
else torch.tensor(1.0)
)
# Predicted x_0
x_0_pred = (
x - torch.sqrt(1 - alpha_bar_t) * noise_pred
) / torch.sqrt(alpha_bar_t)
# Clamp for stability
x_0_pred = torch.clamp(x_0_pred, -1.0, 1.0)
# Compute x_{t-1}
beta_t = self.betas[t]
posterior_mean = (
torch.sqrt(alpha_bar_prev) * beta_t
/ (1 - alpha_bar_t) * x_0_pred
+ torch.sqrt(alpha_t)
* (1 - alpha_bar_prev)
/ (1 - alpha_bar_t) * x
)
if t > 0:
noise = torch.randn_like(x)
posterior_var = (
beta_t * (1 - alpha_bar_prev)
/ (1 - alpha_bar_t)
)
x = posterior_mean + torch.sqrt(posterior_var) * noise
else:
x = posterior_mean
return x
Latent Diffusion (Stable Diffusion Architecture)
Running diffusion in pixel space is expensive: a 512x512 image has 786,432 dimensions. Latent diffusion runs the diffusion process in a compressed latent space (typically 64x64x4 = 16,384 dimensions), reducing compute by 48x.
class LatentDiffusion:
"""
Latent Diffusion Model (Stable Diffusion architecture).
Instead of denoising in pixel space (512x512x3),
denoise in latent space (64x64x4).
Pipeline:
1. VAE encoder compresses image to latent
2. Diffusion operates in latent space
3. VAE decoder reconstructs image from denoised latent
Text conditioning uses CLIP text encoder to produce
embeddings that guide the denoising process via
cross-attention.
"""
def __init__(self, vae, unet, text_encoder, diffusion):
self.vae = vae # VAE for image compression
self.unet = unet # U-Net for denoising
self.text_encoder = text_encoder # CLIP text encoder
self.diffusion = diffusion
def encode_image(self, image):
"""Compress image to latent space."""
# image: (B, 3, 512, 512) -> latent: (B, 4, 64, 64)
latent = self.vae.encode(image)
# Scale by VAE constant
return latent * 0.18215
def decode_latent(self, latent):
"""Decompress latent to image."""
latent = latent / 0.18215
return self.vae.decode(latent)
def encode_text(self, text_tokens):
"""Encode text to CLIP embeddings."""
# Returns (B, seq_len, 768) for CLIP ViT-L
return self.text_encoder(text_tokens)
@torch.no_grad()
def generate(self, prompt_embeddings, n_steps=50,
cfg_scale=7.5, height=512, width=512):
"""
Generate an image from text embeddings.
"""
latent_h = height // 8 # VAE downsamples 8x
latent_w = width // 8
# Start from noise in latent space
latent = torch.randn(
(1, 4, latent_h, latent_w)
)
# Denoising loop (using fewer steps with DDIM)
timesteps = self._get_ddim_timesteps(n_steps)
for i, t in enumerate(timesteps):
t_batch = torch.tensor([t])
# Conditional prediction
noise_cond = self.unet(
latent, t_batch, prompt_embeddings
)
# Unconditional prediction (for CFG)
null_embeddings = torch.zeros_like(prompt_embeddings)
noise_uncond = self.unet(
latent, t_batch, null_embeddings
)
# Apply CFG
noise_pred = (
noise_uncond
+ cfg_scale * (noise_cond - noise_uncond)
)
# DDIM step (deterministic, allows fewer steps)
latent = self._ddim_step(latent, noise_pred, t,
timesteps[i + 1]
if i + 1 < len(timesteps)
else 0)
# Decode latent to image
image = self.decode_latent(latent)
return image
def _get_ddim_timesteps(self, n_steps):
"""Get evenly spaced timesteps for DDIM sampling."""
step_size = self.diffusion.n_steps // n_steps
return list(range(
self.diffusion.n_steps - 1, 0, -step_size
))[:n_steps]
def _ddim_step(self, x_t, noise_pred, t, t_prev):
"""DDIM deterministic sampling step."""
alpha_bar_t = self.diffusion.alpha_bars[t]
alpha_bar_prev = (
self.diffusion.alpha_bars[t_prev]
if t_prev > 0 else torch.tensor(1.0)
)
# Predicted x_0
x_0_pred = (
x_t - torch.sqrt(1 - alpha_bar_t) * noise_pred
) / torch.sqrt(alpha_bar_t)
# Direction pointing to x_t
dir_xt = torch.sqrt(1 - alpha_bar_prev) * noise_pred
# DDIM update (deterministic)
x_prev = torch.sqrt(alpha_bar_prev) * x_0_pred + dir_xt
return x_prev
Image Generation Quality: Diffusion vs Autoregressive
| Metric | DALL-E 2 | SD 1.5 | SDXL | SD 3 | Parti | LlamaGen | Chameleon | Emu3 |
|---|---|---|---|---|---|---|---|---|
| FID on COCO-30K |
Text-to-Video Generation
Video as a Sequence of Frames
Video generation extends image generation to the temporal dimension. A 4-second video at 8 FPS is 32 frames. Each frame is tokenized (or encoded to latent space), producing a 3D grid of tokens: spatial (H x W) times temporal (T).
class VideoTokenizer(nn.Module):
"""
Tokenize video into discrete tokens.
A video of T frames at resolution HxW is encoded into
a (T', H', W') grid of tokens where T' = T/t_stride,
H' = H/s_stride, W' = W/s_stride.
For T=32, H=W=256, t_stride=4, s_stride=16:
Token grid: 8 x 16 x 16 = 2048 tokens per video.
"""
def __init__(self, image_tokenizer, temporal_stride=4):
super().__init__()
self.image_tokenizer = image_tokenizer
self.temporal_stride = temporal_stride
# Temporal compression: 3D convolutions
code_dim = image_tokenizer.quantizer.code_dim
self.temporal_encoder = nn.Sequential(
nn.Conv3d(code_dim, code_dim, (temporal_stride, 1, 1),
stride=(temporal_stride, 1, 1)),
nn.GroupNorm(32, code_dim),
nn.SiLU(),
nn.Conv3d(code_dim, code_dim, (3, 1, 1), padding=(1, 0, 0)),
)
self.temporal_decoder = nn.Sequential(
nn.ConvTranspose3d(code_dim, code_dim,
(temporal_stride, 1, 1),
stride=(temporal_stride, 1, 1)),
nn.GroupNorm(32, code_dim),
nn.SiLU(),
nn.Conv3d(code_dim, code_dim, (3, 1, 1), padding=(1, 0, 0)),
)
# Shared quantizer with image tokenizer
self.quantizer = image_tokenizer.quantizer
def encode(self, video):
"""
Encode a video to discrete tokens.
video: (B, T, 3, H, W) -> indices: (B, T', H', W')
"""
B, T, C, H, W = video.shape
# Encode each frame spatially
frames = video.view(B * T, C, H, W)
z = self.image_tokenizer.encoder(frames)
# Reshape for temporal processing
_, D, Hp, Wp = z.shape
z = z.view(B, T, D, Hp, Wp).permute(0, 2, 1, 3, 4)
# z: (B, D, T, H', W')
# Temporal compression
z = self.temporal_encoder(z) # (B, D, T', H', W')
# Quantize
Tp = z.shape[2]
z_flat = z.permute(0, 2, 3, 4, 1).contiguous()
z_flat = z_flat.view(-1, D)
# Find nearest codebook entries
d = (
z_flat.pow(2).sum(dim=1, keepdim=True)
+ self.quantizer.codebook.weight.pow(2).sum(dim=1)
- 2 * z_flat @ self.quantizer.codebook.weight.t()
)
indices = d.argmin(dim=1).view(B, Tp, Hp, Wp)
return indices
def decode(self, indices):
"""
Decode discrete tokens back to video.
indices: (B, T', H', W') -> video: (B, T, 3, H, W)
"""
B, Tp, Hp, Wp = indices.shape
D = self.quantizer.code_dim
# Look up codebook
z_q = self.quantizer.codebook(indices.view(-1))
z_q = z_q.view(B, Tp, Hp, Wp, D).permute(0, 4, 1, 2, 3)
# Temporal decompression
z = self.temporal_decoder(z_q) # (B, D, T, H', W')
T = z.shape[2]
# Decode each frame spatially
z = z.permute(0, 2, 1, 3, 4).contiguous()
z = z.view(B * T, D, Hp, Wp)
frames = self.image_tokenizer.decoder(z)
_, C, H, W = frames.shape
video = frames.view(B, T, C, H, W)
return video
Unified Multimodal Architectures
Understanding vs Generation
Most multimodal models are understanding-only: they take images as input and produce text as output (LLaVA, Qwen-VL). They use a vision encoder to convert images to continuous embeddings, which are fed to the LLM. These models cannot generate images because they have no image decoder and no image token vocabulary.
Unified models both understand and generate across modalities. There are two approaches:
@dataclass
class UnifiedArchitecture:
name: str
approach: str
modalities: list
understands: list
generates: list
key_innovation: str
UNIFIED_ARCHITECTURES = [
UnifiedArchitecture(
name="Chameleon (Meta)",
approach="All tokens in one vocabulary",
modalities=["text", "image"],
understands=["text", "image"],
generates=["text", "image"],
key_innovation=(
"Unified tokenizer: text BPE (65536 tokens) + "
"image VQ (8192 tokens) = 73728 total vocabulary. "
"Single transformer processes everything. "
"Images and text are interleaved freely."
),
),
UnifiedArchitecture(
name="Gemini (Google)",
approach="Native multimodal from pretraining",
modalities=["text", "image", "video", "audio"],
understands=["text", "image", "video", "audio"],
generates=["text", "image"],
key_innovation=(
"Trained from scratch on interleaved multimodal "
"data. Uses SoundStorm for audio and Imagen "
"for image generation. Not a single decoder "
"for all modalities."
),
),
UnifiedArchitecture(
name="Emu3 (BAAI)",
approach="Predict next visual token",
modalities=["text", "image", "video"],
understands=["text", "image"],
generates=["text", "image", "video"],
key_innovation=(
"Pure next-token prediction for all modalities. "
"SBER-MoVQGAN tokenizer for images/video. "
"No diffusion -- autoregressive only. "
"Matches diffusion quality on generation benchmarks."
),
),
UnifiedArchitecture(
name="Transfusion (Meta)",
approach="Hybrid: autoregressive text + diffusion images",
modalities=["text", "image"],
understands=["text", "image"],
generates=["text", "image"],
key_innovation=(
"Single transformer with dual training objectives. "
"Text tokens: next-token prediction loss. "
"Image tokens: diffusion denoising loss. "
"Best of both worlds: LM for text, diffusion quality "
"for images."
),
),
]
Implementing a Unified Token Space
class UnifiedTokenizer:
"""
Unified tokenizer that handles both text and images
in a single vocabulary.
Vocabulary layout:
[0, 32000) - Text tokens (BPE)
[32000, 32002) - Special tokens (IMG_START, IMG_END)
[32002, 40194) - Image tokens (VQ codebook, 8192 entries)
[40194, 40196) - Special tokens (VID_START, VID_END)
"""
def __init__(self, text_tokenizer, image_tokenizer):
self.text_tokenizer = text_tokenizer
self.image_tokenizer = image_tokenizer
# Vocabulary offsets
self.text_vocab_size = 32000
self.img_start_id = 32000
self.img_end_id = 32001
self.image_offset = 32002
self.image_vocab_size = image_tokenizer.quantizer.n_codes
self.vid_start_id = self.image_offset + self.image_vocab_size
self.vid_end_id = self.vid_start_id + 1
self.total_vocab_size = self.vid_end_id + 1
def encode_text(self, text):
"""Encode text to token IDs."""
return self.text_tokenizer.encode(text)
def encode_image(self, image):
"""
Encode image to token IDs in the unified vocabulary.
Returns: [IMG_START] + image_tokens + [IMG_END]
"""
indices, _ = self.image_tokenizer.encode(image)
# Flatten 2D grid to 1D sequence (raster scan order)
flat_indices = indices.flatten().tolist()
# Offset into unified vocabulary
image_tokens = [idx + self.image_offset for idx in flat_indices]
return [self.img_start_id] + image_tokens + [self.img_end_id]
def encode_interleaved(self, items):
"""
Encode an interleaved sequence of text and images.
items: list of {"type": "text", "content": "..."} or
{"type": "image", "content": tensor}
"""
token_ids = []
for item in items:
if item["type"] == "text":
token_ids.extend(self.encode_text(item["content"]))
elif item["type"] == "image":
token_ids.extend(self.encode_image(item["content"]))
return token_ids
def decode_tokens(self, token_ids):
"""
Decode a token sequence back to text and images.
Returns a list of {"type": ..., "content": ...}
"""
results = []
current_text_tokens = []
current_image_tokens = []
in_image = False
for token_id in token_ids:
if token_id == self.img_start_id:
# Flush text
if current_text_tokens:
text = self.text_tokenizer.decode(
current_text_tokens
)
results.append({"type": "text", "content": text})
current_text_tokens = []
in_image = True
continue
elif token_id == self.img_end_id:
# Decode image
if current_image_tokens:
indices = torch.tensor([
t - self.image_offset
for t in current_image_tokens
]).unsqueeze(0)
image = self.image_tokenizer.tokens_to_image(
indices
)
results.append({"type": "image", "content": image})
current_image_tokens = []
in_image = False
continue
if in_image:
current_image_tokens.append(token_id)
else:
current_text_tokens.append(token_id)
# Flush remaining text
if current_text_tokens:
text = self.text_tokenizer.decode(current_text_tokens)
results.append({"type": "text", "content": text})
return results
def get_token_type(self, token_id):
"""Determine the type of a token."""
if token_id < self.text_vocab_size:
return "text"
elif token_id == self.img_start_id:
return "img_start"
elif token_id == self.img_end_id:
return "img_end"
elif token_id < self.vid_start_id:
return "image"
else:
return "special"
Training Unified Models
Interleaved Training Data
Unified models are trained on interleaved sequences of text and images. The key challenge is balancing the two modalities: too much text and the model forgets how to generate images, too many images and text quality degrades.
class InterleavedDataLoader:
"""
Data loader for interleaved text-image training.
Data sources:
1. Text-only: books, web text (like standard LLM pretraining)
2. Image-text pairs: LAION, CC12M (captioned images)
3. Interleaved: web pages with inline images (MMC4, OBELICS)
4. Image-only: ImageNet (image generation without text)
"""
def __init__(self, tokenizer, sequence_length=4096):
self.tokenizer = tokenizer
self.sequence_length = sequence_length
# Mixing ratios (fraction of batches)
self.mix_ratios = {
"text_only": 0.50, # Standard text pretraining
"image_text_pairs": 0.25, # Captioned images
"interleaved": 0.15, # Web pages with images
"image_only": 0.10, # Image generation training
}
def create_text_only_example(self, text):
"""Standard text pretraining example."""
tokens = self.tokenizer.encode_text(text)
return self._truncate_and_pad(tokens)
def create_image_text_pair(self, image, caption):
"""Image-caption pair: [caption] [IMG_START] ... [IMG_END]."""
text_tokens = self.tokenizer.encode_text(caption)
image_tokens = self.tokenizer.encode_image(image)
tokens = text_tokens + image_tokens
return self._truncate_and_pad(tokens)
def create_interleaved_example(self, items):
"""
Interleaved text-image sequence from a web page.
items: list of text spans and images in document order.
"""
tokens = self.tokenizer.encode_interleaved(items)
return self._truncate_and_pad(tokens)
def create_image_only_example(self, image):
"""Image-only: [IMG_START] ... [IMG_END]."""
tokens = self.tokenizer.encode_image(image)
return self._truncate_and_pad(tokens)
def _truncate_and_pad(self, tokens):
"""Truncate or pad to sequence_length."""
if len(tokens) > self.sequence_length:
tokens = tokens[:self.sequence_length]
else:
tokens = tokens + [0] * (self.sequence_length - len(tokens))
return torch.tensor(tokens)
Unified Model Comparison
| Model | Text Quality (MMLU) | Image Gen (FID) | Image Understanding | Approach |
|---|---|---|---|---|
| Chameleon 34B | 55.4 | 6.80 | 67.2% (VQAv2) | All autoregressive |
| Emu3 8B | N/A | 4.20 | 72.1% (VQAv2) | All autoregressive |
| Gemini 1.5 Pro | 81.9 | ~5.0 (est.) | 82.4% (VQAv2) | Mixed (AR + diffusion) |
| Transfusion 7B | ~52 (est.) | ~5.5 (est.) | 68% (est.) | Hybrid loss |
| GPT-4o (for reference) | 86.4 | N/A (not AR) | 77.2% (VQAv2) | Understanding only + DALL-E 3 |
Key Takeaways
Multimodal generation is converging toward unified architectures where a single model both understands and generates across text, images, and video.
The technical landscape:
-
Image tokenization is the enabler: VQ-VAE compresses a 256x256 image from 196,608 values to 256 tokens. Codebook quality (measured by reconstruction FID) directly determines generation quality. Modern tokenizers (Cosmos, Open-MAGVIT2) achieve near-lossless reconstruction with 64K-256K codebooks.
-
Autoregressive vs diffusion: Diffusion currently produces higher-quality images (lower FID), but autoregressive models are closing the gap. The advantage of autoregressive: it unifies with text generation naturally. The advantage of diffusion: parallel denoising is faster than sequential token generation.
-
Unified models trade quality for generality: Chameleon and Emu3 can both understand and generate images, but their text quality (MMLU) and image quality (FID) are each lower than specialized models. The gap is shrinking with scale: larger unified models approach specialized model quality.
-
Video multiplies the challenge: A 4-second 256x256 video at 8 FPS has 32 frames, producing 2048-8192 tokens. This is 8-32x more tokens than a single image, proportionally increasing compute. Temporal compression (3D convolutions) and frame interpolation reduce the token count.
-
Interleaved training data matters: Web pages with inline images (MMC4, OBELICS) teach the model the relationship between text and images in context. Caption-image pairs alone teach correspondence but not the more complex text-image interleaving patterns that humans produce.
The compute equation: generating a 256x256 image autoregressively requires 256 forward passes through the transformer. At the same model size, this is equivalent to generating 256 text tokens. For a 7B parameter model on an H100, that is approximately per image. Diffusion with 50 steps on the same hardware takes approximately (each step processes the full latent in parallel). The latency gap is modest; the quality gap is what drives architecture choice.