Part of Series Transformer Anatomy 38 of 36
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Distributed Data Parallel: Gradient Synchronization, Bucket All-Reduce, and Overlap with Backward 21 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 22 Dropout and Regularization in Transformers: Where It Helps, Where It Hurts 23 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 24 Mixed Precision Training: BF16 Forward, FP32 Master Weights, and the Precision Hierarchy 25 Token Prediction Heads: Next-Token, Multi-Token, and Classifier Heads 26 Mixture of Depths: Conditional Computation Per Layer for Faster Inference 27 Sparse Attention Patterns: Local, Strided, Hash-Based, and Learnable Sparsity 28 Rotary Position Embedding: The Complete Mathematical Derivation 29 Knowledge Distillation: Training Small Models to Match Large Ones 30 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search 31 Pruning at Scale: SparseGPT, Wanda, and Structured Removal of Redundant Parameters 32 The Transformer in 2026: What Changed, What Stayed, and What's Next 33 Data Loading: Tokenization, Sequence Packing, Padding Strategies, and Attention Masks 34 The FlashAttention Backward Pass: Recomputation, Memory Savings, and the 33% Compute Overhead 35 The Inference Engine: Token Generation Loop, KV Cache Management, and Autoregressive Decoding 36 Tensor Parallelism Implementation: Splitting Weights Across GPUs for Training and Inference
⚠️ Prerequisites

This post requires the Transformer Anatomy series Parts 1-15 (the complete forward pass). It also connects to the Inference Optimization Timeline series for production-scale optimizations. Read the capstone post (Part 15) first.

The Transformer Anatomy series showed how to build the model — tokenize, embed, attend, FFN, project logits. But training and inference are fundamentally different operations. Training processes the entire sequence at once (teacher forcing). Inference must generate tokens one at a time, feeding each generated token back as input for the next step. This autoregressive loop, combined with KV cache management, is what turns a trained model into a text generator.

The Autoregressive Generation Loop

During training, the model sees all tokens simultaneously:

# Training: parallel forward pass over entire sequence
logits = model(input_ids)  # [B, S, V] — all positions computed at once
loss = cross_entropy(logits[:, :-1], labels[:, 1:])

During inference, tokens are generated one at a time:

# Inference: sequential generation
generated = list(prompt_tokens)
for step in range(max_new_tokens):
    logits = model(generated)     # Forward pass on ALL tokens so far
    next_token = sample(logits[-1])  # Sample from last position only
    generated.append(next_token)
    if next_token == eos_token:
        break

The problem: At step tt, the forward pass processes prompt+t|\text{prompt}| + t tokens. The attention computation is O(n2)O(n^2) where nn grows every step. Without optimization, generating 1000 tokens requires recomputing attention over the entire growing sequence 1000 times. This is catastrophically wasteful.

KV Cache: The Core Optimization

The key insight: when generating token t+1t+1, the Key and Value vectors for tokens 00 through tt don’t change. They were computed during previous steps. Only the new token’s Q, K, V need computation. Cache the K and V vectors and reuse them.

class KVCache:
    """Stores Key and Value tensors from previous forward passes."""

    def __init__(self, num_layers, max_seq_len, num_kv_heads, head_dim, dtype):
        self.num_layers = num_layers
        self.max_seq_len = max_seq_len
        # Pre-allocate cache tensors
        # Shape per layer: [batch, num_kv_heads, max_seq_len, head_dim]
        self.k_cache = [
            torch.zeros(1, num_kv_heads, max_seq_len, head_dim, dtype=dtype)
            for _ in range(num_layers)
        ]
        self.v_cache = [
            torch.zeros(1, num_kv_heads, max_seq_len, head_dim, dtype=dtype)
            for _ in range(num_layers)
        ]
        self.seq_len = 0  # Current number of cached tokens

    def update(self, layer_idx, new_k, new_v):
        """Append new K, V to the cache for one layer.

        new_k, new_v: [batch, num_kv_heads, num_new_tokens, head_dim]
        """
        num_new = new_k.shape[2]
        start = self.seq_len
        end = start + num_new
        self.k_cache[layer_idx][:, :, start:end, :] = new_k
        self.v_cache[layer_idx][:, :, start:end, :] = new_v
        # Only update seq_len after last layer processes
        if layer_idx == self.num_layers - 1:
            self.seq_len = end

    def get(self, layer_idx):
        """Return cached K, V up to current seq_len."""
        return (
            self.k_cache[layer_idx][:, :, :self.seq_len, :],
            self.v_cache[layer_idx][:, :, :self.seq_len, :],
        )
