Multimodal LLMs must combine information from different modalities (text, images, video, audio) into a single representation that the transformer can reason over. The architectural choice of HOW to combine them — the fusion strategy — determines both quality and inference cost. This post covers the three dominant approaches with implementation code.
Late Fusion (Projector-Based)
The simplest and most common approach (LLaVA, Llama 3.2 Vision):
- Encode each modality with its own encoder (ViT for images, Whisper for audio)
- Project encoder outputs to LLM’s embedding dimension via a learned linear layer
- Concatenate projected embeddings with text token embeddings
- Feed the combined sequence to the LLM as if it were all text
import torch
import torch.nn as nn
class LateFusionVLM(nn.Module):
"""Late fusion: encode image separately, project, concatenate with text."""
def __init__(self, llm, vit_encoder, vit_dim=1024, llm_dim=4096):
super().__init__()
self.llm = llm
self.vit = vit_encoder
# Simple MLP projector: ViT dim -> LLM dim
self.projector = nn.Sequential(
nn.Linear(vit_dim, llm_dim),
nn.GELU(),
nn.Linear(llm_dim, llm_dim),
)
def forward(self, text_ids, images, image_positions):
# Step 1: Encode images
with torch.no_grad(): # ViT often frozen
visual_features = self.vit(images) # [B, num_patches, vit_dim]
# Step 2: Project to LLM dimension
visual_tokens = self.projector(visual_features) # [B, num_patches, llm_dim]
# Step 3: Get text embeddings
text_embeds = self.llm.embed_tokens(text_ids) # [B, seq_len, llm_dim]
# Step 4: Insert visual tokens at image_positions
combined = insert_at_positions(text_embeds, visual_tokens, image_positions)
# Step 5: Standard LLM forward on combined sequence
return self.llm(inputs_embeds=combined)
Advantage: Simple. The LLM doesn’t need any architectural changes — it just sees a longer sequence of embeddings. Any text LLM can be converted to a VLM by adding an encoder + projector.
Disadvantage: The LLM has no special mechanism for attending to visual information differently from text. All cross-modal reasoning happens through standard self-attention. This works but may not capture fine-grained visual details.
Cross-Attention Fusion
The LLM has dedicated cross-attention layers that attend to encoder outputs (Flamingo, Qwen-VL):
class CrossAttentionLayer(nn.Module):
"""Cross-attention: LLM queries attend to visual encoder outputs."""
def __init__(self, llm_dim=4096, encoder_dim=1024, n_heads=32):
super().__init__()
self.q_proj = nn.Linear(llm_dim, llm_dim) # Query from LLM hidden states
self.k_proj = nn.Linear(encoder_dim, llm_dim) # Key from encoder outputs
self.v_proj = nn.Linear(encoder_dim, llm_dim) # Value from encoder outputs
self.o_proj = nn.Linear(llm_dim, llm_dim)
self.n_heads = n_heads
self.d_head = llm_dim // n_heads
def forward(self, llm_hidden, encoder_output):
"""
llm_hidden: [B, seq_len, llm_dim] — current LLM hidden states
encoder_output: [B, num_patches, encoder_dim] — visual features
"""
B, S, D = llm_hidden.shape
_, P, _ = encoder_output.shape
Q = self.q_proj(llm_hidden).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
K = self.k_proj(encoder_output).view(B, P, self.n_heads, self.d_head).transpose(1, 2)
V = self.v_proj(encoder_output).view(B, P, self.n_heads, self.d_head).transpose(1, 2)
# Cross-attention: LLM tokens attend to visual patches
scores = (Q @ K.transpose(-1, -2)) / (self.d_head ** 0.5) # [B, H, S, P]
weights = torch.softmax(scores, dim=-1)
output = (weights @ V).transpose(1, 2).contiguous().view(B, S, D)
return self.o_proj(output)
Cross-attention is inserted every N layers (e.g., every 4th layer in Flamingo):
class CrossAttentionTransformerBlock(nn.Module):
def __init__(self, llm_layer, cross_attn_layer):
super().__init__()
self.self_attn = llm_layer # Standard self-attention
self.cross_attn = cross_attn_layer # Visual cross-attention
self.gate = nn.Parameter(torch.tensor(0.0)) # Learnable gate
def forward(self, x, encoder_output):
# Self-attention (standard)
x = x + self.self_attn(x)
# Cross-attention (gated — starts at zero, gradually activates)
x = x + torch.tanh(self.gate) * self.cross_attn(x, encoder_output)
return x
Advantage: Dedicated mechanism for cross-modal attention. More parameter-efficient than concatenating visual tokens (visual tokens don’t consume the self-attention budget). The LLM can attend to visual features without them taking up KV cache space in self-attention.
Disadvantage: Requires architectural changes to the LLM (can’t just add a projector). Cross-attention layers add parameters and compute. Harder to initialize from a text-only checkpoint.
Early Fusion (Native Multimodal)
All modalities are tokenized and processed together from the start (Gemini, Chameleon):
class EarlyFusionTokenizer:
"""Convert all modalities to a unified token sequence."""
def __init__(self, text_tokenizer, image_tokenizer, audio_tokenizer):
self.text_tok = text_tokenizer
self.image_tok = image_tokenizer # VQ-VAE or similar
self.audio_tok = audio_tokenizer # e.g., EnCodec
def encode(self, text, images, audio):
tokens = []
# Text tokens
tokens.extend(self.text_tok.encode(text))
# Image tokens (quantized to discrete codes)
for img in images:
tokens.append(IMAGE_START_TOKEN)
tokens.extend(self.image_tok.encode(img)) # e.g., 256 discrete codes
tokens.append(IMAGE_END_TOKEN)
# Audio tokens
for aud in audio:
tokens.append(AUDIO_START_TOKEN)
tokens.extend(self.audio_tok.encode(aud))
tokens.append(AUDIO_END_TOKEN)
return tokens
Advantage: Truly unified — the model processes all modalities identically. Can generate images and audio, not just consume them. Most flexible architecture.
Disadvantage: Requires training from scratch (can’t adapt a text-only model). Image tokenization (VQ-VAE) is lossy. Very long sequences (a single image = 256-1024 discrete tokens).
Fusion Strategy Comparison
| Strategy | Can Generate Images? | Adapt from Text LLM? | KV Cache Impact | Quality (VQA) |
|---|---|---|---|---|
| Late Fusion (LLaVA) | No | Yes (add projector) | Visual tokens in KV cache | Good (78%) |
| Cross-Attention (Flamingo) | No | Partially (add layers) | No visual tokens in self-attn KV | Good (79%) |
| Early Fusion (Gemini) | Yes | No (train from scratch) | All tokens in same KV | Best (82%) |
Late fusion dominates open-source (LLaVA, Llama Vision) because it lets you upgrade the LLM and vision encoder independently. Cross-attention is used by Qwen-VL and some commercial models for its parameter efficiency. Early fusion is mostly Google (Gemini) and Meta (Chameleon) — it requires training from scratch, which most organizations cannot afford.
Reviewer Agent Validation
Challenge: Implement a minimal late-fusion forward pass that takes text token IDs and pre-computed visual features, inserts visual tokens at a specified position, and runs the LLM.
Expected:
def late_fusion_forward(llm, projector, text_ids, visual_features, insert_pos):
text_embeds = llm.embed_tokens(text_ids) # [B, S, D]
visual_embeds = projector(visual_features) # [B, P, D]
# Insert visual tokens
combined = torch.cat([
text_embeds[:, :insert_pos],
visual_embeds,
text_embeds[:, insert_pos:]
], dim=1)
return llm(inputs_embeds=combined)