This post requires the Transformer Anatomy series Parts 1-15 (the complete forward pass). It also connects to the Inference Optimization Timeline series for production-scale optimizations. Read the capstone post (Part 15) first.
The Transformer Anatomy series showed how to build the model — tokenize, embed, attend, FFN, project logits. But training and inference are fundamentally different operations. Training processes the entire sequence at once (teacher forcing). Inference must generate tokens one at a time, feeding each generated token back as input for the next step. This autoregressive loop, combined with KV cache management, is what turns a trained model into a text generator.
The Autoregressive Generation Loop
During training, the model sees all tokens simultaneously:
# Training: parallel forward pass over entire sequence
logits = model(input_ids) # [B, S, V] — all positions computed at once
loss = cross_entropy(logits[:, :-1], labels[:, 1:])
During inference, tokens are generated one at a time:
# Inference: sequential generation
generated = list(prompt_tokens)
for step in range(max_new_tokens):
logits = model(generated) # Forward pass on ALL tokens so far
next_token = sample(logits[-1]) # Sample from last position only
generated.append(next_token)
if next_token == eos_token:
break
The problem: At step , the forward pass processes tokens. The attention computation is where grows every step. Without optimization, generating 1000 tokens requires recomputing attention over the entire growing sequence 1000 times. This is catastrophically wasteful.
KV Cache: The Core Optimization
The key insight: when generating token , the Key and Value vectors for tokens through don’t change. They were computed during previous steps. Only the new token’s Q, K, V need computation. Cache the K and V vectors and reuse them.
class KVCache:
"""Stores Key and Value tensors from previous forward passes."""
def __init__(self, num_layers, max_seq_len, num_kv_heads, head_dim, dtype):
self.num_layers = num_layers
self.max_seq_len = max_seq_len
# Pre-allocate cache tensors
# Shape per layer: [batch, num_kv_heads, max_seq_len, head_dim]
self.k_cache = [
torch.zeros(1, num_kv_heads, max_seq_len, head_dim, dtype=dtype)
for _ in range(num_layers)
]
self.v_cache = [
torch.zeros(1, num_kv_heads, max_seq_len, head_dim, dtype=dtype)
for _ in range(num_layers)
]
self.seq_len = 0 # Current number of cached tokens
def update(self, layer_idx, new_k, new_v):
"""Append new K, V to the cache for one layer.
new_k, new_v: [batch, num_kv_heads, num_new_tokens, head_dim]
"""
num_new = new_k.shape[2]
start = self.seq_len
end = start + num_new
self.k_cache[layer_idx][:, :, start:end, :] = new_k
self.v_cache[layer_idx][:, :, start:end, :] = new_v
# Only update seq_len after last layer processes
if layer_idx == self.num_layers - 1:
self.seq_len = end
def get(self, layer_idx):
"""Return cached K, V up to current seq_len."""
return (
self.k_cache[layer_idx][:, :, :self.seq_len, :],
self.v_cache[layer_idx][:, :, :self.seq_len, :],
)
Without KV cache: generating 1000 tokens from a 500-token prompt requires processing 500 + 501 + 502 + … + 1499 = ~1M total tokens through the model. With KV cache: process 500 tokens once (prefill), then 1 token per step for 1000 steps = 1500 total tokens processed. That is a 667x reduction in total computation.
The Two-Phase Generation: Prefill and Decode
Inference: Prefill vs Decode Phases
Prefill: Process the entire prompt at once. All prompt tokens flow through the model in a single forward pass (like training). K and V for all prompt tokens are computed and stored in the KV cache. Output: logits at the last position (the first generated token).
Decode: Generate tokens one at a time. Each step: (1) compute Q, K, V for the new token only, (2) append K, V to cache, (3) compute attention between the new Q and ALL cached K, V, (4) sample the next token.
def generate(model, prompt_ids, max_new_tokens, temperature=1.0, top_p=0.95):
"""Complete generation with prefill + decode phases."""
# Initialize KV cache
kv_cache = KVCache(
num_layers=model.num_layers,
max_seq_len=len(prompt_ids) + max_new_tokens,
num_kv_heads=model.num_kv_heads,
head_dim=model.head_dim,
dtype=torch.float16,
)
# === PREFILL PHASE ===
# Process entire prompt at once (compute-bound, high GPU utilization)
prompt_tensor = torch.tensor([prompt_ids], device="cuda")
positions = torch.arange(len(prompt_ids), device="cuda").unsqueeze(0)
logits = model.forward(prompt_tensor, positions, kv_cache)
# logits shape: [1, prompt_len, vocab_size]
# KV cache now holds K, V for all prompt tokens
# Sample first generated token from last position logits
next_token = sample_token(logits[0, -1, :], temperature, top_p)
generated = [next_token]
# === DECODE PHASE ===
# Generate one token at a time (memory-bandwidth-bound)
for step in range(max_new_tokens - 1):
# Forward pass on SINGLE new token (using KV cache for context)
token_tensor = torch.tensor([[next_token]], device="cuda")
position = torch.tensor([[len(prompt_ids) + step]], device="cuda")
logits = model.forward(token_tensor, position, kv_cache)
# logits shape: [1, 1, vocab_size]
# Sample next token
next_token = sample_token(logits[0, 0, :], temperature, top_p)
generated.append(next_token)
# Check stop condition
if next_token == model.eos_token_id:
break
return generated
def sample_token(logits, temperature=1.0, top_p=0.95):
"""Sample a token from logits with temperature and top-p."""
if temperature == 0:
return logits.argmax().item() # Greedy
# Temperature scaling
logits = logits / temperature
# Top-p (nucleus) sampling
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
sorted_indices_to_remove[0] = False
logits[sorted_indices[sorted_indices_to_remove]] = float('-inf')
# Sample from filtered distribution
probs = torch.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1).item()
Attention with KV Cache
The model’s attention layer must handle two modes:
class CachedAttention:
"""Attention that uses KV cache during inference."""
def forward(self, hidden, positions, kv_cache, layer_idx):
# Compute Q, K, V for CURRENT tokens only
Q = self.W_q(hidden) # [B, num_new_tokens, n_heads * d_head]
K = self.W_k(hidden) # [B, num_new_tokens, n_kv_heads * d_head]
V = self.W_v(hidden) # [B, num_new_tokens, n_kv_heads * d_head]
# Apply RoPE to Q and K
Q = apply_rope(Q, positions)
K = apply_rope(K, positions)
# Reshape for attention
Q = Q.view(B, -1, self.n_heads, self.d_head).transpose(1, 2)
K = K.view(B, -1, self.n_kv_heads, self.d_head).transpose(1, 2)
V = V.view(B, -1, self.n_kv_heads, self.d_head).transpose(1, 2)
# Update KV cache with new K, V
kv_cache.update(layer_idx, K, V)
# Get FULL cached K, V (all previous + current tokens)
K_full, V_full = kv_cache.get(layer_idx)
# K_full: [B, n_kv_heads, total_seq_len, d_head]
# GQA: expand KV heads to match query heads
if self.n_heads != self.n_kv_heads:
repeat = self.n_heads // self.n_kv_heads
K_full = K_full.repeat_interleave(repeat, dim=1)
V_full = V_full.repeat_interleave(repeat, dim=1)
# Compute attention: Q (new tokens) attends to K_full (all tokens)
scores = torch.matmul(Q, K_full.transpose(-1, -2)) / (self.d_head ** 0.5)
# Causal mask: new tokens can attend to all previous + self
# No masking needed during decode (single new token attends to all cached)
# During prefill: standard causal mask applies
attn_weights = torch.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V_full)
# Reshape and project
output = output.transpose(1, 2).contiguous().view(B, -1, self.n_heads * self.d_head)
return self.W_o(output)
KV Cache Memory Math
For Llama 3 70B (80 layers, GQA-8 with 8 KV heads, , FP16):
KV Cache Memory at Different Sequence Lengths (Llama 70B)
| Sequence Length | KV Cache Size | Notes |
|---|---|---|
| 512 | 160 MB | Short prompt + response |
| 4096 | 1.28 GB | Typical conversation |
| 32768 | 10.24 GB | Long document analysis |
| 131072 | 40.96 GB | Max context (128K) |
At batch=32 with 4K context, KV cache = 32 x 1.28 GB = 41 GB. Model weights (FP16) = 140 GB. Total = 181 GB, requiring 3 H100 GPUs. The KV cache, not the model, determines how many concurrent requests you can serve. This is why PagedAttention (Inference Timeline Part 5), KV compression (Part 37), and MLA (MoE Masterclass Part 3) exist.
Why Decode Is Memory-Bandwidth-Bound
During decode, the model processes 1 token. The compute: FLOPs (one matmul per parameter). The memory: load all bytes from HBM.
Arithmetic intensity = FLOP/byte. The H100 ridge point is 295 FLOP/byte (FP16). Decode is 295x below the ridge point — deeply memory-bandwidth-bound. This means:
- Tensor core utilization during decode: ~0.3%
- Throughput limited by HBM bandwidth: 3.35 TB/s / (140 GB weights) = ~24 tokens/sec at batch=1
- Batching helps: at batch=32, effective AI = 32 FLOP/byte — still below ridge but 32x more efficient
Decode Throughput vs Batch Size (Llama 70B, H100)
(tokens/sec total)This is why the Inference Optimization Timeline series exists: every optimization (batching, quantization, FlashAttention, speculative decoding, disaggregated serving) attacks the decode memory-bandwidth bottleneck from a different angle.
Streaming Output
Production systems stream tokens as they are generated:
async def generate_streaming(model, prompt_ids, max_new_tokens, callback):
"""Generate tokens and stream each one via callback."""
kv_cache = KVCache(...)
# Prefill
logits = model.forward(prompt_ids, kv_cache=kv_cache)
next_token = sample_token(logits[0, -1, :])
await callback(next_token) # Stream first token (TTFT)
# Decode with streaming
for step in range(max_new_tokens - 1):
logits = model.forward([[next_token]], kv_cache=kv_cache)
next_token = sample_token(logits[0, 0, :])
await callback(next_token) # Stream each token (TBT)
if next_token == eos_token_id:
break
TTFT (Time to First Token) = prefill time. TBT (Time Between Tokens) = single decode step time. Users perceive TTFT as “response start latency” and TBT as “typing speed.”