Why KV Cache Is Essential

Without KV cache: generating 1000 tokens from a 500-token prompt requires processing 500 + 501 + 502 + … + 1499 = ~1M total tokens through the model. With KV cache: process 500 tokens once (prefill), then 1 token per step for 1000 steps = 1500 total tokens processed. That is a 667x reduction in total computation.

The Two-Phase Generation: Prefill and Decode

Inference: Prefill vs Decode Phases

PREFILL (prompt processing) Process all prompt tokens in parallel Compute-bound: large batch GEMM saturates tensor cores
DECODE (token generation) Generate one token at a time Memory-bandwidth-bound: load model weights for 1 token

Prefill: Process the entire prompt at once. All prompt tokens flow through the model in a single forward pass (like training). K and V for all prompt tokens are computed and stored in the KV cache. Output: logits at the last position (the first generated token).

Decode: Generate tokens one at a time. Each step: (1) compute Q, K, V for the new token only, (2) append K, V to cache, (3) compute attention between the new Q and ALL cached K, V, (4) sample the next token.

def generate(model, prompt_ids, max_new_tokens, temperature=1.0, top_p=0.95):
    """Complete generation with prefill + decode phases."""

    # Initialize KV cache
    kv_cache = KVCache(
        num_layers=model.num_layers,
        max_seq_len=len(prompt_ids) + max_new_tokens,
        num_kv_heads=model.num_kv_heads,
        head_dim=model.head_dim,
        dtype=torch.float16,
    )

    # === PREFILL PHASE ===
    # Process entire prompt at once (compute-bound, high GPU utilization)
    prompt_tensor = torch.tensor([prompt_ids], device="cuda")
    positions = torch.arange(len(prompt_ids), device="cuda").unsqueeze(0)
    logits = model.forward(prompt_tensor, positions, kv_cache)
    # logits shape: [1, prompt_len, vocab_size]
    # KV cache now holds K, V for all prompt tokens

    # Sample first generated token from last position logits
    next_token = sample_token(logits[0, -1, :], temperature, top_p)
    generated = [next_token]

    # === DECODE PHASE ===
    # Generate one token at a time (memory-bandwidth-bound)
    for step in range(max_new_tokens - 1):
        # Forward pass on SINGLE new token (using KV cache for context)
        token_tensor = torch.tensor([[next_token]], device="cuda")
        position = torch.tensor([[len(prompt_ids) + step]], device="cuda")
        logits = model.forward(token_tensor, position, kv_cache)
        # logits shape: [1, 1, vocab_size]

        # Sample next token
        next_token = sample_token(logits[0, 0, :], temperature, top_p)
        generated.append(next_token)

        # Check stop condition
        if next_token == model.eos_token_id:
            break

    return generated

def sample_token(logits, temperature=1.0, top_p=0.95):
    """Sample a token from logits with temperature and top-p."""
    if temperature == 0:
        return logits.argmax().item()  # Greedy

    # Temperature scaling
    logits = logits / temperature

    # Top-p (nucleus) sampling
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

    # Remove tokens with cumulative probability above threshold
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
    sorted_indices_to_remove[0] = False
    logits[sorted_indices[sorted_indices_to_remove]] = float('-inf')

    # Sample from filtered distribution
    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1).item()

Attention with KV Cache

The model’s attention layer must handle two modes:

class CachedAttention:
    """Attention that uses KV cache during inference."""

    def forward(self, hidden, positions, kv_cache, layer_idx):
        # Compute Q, K, V for CURRENT tokens only
        Q = self.W_q(hidden)   # [B, num_new_tokens, n_heads * d_head]
        K = self.W_k(hidden)   # [B, num_new_tokens, n_kv_heads * d_head]
        V = self.W_v(hidden)   # [B, num_new_tokens, n_kv_heads * d_head]

        # Apply RoPE to Q and K
        Q = apply_rope(Q, positions)
        K = apply_rope(K, positions)

        # Reshape for attention
        Q = Q.view(B, -1, self.n_heads, self.d_head).transpose(1, 2)
        K = K.view(B, -1, self.n_kv_heads, self.d_head).transpose(1, 2)
        V = V.view(B, -1, self.n_kv_heads, self.d_head).transpose(1, 2)

        # Update KV cache with new K, V
        kv_cache.update(layer_idx, K, V)

        # Get FULL cached K, V (all previous + current tokens)
        K_full, V_full = kv_cache.get(layer_idx)
        # K_full: [B, n_kv_heads, total_seq_len, d_head]

        # GQA: expand KV heads to match query heads
        if self.n_heads != self.n_kv_heads:
            repeat = self.n_heads // self.n_kv_heads
            K_full = K_full.repeat_interleave(repeat, dim=1)
            V_full = V_full.repeat_interleave(repeat, dim=1)

        # Compute attention: Q (new tokens) attends to K_full (all tokens)
        scores = torch.matmul(Q, K_full.transpose(-1, -2)) / (self.d_head ** 0.5)

        # Causal mask: new tokens can attend to all previous + self
        # No masking needed during decode (single new token attends to all cached)
        # During prefill: standard causal mask applies

        attn_weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, V_full)

        # Reshape and project
        output = output.transpose(1, 2).contiguous().view(B, -1, self.n_heads * self.d_head)
        return self.W_o(output)

