A 448x448 image processed by a ViT encoder produces 1,024 visual tokens — as many as 200 words of text. Add three images to a prompt and you have 3,000 visual tokens before the user types a single question. These visual tokens must flow through the same continuous batching and KV cache machinery as text, but they arrive differently: an image needs a full ViT forward pass (4-12ms depending on resolution and encoder model) before it becomes usable by the LLM. vLLM v1 handles this by integrating vision encoders as a preprocessing stage that runs asynchronously with text processing, producing visual tokens that merge seamlessly with text tokens in the LLM’s input sequence.
The Visual Token Pipeline
End-to-End Architecture
The pipeline from raw image bytes to generated text:
Raw Image (JPEG/PNG bytes)
|
v
Preprocessing (resize, normalize, patch extraction)
|
v
ViT Encoder (forward pass, produces visual features)
|
v
Projection Layer (maps ViT features to LLM embedding space)
|
v
Visual Token Sequence (N tokens of dimension d_model)
|
v
Token Injection (insert visual tokens into LLM input sequence)
|
v
LLM Forward Pass (attention over text + visual tokens)
|
v
Generated Text
Each stage has specific compute and memory characteristics:
Pipeline Stage Costs (LLaVA-1.5 7B, 336x336 image, H100)
| Stage | Time (ms) | Memory | Output Shape |
|---|---|---|---|
| Image decode (JPEG) | 0.5 | ~1 MB | 336x336x3 uint8 |
| Preprocessing | 1.2 | ~2 MB | 336x336x3 float32 |
| ViT encoder | 4.8 | ~150 MB (weights) | 576 x 1024 |
| Projection (MLP) | 0.3 | ~50 MB (weights) | 576 x 4096 |
| Token injection | 0.1 | Negligible | text_len + 576 x 4096 |
| LLM prefill (text + visual) | 28.0 | KV cache | 1 x vocab_size |
The ViT encoder adds ~5ms to the prefill cost. The real cost is the increased sequence length: 576 visual tokens are equivalent to 576 text tokens for attention computation.
Image Preprocessing
The Preprocessing Pipeline
Each vision model specifies its own preprocessing requirements. The common steps are:
import torch
from torchvision import transforms
from PIL import Image
import io
class ImagePreprocessor:
"""Preprocess images for Vision Transformer input."""
def __init__(self, config):
self.image_size = config.image_size # e.g., 336
self.patch_size = config.patch_size # e.g., 14
self.mean = config.image_mean # e.g., [0.4815, 0.4578, 0.4082]
self.std = config.image_std # e.g., [0.2686, 0.2613, 0.2758]
self.num_patches = (self.image_size // self.patch_size) ** 2
# 336/14 = 24, 24^2 = 576 patches
self.transform = transforms.Compose([
transforms.Resize(
(self.image_size, self.image_size),
interpolation=transforms.InterpolationMode.BICUBIC,
),
transforms.ToTensor(), # HWC uint8 -> CHW float32 [0,1]
transforms.Normalize(
mean=self.mean,
std=self.std,
),
])
def preprocess(self, image_bytes: bytes) -> torch.Tensor:
"""Convert raw image bytes to preprocessed tensor."""
# Decode image
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Apply transforms
pixel_values = self.transform(image) # [3, H, W]
return pixel_values
def preprocess_batch(self, images: list) -> torch.Tensor:
"""Preprocess a batch of images."""
tensors = [self.preprocess(img) for img in images]
return torch.stack(tensors) # [B, 3, H, W]
Dynamic Resolution (Qwen-VL, InternVL Style)
Some models support dynamic resolution: instead of resizing all images to a fixed size, they divide the image into tiles at the native resolution:
class DynamicResolutionPreprocessor:
"""Dynamic resolution preprocessing for models like Qwen-VL."""
def __init__(self, config):
self.patch_size = config.patch_size # 14
self.min_patches = config.min_dynamic_patch # 1
self.max_patches = config.max_dynamic_patch # 12
self.image_size = config.image_size # 448 (per tile)
def find_best_grid(self, width: int, height: int):
"""Find the optimal grid of tiles for the image."""
aspect = width / height
best_grid = (1, 1)
best_waste = float('inf')
for total_patches in range(self.min_patches, self.max_patches + 1):
for cols in range(1, total_patches + 1):
rows = total_patches // cols
if rows * cols != total_patches:
continue
if rows == 0:
continue
grid_aspect = cols / rows
waste = abs(grid_aspect - aspect)
if waste < best_waste:
best_waste = waste
best_grid = (rows, cols)
return best_grid
def preprocess(self, image_bytes: bytes):
"""Preprocess with dynamic tiling."""
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
w, h = image.size
rows, cols = self.find_best_grid(w, h)
total_tiles = rows * cols
# Resize to grid dimensions
target_w = cols * self.image_size
target_h = rows * self.image_size
image = image.resize((target_w, target_h), Image.BICUBIC)
# Split into tiles
tiles = []
for r in range(rows):
for c in range(cols):
tile = image.crop((
c * self.image_size,
r * self.image_size,
(c + 1) * self.image_size,
(r + 1) * self.image_size,
))
tiles.append(self._normalize(tile))
# Add a global thumbnail tile
thumbnail = image.resize(
(self.image_size, self.image_size), Image.BICUBIC
)
tiles.append(self._normalize(thumbnail))
pixel_values = torch.stack(tiles) # [total_tiles+1, 3, H, W]
patches_per_tile = (self.image_size // self.patch_size) ** 2
total_visual_tokens = (total_tiles + 1) * patches_per_tile
return pixel_values, total_visual_tokens
def _normalize(self, image):
tensor = transforms.ToTensor()(image)
return transforms.Normalize(
mean=[0.4815, 0.4578, 0.4082],
std=[0.2686, 0.2613, 0.2758],
)(tensor)
Dynamic resolution produces a variable number of visual tokens per image. A small thumbnail might produce 256 tokens; a large detailed image might produce 3072+ tokens.
A 1920x1080 image with 12 tiles at 448x448 each produces visual tokens. This is equivalent to a 13K-token text prompt for attention computation. Dynamic resolution can dramatically increase prefill latency and KV cache consumption.
The Vision Encoder
ViT Architecture
The Vision Transformer (ViT) processes the preprocessed image by:
- Splitting it into non-overlapping patches
- Embedding each patch into a vector
- Processing through transformer layers
- Outputting a sequence of feature vectors (one per patch)
class VisionEncoder(torch.nn.Module):
"""Simplified ViT encoder for multimodal LLMs."""
def __init__(self, config):
super().__init__()
self.patch_size = config.patch_size # 14
self.hidden_size = config.vision_hidden_size # 1024 (ViT-L)
self.num_layers = config.vision_num_layers # 24
self.num_heads = config.vision_num_heads # 16
# Patch embedding: Conv2d that maps patches to vectors
self.patch_embed = torch.nn.Conv2d(
in_channels=3,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
# Position embedding: learnable position for each patch
max_patches = (config.image_size // self.patch_size) ** 2
self.position_embed = torch.nn.Embedding(
max_patches + 1, # +1 for CLS token
self.hidden_size,
)
# CLS token
self.cls_token = torch.nn.Parameter(
torch.randn(1, 1, self.hidden_size)
)
# Transformer layers
self.layers = torch.nn.ModuleList([
VisionTransformerLayer(config) for _ in range(self.num_layers)
])
self.layer_norm = torch.nn.LayerNorm(self.hidden_size)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Args:
pixel_values: [batch_size, 3, H, W]
Returns:
features: [batch_size, num_patches, hidden_size]
"""
batch_size = pixel_values.shape[0]
# Patch embedding: [B, 3, H, W] -> [B, C, H/P, W/P] -> [B, N, C]
patches = self.patch_embed(pixel_values) # [B, C, H/P, W/P]
patches = patches.flatten(2).transpose(1, 2) # [B, N, C]
# Add CLS token
cls = self.cls_token.expand(batch_size, -1, -1)
patches = torch.cat([cls, patches], dim=1) # [B, N+1, C]
# Add position embeddings
positions = torch.arange(patches.shape[1], device=patches.device)
patches = patches + self.position_embed(positions)
# Transformer layers
hidden = patches
for layer in self.layers:
hidden = layer(hidden)
hidden = self.layer_norm(hidden)
# Return all patch features (excluding CLS, or including it depending on model)
return hidden[:, 1:] # [B, N, C] -- exclude CLS token
Feature Selection: Which Layers to Use
Different multimodal models extract features from different ViT layers:
class FeatureExtractor:
"""Extract features from specific ViT layers."""
def __init__(self, strategy="last", layers_to_use=None):
self.strategy = strategy
self.layers_to_use = layers_to_use # e.g., [-2] for second-to-last
def extract(self, encoder, pixel_values):
"""Run encoder and extract features from specified layers."""
if self.strategy == "last":
# LLaVA: use last layer output
return encoder(pixel_values)
elif self.strategy == "intermediate":
# Some models use intermediate layer(s)
hidden = pixel_values
hidden = encoder.patch_embed(hidden).flatten(2).transpose(1, 2)
cls = encoder.cls_token.expand(hidden.shape[0], -1, -1)
hidden = torch.cat([cls, hidden], dim=1)
positions = torch.arange(hidden.shape[1], device=hidden.device)
hidden = hidden + encoder.position_embed(positions)
all_hidden = []
for i, layer in enumerate(encoder.layers):
hidden = layer(hidden)
if i in self.layers_to_use or (i - len(encoder.layers)) in self.layers_to_use:
all_hidden.append(hidden)
# Concatenate or average features from selected layers
if len(all_hidden) == 1:
return all_hidden[0][:, 1:]
else:
return torch.cat([h[:, 1:] for h in all_hidden], dim=-1)
The Projection Layer
Mapping ViT Features to LLM Space
The ViT hidden dimension (e.g., 1024 for ViT-L) differs from the LLM hidden dimension (e.g., 4096 for Llama 7B). A projection layer bridges this gap:
class VisionProjection(torch.nn.Module):
"""Project ViT features into LLM embedding space."""
def __init__(self, config, projection_type="mlp"):
super().__init__()
self.projection_type = projection_type
vision_dim = config.vision_hidden_size # 1024
llm_dim = config.hidden_size # 4096
if projection_type == "linear":
# LLaVA v1: simple linear projection
self.proj = torch.nn.Linear(vision_dim, llm_dim)
elif projection_type == "mlp":
# LLaVA v1.5+: two-layer MLP with GELU
self.proj = torch.nn.Sequential(
torch.nn.Linear(vision_dim, llm_dim),
torch.nn.GELU(),
torch.nn.Linear(llm_dim, llm_dim),
)
elif projection_type == "cross_attention":
# Qwen-VL: cross-attention with learnable queries
self.num_queries = config.num_query_tokens # e.g., 256
self.queries = torch.nn.Parameter(
torch.randn(1, self.num_queries, llm_dim)
)
self.cross_attn = torch.nn.MultiheadAttention(
embed_dim=llm_dim,
kdim=vision_dim,
vdim=vision_dim,
num_heads=16,
batch_first=True,
)
self.kv_proj = torch.nn.Linear(vision_dim, llm_dim)
def forward(self, visual_features: torch.Tensor) -> torch.Tensor:
"""
Args:
visual_features: [batch_size, num_patches, vision_dim]
Returns:
projected: [batch_size, num_tokens, llm_dim]
"""
if self.projection_type in ("linear", "mlp"):
return self.proj(visual_features)
elif self.projection_type == "cross_attention":
batch_size = visual_features.shape[0]
queries = self.queries.expand(batch_size, -1, -1)
projected, _ = self.cross_attn(
queries,
visual_features,
visual_features,
)
return projected # [B, num_queries, llm_dim]
The projection type affects the number of visual tokens:
Projection Types and Visual Token Counts
| Method | Visual Tokens (336x336) | Visual Tokens (672x672) | Token Reduction |
|---|---|---|---|
| Linear/MLP (LLaVA) | 576 | 2304 | None (1:1) |
| Cross-attention 256q (Qwen-VL) | 256 | 256 | Fixed queries |
| Pixel shuffle 2x (InternVL) | 144 | 576 | 4:1 spatial merge |
| Avg pool 2x2 (custom) | 144 | 576 | 4:1 pooling |
Pixel Shuffle Token Reduction
Some models reduce the number of visual tokens by merging adjacent patches:
class PixelShuffleProjection(torch.nn.Module):
"""Reduce visual tokens via pixel shuffle (InternVL style)."""
def __init__(self, vision_dim, llm_dim, downsample=2):
super().__init__()
self.downsample = downsample
# After pixel shuffle: each token absorbs downsample^2 neighbors
merged_dim = vision_dim * downsample * downsample
self.proj = torch.nn.Sequential(
torch.nn.Linear(merged_dim, llm_dim),
torch.nn.GELU(),
torch.nn.Linear(llm_dim, llm_dim),
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""
Args:
features: [B, N, C] where N = H_p * W_p patches
Returns:
reduced: [B, N / downsample^2, llm_dim]
"""
B, N, C = features.shape
H_p = W_p = int(N ** 0.5)
d = self.downsample
# Reshape to spatial grid
features = features.view(B, H_p, W_p, C)
# Merge adjacent patches
features = features.view(B, H_p // d, d, W_p // d, d, C)
features = features.permute(0, 1, 3, 2, 4, 5) # [B, H/d, W/d, d, d, C]
features = features.reshape(B, (H_p // d) * (W_p // d), d * d * C)
return self.proj(features)
Token Injection into the LLM
The Injection Mechanism
After the ViT produces visual tokens, they must be inserted into the LLM’s input sequence at the correct position. Multimodal models use a special placeholder token (e.g., <image>) that marks where visual tokens should go:
Text input: "User: <image>\nDescribe this image."
Token IDs: [User, :, <image>, \n, Describe, this, image, .]
After injection:
Token IDs: [User, :, [vis_0, vis_1, ..., vis_575], \n, Describe, this, image, .]
The injection replaces the single <image> token with visual token embeddings:
class TokenInjector:
"""Inject visual tokens into the LLM input sequence."""
def __init__(self, image_token_id: int):
self.image_token_id = image_token_id
def inject(
self,
input_ids: torch.Tensor, # [seq_len]
text_embeddings: torch.Tensor, # [seq_len, d_model]
visual_tokens: torch.Tensor, # [num_visual, d_model]
) -> tuple:
"""
Replace <image> placeholder with visual token embeddings.
Returns:
combined_embeddings: [new_seq_len, d_model]
new_positions: [new_seq_len]
"""
# Find the position of <image> token
image_positions = (input_ids == self.image_token_id).nonzero(as_tuple=True)[0]
if len(image_positions) == 0:
# No image token -- return text embeddings unchanged
return text_embeddings, torch.arange(len(text_embeddings))
num_visual = visual_tokens.shape[0]
parts = []
pos_parts = []
current_pos = 0
prev_end = 0
for img_pos in image_positions:
img_pos = img_pos.item()
# Text before this <image> token
if img_pos > prev_end:
text_chunk = text_embeddings[prev_end:img_pos]
parts.append(text_chunk)
pos_parts.append(
torch.arange(current_pos, current_pos + len(text_chunk))
)
current_pos += len(text_chunk)
# Visual tokens replacing <image>
parts.append(visual_tokens)
pos_parts.append(
torch.arange(current_pos, current_pos + num_visual)
)
current_pos += num_visual
prev_end = img_pos + 1 # Skip the <image> token
# Remaining text after last <image>
if prev_end < len(text_embeddings):
text_chunk = text_embeddings[prev_end:]
parts.append(text_chunk)
pos_parts.append(
torch.arange(current_pos, current_pos + len(text_chunk))
)
combined = torch.cat(parts, dim=0)
positions = torch.cat(pos_parts, dim=0)
return combined, positions
Multi-Image Handling
When a request contains multiple images, each <image> placeholder is replaced by the visual tokens from the corresponding image:
class MultiImageInjector:
"""Handle requests with multiple images."""
def __init__(self, image_token_id: int):
self.image_token_id = image_token_id
def inject_multiple(
self,
input_ids: torch.Tensor,
text_embeddings: torch.Tensor,
visual_tokens_list: list, # List of [num_visual_i, d_model] tensors
) -> tuple:
image_positions = (input_ids == self.image_token_id).nonzero(as_tuple=True)[0]
if len(image_positions) != len(visual_tokens_list):
raise ValueError(
f"Found {len(image_positions)} <image> tokens but "
f"received {len(visual_tokens_list)} images"
)
parts = []
current_pos = 0
prev_end = 0
for i, img_pos in enumerate(image_positions):
img_pos = img_pos.item()
# Text before this image
if img_pos > prev_end:
parts.append(text_embeddings[prev_end:img_pos])
# Visual tokens for image i
parts.append(visual_tokens_list[i])
prev_end = img_pos + 1
# Remaining text
if prev_end < len(text_embeddings):
parts.append(text_embeddings[prev_end:])
combined = torch.cat(parts, dim=0)
total_len = combined.shape[0]
positions = torch.arange(total_len, device=combined.device)
return combined, positions
A request with 3 images (each producing 576 visual tokens) and 200 text tokens has a total sequence length of tokens. The attention computation scales as during prefill, so multimodal requests can be 10x more expensive than their text-only counterparts.
vLLM Integration Architecture
The Full Pipeline in vLLM
class MultimodalModelRunner:
"""
vLLM model runner with vision encoder integration.
Handles preprocessing, encoding, projection, and injection.
"""
def __init__(self, model_config, vision_config):
# Text model
self.llm = load_llm(model_config)
# Vision components
self.preprocessor = ImagePreprocessor(vision_config)
self.vision_encoder = VisionEncoder(vision_config)
self.projection = VisionProjection(vision_config)
self.injector = MultiImageInjector(
image_token_id=model_config.image_token_id
)
# Pre-compute embedding layer reference
self.embed_tokens = self.llm.model.embed_tokens
def execute_model(self, input_data):
"""Execute a multimodal forward pass."""
if input_data.has_images:
return self._execute_multimodal(input_data)
else:
return self._execute_text_only(input_data)
def _execute_multimodal(self, input_data):
# Step 1: Preprocess images (CPU -> GPU)
pixel_values = self._preprocess_images(input_data.images)
# Step 2: Run vision encoder
with torch.no_grad():
visual_features = self.vision_encoder(pixel_values)
visual_tokens = self.projection(visual_features)
# Step 3: Get text embeddings
text_embeddings = self.embed_tokens(input_data.input_ids)
# Step 4: Inject visual tokens
combined_embeddings, positions = self.injector.inject_multiple(
input_data.input_ids,
text_embeddings,
visual_tokens, # [num_images, num_patches, d_model]
)
# Step 5: Run LLM with combined embeddings
# (bypass the embedding layer, feed embeddings directly)
output = self.llm.forward_from_embeddings(
combined_embeddings,
positions,
kv_cache=input_data.kv_cache,
)
return output
def _preprocess_images(self, images):
"""Preprocess images and move to GPU."""
pixel_values = self.preprocessor.preprocess_batch(images)
return pixel_values.to(device="cuda", dtype=torch.float16)
Scheduler Awareness of Visual Tokens
The scheduler must account for visual tokens when computing sequence lengths and KV cache requirements:
class MultimodalScheduler:
"""Scheduler that accounts for visual token overhead."""
def compute_sequence_length(self, request):
"""Compute total sequence length including visual tokens."""
text_len = len(request.prompt_tokens)
visual_len = 0
for image_info in request.images:
if image_info.dynamic_resolution:
# Variable number of visual tokens
visual_len += image_info.num_visual_tokens
else:
visual_len += self.default_visual_tokens # e.g., 576
# Subtract <image> placeholder tokens (they get replaced)
num_placeholders = sum(
1 for t in request.prompt_tokens
if t == self.image_token_id
)
total = text_len - num_placeholders + visual_len
return total
def can_schedule(self, request, available_blocks):
"""Check if there are enough KV cache blocks for this request."""
total_len = self.compute_sequence_length(request)
blocks_needed = (total_len + self.block_size - 1) // self.block_size
return blocks_needed <= available_blocks
Performance Optimization
Encoder Caching
If multiple requests share the same image (e.g., different questions about the same image), the ViT encoding can be cached:
class VisionEncoderCache:
"""Cache ViT encodings to avoid redundant computation."""
def __init__(self, max_cache_size=100):
self.cache = {}
self.max_size = max_cache_size
def get_or_compute(self, image_hash, pixel_values, encoder, projection):
"""Return cached visual tokens or compute them."""
if image_hash in self.cache:
return self.cache[image_hash]
with torch.no_grad():
features = encoder(pixel_values.unsqueeze(0))
tokens = projection(features).squeeze(0)
# Cache the result
if len(self.cache) >= self.max_size:
# Evict oldest entry
oldest = next(iter(self.cache))
del self.cache[oldest]
self.cache[image_hash] = tokens.detach()
return tokens
Async Preprocessing
Image preprocessing (decode + resize + normalize) runs on CPU and can overlap with GPU computation:
import concurrent.futures
import hashlib
class AsyncImagePipeline:
"""Async image preprocessing pipeline."""
def __init__(self, preprocessor, max_workers=4):
self.preprocessor = preprocessor
self.executor = concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
)
def submit(self, image_bytes_list):
"""Submit images for async preprocessing."""
futures = []
for img_bytes in image_bytes_list:
future = self.executor.submit(
self._preprocess_one, img_bytes
)
futures.append(future)
return futures
def _preprocess_one(self, image_bytes):
"""Preprocess one image (runs on thread pool)."""
img_hash = hashlib.sha256(image_bytes).hexdigest()[:16]
pixel_values = self.preprocessor.preprocess(image_bytes)
return img_hash, pixel_values
def collect(self, futures):
"""Collect preprocessed results."""
results = []
for f in concurrent.futures.as_completed(futures):
results.append(f.result())
return results
Batched Encoding
When multiple images arrive in the same scheduling window, they can be batched through the ViT encoder:
def batch_encode_images(
encoder,
projection,
pixel_values_list,
max_batch_encode=8,
):
"""Batch multiple images through the ViT encoder."""
results = []
for i in range(0, len(pixel_values_list), max_batch_encode):
batch = torch.stack(pixel_values_list[i:i + max_batch_encode])
batch = batch.to(device="cuda", dtype=torch.float16)
with torch.no_grad():
features = encoder(batch) # [B, N, C_vision]
tokens = projection(features) # [B, N, C_llm]
for j in range(tokens.shape[0]):
results.append(tokens[j]) # [N, C_llm]
return results
Memory Accounting
ViT Memory Budget
The vision encoder has its own memory footprint that must be accounted for in the GPU memory budget:
def compute_vision_memory(config):
"""Compute memory consumed by vision components."""
# ViT weights
vit_params = (
# Patch embedding
3 * config.vision_hidden_size * config.patch_size ** 2 +
# Position embedding
(config.max_patches + 1) * config.vision_hidden_size +
# Transformer layers (approximate)
config.vision_num_layers * (
4 * config.vision_hidden_size ** 2 + # QKV + O
8 * config.vision_hidden_size ** 2 # FFN (4x expansion)
)
)
# Projection weights
proj_params = (
config.vision_hidden_size * config.hidden_size + # First linear
config.hidden_size * config.hidden_size # Second linear
)
dtype_bytes = 2 # FP16
vit_mb = vit_params * dtype_bytes / (1024 ** 2)
proj_mb = proj_params * dtype_bytes / (1024 ** 2)
# Activation memory during encoding (for one image)
num_patches = (config.image_size // config.patch_size) ** 2
activation_per_image_mb = (
num_patches * config.vision_hidden_size * dtype_bytes *
2 # Forward + backward-ish buffer
) / (1024 ** 2)
return {
'vit_weights_mb': vit_mb,
'projection_weights_mb': proj_mb,
'activation_per_image_mb': activation_per_image_mb,
'total_persistent_mb': vit_mb + proj_mb,
}
# Example: ViT-L/14 + LLaVA projection
# vit_weights: ~304 MB
# projection: ~33 MB
# total: ~337 MB (reduces KV cache capacity by ~337 MB)
Vision Encoder Latency by Image Resolution (ViT-L/14, H100)
(ms)Complete Multimodal Serving Example
class MultimodalServingEngine:
"""
Complete multimodal serving engine.
Integrates vision encoder, token injection, and LLM generation.
"""
def __init__(self, model_name, max_images_per_request=5):
self.config = load_config(model_name)
self.max_images = max_images_per_request
# Initialize components
self.preprocessor = ImagePreprocessor(self.config)
self.encoder = VisionEncoder(self.config).cuda().half()
self.projection = VisionProjection(self.config).cuda().half()
self.injector = MultiImageInjector(self.config.image_token_id)
self.encoder_cache = VisionEncoderCache()
self.async_pipeline = AsyncImagePipeline(self.preprocessor)
# LLM engine (text generation)
self.llm_engine = LLMEngine(self.config)
def process_request(self, text, images):
"""Process a multimodal request."""
if len(images) > self.max_images:
raise ValueError(f"Max {self.max_images} images per request")
# Async preprocessing
futures = self.async_pipeline.submit(images)
preprocessed = self.async_pipeline.collect(futures)
# Encode images (with caching)
visual_tokens_list = []
for img_hash, pixel_values in preprocessed:
tokens = self.encoder_cache.get_or_compute(
img_hash, pixel_values, self.encoder, self.projection
)
visual_tokens_list.append(tokens)
# Tokenize text
input_ids = self.llm_engine.tokenize(text)
# Compute total sequence length for scheduling
num_image_tokens = sum(t.shape[0] for t in visual_tokens_list)
num_placeholders = sum(1 for t in input_ids if t == self.config.image_token_id)
total_seq_len = len(input_ids) - num_placeholders + num_image_tokens
# Submit to LLM engine with visual context
return self.llm_engine.generate(
input_ids=input_ids,
visual_tokens=visual_tokens_list,
max_new_tokens=512,
total_seq_len=total_seq_len,
)