Part of Series Transformer Anatomy 25 of 36
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Distributed Data Parallel: Gradient Synchronization, Bucket All-Reduce, and Overlap with Backward 21 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 22 Dropout and Regularization in Transformers: Where It Helps, Where It Hurts 23 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 24 Mixed Precision Training: BF16 Forward, FP32 Master Weights, and the Precision Hierarchy 25 Token Prediction Heads: Next-Token, Multi-Token, and Classifier Heads 26 Mixture of Depths: Conditional Computation Per Layer for Faster Inference 27 Sparse Attention Patterns: Local, Strided, Hash-Based, and Learnable Sparsity 28 Rotary Position Embedding: The Complete Mathematical Derivation 29 Knowledge Distillation: Training Small Models to Match Large Ones 30 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search 31 Pruning at Scale: SparseGPT, Wanda, and Structured Removal of Redundant Parameters 32 The Transformer in 2026: What Changed, What Stayed, and What's Next 33 Data Loading: Tokenization, Sequence Packing, Padding Strategies, and Attention Masks 34 The FlashAttention Backward Pass: Recomputation, Memory Savings, and the 33% Compute Overhead 35 The Inference Engine: Token Generation Loop, KV Cache Management, and Autoregressive Decoding 36 Tensor Parallelism Implementation: Splitting Weights Across GPUs for Training and Inference

Language models predict one token at a time, which is absurdly wasteful when you think about it. Your model computes a 4096-dimensional hidden state at position tt, then throws away all that information except for the single next-token prediction. Multi-Token Prediction (MTP), pioneered by DeepSeek V3, exploits this waste by adding K1K-1 additional prediction heads that simultaneously predict tokens t+2,t+3,,t+Kt+2, t+3, \ldots, t+K from the same hidden state. This gives you richer training signal (your hidden states must encode more future information) and enables self-speculative decoding at inference (the extra heads serve as a built-in draft model). No extra forward passes during training, no separate draft model during inference — just smarter use of compute you’re already spending.

Standard Next-Token Head

import torch
import torch.nn as nn

class StandardLMHead(nn.Module):
    """Standard next-token prediction. Output: logits over vocabulary."""
    def __init__(self, d_model, vocab_size, tie_weights=None):
        super().__init__()
        if tie_weights is not None:
            self.weight = tie_weights  # Shared with embedding
        else:
            self.weight = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)

    def forward(self, hidden_states):
        # hidden_states: [B, S, d_model]
        return hidden_states @ self.weight.T  # [B, S, vocab_size]

The head is a single linear projection: logits=hET\text{logits} = h \cdot E^T where EE is the embedding matrix (weight-tied). Cost: B×S×d×VB \times S \times d \times V FLOPs. For Llama 70B (d=8192d=8192, V=128256V=128256): 1.05 TFLOP per forward pass — 6% of total model FLOPs.

Multi-Token Prediction (MTP)

DeepSeek V3’s innovation: predict KK future tokens from the same hidden state, using KK separate prediction heads:

class MultiTokenPredictionHead(nn.Module):
    """Predict tokens t+1, t+2, ..., t+K from hidden state at position t."""

    def __init__(self, d_model, vocab_size, num_future_tokens=4):
        super().__init__()
        self.K = num_future_tokens
        # Each future token gets its own MLP head
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_model),
                nn.SiLU(),
                nn.Linear(d_model, vocab_size),
            )
            for _ in range(self.K)
        ])

    def forward(self, hidden_states):
        """
        hidden_states: [B, S, d_model]
        Returns: list of K tensors, each [B, S, vocab_size]
        """
        return [head(hidden_states) for head in self.heads]

Training with MTP

The total loss combines standard next-token loss with MTP losses:

Ltotal=Lnext+k=2KλkLk\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{next}} + \sum_{k=2}^{K} \lambda_k \mathcal{L}_k

where Lk=tlogpk(xt+kht)\mathcal{L}_k = -\sum_t \log p_k(x_{t+k} | h_t) is the cross-entropy for predicting the kk-th future token.