KV Cache Memory Math

For Llama 3 70B (80 layers, GQA-8 with 8 KV heads, dh=128d_h = 128, FP16):

KV per token=2×L×nkv×dh×2=2×80×8×128×2=327,680 bytes320 KB\text{KV per token} = 2 \times L \times n_{\text{kv}} \times d_h \times 2 = 2 \times 80 \times 8 \times 128 \times 2 = 327{,}680 \text{ bytes} \approx 320 \text{ KB}

📊

KV Cache Memory at Different Sequence Lengths (Llama 70B)

Sequence LengthKV Cache SizeNotes
512 160 MB Short prompt + response
4096 1.28 GB Typical conversation
32768 10.24 GB Long document analysis
131072 40.96 GB Max context (128K)
Note: 320 KB per token across all 80 layers. At 128K context, KV cache alone is 41 GB — half of an H100's 80 GB HBM.
⚠️ KV Cache Is the Memory Bottleneck

At batch=32 with 4K context, KV cache = 32 x 1.28 GB = 41 GB. Model weights (FP16) = 140 GB. Total = 181 GB, requiring 3 H100 GPUs. The KV cache, not the model, determines how many concurrent requests you can serve. This is why PagedAttention (Inference Timeline Part 5), KV compression (Part 37), and MLA (MoE Masterclass Part 3) exist.

Why Decode Is Memory-Bandwidth-Bound

During decode, the model processes 1 token. The compute: 2×Nparams2 \times N_{\text{params}} FLOPs (one matmul per parameter). The memory: load all Nparams×2N_{\text{params}} \times 2 bytes from HBM.

Arithmetic intensity = 2N2N=1\frac{2N}{2N} = 1 FLOP/byte. The H100 ridge point is 295 FLOP/byte (FP16). Decode is 295x below the ridge point — deeply memory-bandwidth-bound. This means:

  • Tensor core utilization during decode: ~0.3%
  • Throughput limited by HBM bandwidth: 3.35 TB/s / (140 GB weights) = ~24 tokens/sec at batch=1
  • Batching helps: at batch=32, effective AI = 32 FLOP/byte — still below ridge but 32x more efficient

Decode Throughput vs Batch Size (Llama 70B, H100)

(tokens/sec total)
Batch 1 24 tok/s (deeply bandwidth-bound)
24 tokens/sec total
Batch 8 Near-linear scaling
180 tokens/sec total
Batch 32 Good utilization
650 tokens/sec total
Batch 128 Approaching compute saturation
1,800 tokens/sec total
Batch 256 Near peak
2,400 tokens/sec total

This is why the Inference Optimization Timeline series exists: every optimization (batching, quantization, FlashAttention, speculative decoding, disaggregated serving) attacks the decode memory-bandwidth bottleneck from a different angle.

Streaming Output

Production systems stream tokens as they are generated:

async def generate_streaming(model, prompt_ids, max_new_tokens, callback):
    """Generate tokens and stream each one via callback."""
    kv_cache = KVCache(...)

    # Prefill
    logits = model.forward(prompt_ids, kv_cache=kv_cache)
    next_token = sample_token(logits[0, -1, :])
    await callback(next_token)  # Stream first token (TTFT)

    # Decode with streaming
    for step in range(max_new_tokens - 1):
        logits = model.forward([[next_token]], kv_cache=kv_cache)
        next_token = sample_token(logits[0, 0, :])
        await callback(next_token)  # Stream each token (TBT)
        if next_token == eos_token_id:
            break

TTFT (Time to First Token) = prefill time. TBT (Time Between Tokens) = single decode step time. Users perceive TTFT as “response start latency” and TBT as “typing speed.”