Language models predict one token at a time, which is absurdly wasteful when you think about it. Your model computes a 4096-dimensional hidden state at position , then throws away all that information except for the single next-token prediction. Multi-Token Prediction (MTP), pioneered by DeepSeek V3, exploits this waste by adding additional prediction heads that simultaneously predict tokens from the same hidden state. This gives you richer training signal (your hidden states must encode more future information) and enables self-speculative decoding at inference (the extra heads serve as a built-in draft model). No extra forward passes during training, no separate draft model during inference — just smarter use of compute you’re already spending.
Standard Next-Token Head
import torch
import torch.nn as nn
class StandardLMHead(nn.Module):
"""Standard next-token prediction. Output: logits over vocabulary."""
def __init__(self, d_model, vocab_size, tie_weights=None):
super().__init__()
if tie_weights is not None:
self.weight = tie_weights # Shared with embedding
else:
self.weight = nn.Parameter(torch.randn(vocab_size, d_model) * 0.02)
def forward(self, hidden_states):
# hidden_states: [B, S, d_model]
return hidden_states @ self.weight.T # [B, S, vocab_size]
The head is a single linear projection: where is the embedding matrix (weight-tied). Cost: FLOPs. For Llama 70B (, ): 1.05 TFLOP per forward pass — 6% of total model FLOPs.
Multi-Token Prediction (MTP)
DeepSeek V3’s innovation: predict future tokens from the same hidden state, using separate prediction heads:
class MultiTokenPredictionHead(nn.Module):
"""Predict tokens t+1, t+2, ..., t+K from hidden state at position t."""
def __init__(self, d_model, vocab_size, num_future_tokens=4):
super().__init__()
self.K = num_future_tokens
# Each future token gets its own MLP head
self.heads = nn.ModuleList([
nn.Sequential(
nn.Linear(d_model, d_model),
nn.SiLU(),
nn.Linear(d_model, vocab_size),
)
for _ in range(self.K)
])
def forward(self, hidden_states):
"""
hidden_states: [B, S, d_model]
Returns: list of K tensors, each [B, S, vocab_size]
"""
return [head(hidden_states) for head in self.heads]
Training with MTP
The total loss combines standard next-token loss with MTP losses:
where is the cross-entropy for predicting the -th future token.
def mtp_loss(model, mtp_heads, input_ids, labels):
"""Compute combined next-token + multi-token prediction loss."""
hidden = model(input_ids).last_hidden_state # [B, S, d]
# Standard next-token loss
logits_next = model.lm_head(hidden) # [B, S, V]
loss_next = F.cross_entropy(
logits_next[:, :-1].reshape(-1, logits_next.size(-1)),
labels[:, 1:].reshape(-1),
)
# MTP losses for tokens t+2, t+3, ..., t+K
total_loss = loss_next
K = len(mtp_heads.heads)
for k in range(K):
logits_k = mtp_heads.heads[k](hidden) # [B, S, V]
shift = k + 2 # Predict t+2, t+3, ...
if shift < labels.size(1):
loss_k = F.cross_entropy(
logits_k[:, :-shift].reshape(-1, logits_k.size(-1)),
labels[:, shift:].reshape(-1),
)
lambda_k = 1.0 / (k + 2) # Decreasing weight for further tokens
total_loss = total_loss + lambda_k * loss_k
return total_loss
Predicting future tokens forces the hidden state to encode information about the upcoming sequence, not just the immediate next token. This creates richer representations that improve quality even at standard next-token-only inference. DeepSeek V3 reports 0.3-0.5 perplexity improvement from MTP training without any inference overhead (MTP heads are discarded after training if not used for speculation).
Inference: Self-Speculative Decoding
The MTP heads serve as a built-in draft model. At each decode step:
- Generate hidden state (standard forward pass)
- Head 0: predict token (standard LM head)
- Heads 1-3: predict tokens (MTP heads)
- Verify: run one forward pass on all 4 predicted tokens
- Accept consecutive correct predictions, reject at first mismatch
def mtp_speculative_step(model, mtp_heads, input_ids, kv_cache):
# Forward pass: get hidden states
out = model(input_ids, kv_cache=kv_cache)
hidden = out.last_hidden_state[:, -1:] # [B, 1, d]
# Draft K tokens using MTP heads
draft_tokens = [model.lm_head(hidden).argmax(dim=-1)] # t+1
for head in mtp_heads.heads:
draft_tokens.append(head(hidden).argmax(dim=-1)) # t+2, t+3, ...
draft = torch.cat(draft_tokens, dim=-1) # [B, K]
# Verify all K tokens in one forward pass
verify_out = model(draft, kv_cache=kv_cache)
verify_logits = model.lm_head(verify_out.last_hidden_state)
verify_tokens = verify_logits.argmax(dim=-1) # [B, K]
# Accept longest matching prefix
accepted = 0
for i in range(len(draft_tokens)):
if i == 0 or verify_tokens[:, i-1] == draft[:, i]:
accepted += 1
else:
break
return draft[:, :accepted] # Accepted tokens
MTP Self-Speculation vs Separate Draft Model
| Method | Extra Memory | Tokens/Step | Speedup |
|---|---|---|---|
| No speculation | 0 GB | 1.0 | 1.0x |
| Separate 7B draft | 14 GB | 2.5 | 2.2x |
| MTP K=2 (self-draft) | 0.5 GB | 1.8 | 1.6x |
| MTP K=4 (self-draft) | 1.0 GB | 2.3 | 2.0x |
Classifier Heads
For tasks like classification, sentiment analysis, or embedding generation, replace the LM head with a task-specific head:
class ClassifierHead(nn.Module):
"""Classification head on top of transformer."""
def __init__(self, d_model, num_classes, pooling="last"):
super().__init__()
self.pooling = pooling
self.classifier = nn.Sequential(
nn.Linear(d_model, d_model),
nn.Tanh(),
nn.Dropout(0.1),
nn.Linear(d_model, num_classes),
)
def forward(self, hidden_states):
# Pool: use last token (for causal LMs) or mean (for encoders)
if self.pooling == "last":
pooled = hidden_states[:, -1, :] # [B, d]
elif self.pooling == "mean":
pooled = hidden_states.mean(dim=1) # [B, d]
return self.classifier(pooled) # [B, num_classes]
Summary
Multi-Token Prediction turns a wasteful one-output-per-forward-pass pattern into a richer training signal and a free draft model for inference. The extra prediction heads add minimal memory (0.5-1.0 GB for 4 future tokens) compared to the 14 GB of loading a separate draft model, and they provide training benefits even if you never use them for speculation. DeepSeek V3’s results show this clearly: 0.3-0.5 perplexity improvement from MTP training alone, plus 2x inference speedup when using the heads for self-speculative decoding. The fundamental insight is that your hidden states already contain information about multiple future tokens — standard training just doesn’t ask the model to surface that information. MTP does, and both training and inference benefit.