def mtp_loss(model, mtp_heads, input_ids, labels):
    """Compute combined next-token + multi-token prediction loss."""
    hidden = model(input_ids).last_hidden_state  # [B, S, d]

    # Standard next-token loss
    logits_next = model.lm_head(hidden)  # [B, S, V]
    loss_next = F.cross_entropy(
        logits_next[:, :-1].reshape(-1, logits_next.size(-1)),
        labels[:, 1:].reshape(-1),
    )

    # MTP losses for tokens t+2, t+3, ..., t+K
    total_loss = loss_next
    K = len(mtp_heads.heads)
    for k in range(K):
        logits_k = mtp_heads.heads[k](hidden)  # [B, S, V]
        shift = k + 2  # Predict t+2, t+3, ...
        if shift < labels.size(1):
            loss_k = F.cross_entropy(
                logits_k[:, :-shift].reshape(-1, logits_k.size(-1)),
                labels[:, shift:].reshape(-1),
            )
            lambda_k = 1.0 / (k + 2)  # Decreasing weight for further tokens
            total_loss = total_loss + lambda_k * loss_k

    return total_loss
Why MTP Improves Training Quality

Predicting future tokens forces the hidden state to encode information about the upcoming sequence, not just the immediate next token. This creates richer representations that improve quality even at standard next-token-only inference. DeepSeek V3 reports 0.3-0.5 perplexity improvement from MTP training without any inference overhead (MTP heads are discarded after training if not used for speculation).

Inference: Self-Speculative Decoding

The MTP heads serve as a built-in draft model. At each decode step:

  1. Generate hidden state hth_t (standard forward pass)
  2. Head 0: predict token t+1t+1 (standard LM head)
  3. Heads 1-3: predict tokens t+2,t+3,t+4t+2, t+3, t+4 (MTP heads)
  4. Verify: run one forward pass on all 4 predicted tokens
  5. Accept consecutive correct predictions, reject at first mismatch
def mtp_speculative_step(model, mtp_heads, input_ids, kv_cache):
    # Forward pass: get hidden states
    out = model(input_ids, kv_cache=kv_cache)
    hidden = out.last_hidden_state[:, -1:]  # [B, 1, d]

    # Draft K tokens using MTP heads
    draft_tokens = [model.lm_head(hidden).argmax(dim=-1)]  # t+1
    for head in mtp_heads.heads:
        draft_tokens.append(head(hidden).argmax(dim=-1))  # t+2, t+3, ...

    draft = torch.cat(draft_tokens, dim=-1)  # [B, K]

    # Verify all K tokens in one forward pass
    verify_out = model(draft, kv_cache=kv_cache)
    verify_logits = model.lm_head(verify_out.last_hidden_state)
    verify_tokens = verify_logits.argmax(dim=-1)  # [B, K]

    # Accept longest matching prefix
    accepted = 0
    for i in range(len(draft_tokens)):
        if i == 0 or verify_tokens[:, i-1] == draft[:, i]:
            accepted += 1
        else:
            break

    return draft[:, :accepted]  # Accepted tokens
📊

MTP Self-Speculation vs Separate Draft Model

MethodExtra MemoryTokens/StepSpeedup
No speculation 0 GB 1.0 1.0x
Separate 7B draft 14 GB 2.5 2.2x
MTP K=2 (self-draft) 0.5 GB 1.8 1.6x
MTP K=4 (self-draft) 1.0 GB 2.3 2.0x
Note: MTP heads add minimal memory (0.5-1.0 GB) vs 14 GB for a separate draft model. Acceptance rate is slightly lower (same hidden state, less context for future tokens) but the memory savings are significant.

Classifier Heads

For tasks like classification, sentiment analysis, or embedding generation, replace the LM head with a task-specific head:

class ClassifierHead(nn.Module):
    """Classification head on top of transformer."""
    def __init__(self, d_model, num_classes, pooling="last"):
        super().__init__()
        self.pooling = pooling
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.Tanh(),
            nn.Dropout(0.1),
            nn.Linear(d_model, num_classes),
        )

    def forward(self, hidden_states):
        # Pool: use last token (for causal LMs) or mean (for encoders)
        if self.pooling == "last":
            pooled = hidden_states[:, -1, :]  # [B, d]
        elif self.pooling == "mean":
            pooled = hidden_states.mean(dim=1)  # [B, d]
        return self.classifier(pooled)  # [B, num_classes]

Summary

Multi-Token Prediction turns a wasteful one-output-per-forward-pass pattern into a richer training signal and a free draft model for inference. The extra prediction heads add minimal memory (0.5-1.0 GB for 4 future tokens) compared to the 14 GB of loading a separate draft model, and they provide training benefits even if you never use them for speculation. DeepSeek V3’s results show this clearly: 0.3-0.5 perplexity improvement from MTP training alone, plus 2x inference speedup when using the heads for self-speculative decoding. The fundamental insight is that your hidden states already contain information about multiple future tokens — standard training just doesn’t ask the model to surface that information. MTP does, and both training and inference benefit.