Part of Series Transformer Anatomy 11 of 23
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text β€” From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 21 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 22 Knowledge Distillation: Training Small Models to Match Large Ones 23 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search

By the time a token’s representation exits the final transformer layer, it has been refined through dozens of attention and feed-forward transformations. It exists as a vector in Rdmodel\mathbb{R}^{d_\text{model}} β€” a point in the model’s internal semantic space. But this vector is useless to the outside world. The user needs a word, not a 8192-dimensional vector. The system orchestrating generation needs a probability distribution over the vocabulary, not a latent representation.

The output head is the component that bridges this gap. It takes the final hidden state and projects it into a vector of length VV β€” one logit per token in the vocabulary β€” which is then converted to probabilities via softmax. This operation is the mirror image of embedding: where the embedding layer maps discrete token IDs into continuous vectors, the output head maps continuous vectors back into scores over discrete tokens. It is, in the most literal sense, the unembedding.

This post covers the full story: what the output head computes, why weight tying with the embedding matrix works and saves enormous amounts of memory, the computational bottleneck created by large vocabularies, techniques for efficient projection at scale, the logit lens as an interpretability tool, and connections to speculative decoding.


What the Output Head Computes

From Hidden State to Logits

Let ht∈Rdmodelh_t \in \mathbb{R}^{d_\text{model}} be the hidden state at position tt after the final transformer layer and final LayerNorm. The output head applies a linear projection:

zt=WUht+bUz_t = W_U h_t + b_U

where WU∈RVΓ—dmodelW_U \in \mathbb{R}^{V \times d_\text{model}} is the unembedding matrix (also called the output projection or LM head), bU∈RVb_U \in \mathbb{R}^V is an optional bias, and zt∈RVz_t \in \mathbb{R}^V is the vector of logits β€” one per vocabulary token.

Most modern LLMs omit the bias term, so the projection simplifies to:

zt=WUhtz_t = W_U h_t

Each logit zt,iz_{t,i} is the dot product of the hidden state with the ii-th row of WUW_U:

zt,i=wi⊀htz_{t,i} = w_i^\top h_t

where wiw_i is the unembedding vector for vocabulary token ii. This dot product measures the alignment between the model’s internal representation and the direction in semantic space associated with token ii. Tokens whose unembedding vectors align closely with the hidden state receive high logits; tokens pointing in unrelated directions receive low logits.

From Logits to Probabilities

The logits are converted to a probability distribution via softmax:

p(xt+1=i∣x≀t)=exp⁑(zt,i)βˆ‘j=1Vexp⁑(zt,j)p(x_{t+1} = i \mid x_{\leq t}) = \frac{\exp(z_{t,i})}{\sum_{j=1}^{V} \exp(z_{t,j})}

This is the model’s prediction for the next token. During training, these probabilities are compared to the ground-truth token via cross-entropy loss. During inference, a sampling strategy (greedy, top-k, top-p, temperature scaling) selects the next token from this distribution.

ℹ️ The Output Head Is Just a Linear Layer

Despite its importance, the output head is architecturally trivial β€” a single matrix multiplication with no nonlinearity. All of the model’s β€œintelligence” lives in the transformer layers that produced hth_t. The output head simply reads off how well the final representation aligns with each vocabulary token’s direction.

The Unembedding as Inverse Embedding

The embedding layer maps token ID ii to a vector ei∈Rdmodele_i \in \mathbb{R}^{d_\text{model}} by looking up the ii-th row of an embedding matrix WE∈RVΓ—dmodelW_E \in \mathbb{R}^{V \times d_\text{model}}. The unembedding layer does the reverse: it takes a vector h∈Rdmodelh \in \mathbb{R}^{d_\text{model}} and computes its similarity with every embedding vector, producing a score for each token.

If we think of WEW_E as defining a coordinate system in semantic space, then WUW_U defines how to read coordinates back out. The embedding converts discrete symbols to continuous geometry; the unembedding converts continuous geometry back to discrete scores.

Σ Theorem: Embedding-Unembedding Duality

Let WE∈RVΓ—dmodelW_E \in \mathbb{R}^{V \times d_\text{model}} be the token embedding matrix. The logit for token ii given hidden state hh is:

zi=wU,i⊀hz_i = w_{U,i}^\top h

If WU=WEW_U = W_E (weight tying), then zi=ei⊀hz_i = e_i^\top h, and the logit is literally the dot-product similarity between the hidden state and the embedding of token ii. The model predicts the token whose embedding is most aligned with its final hidden state.

This duality is not just mathematically elegant β€” it is the foundation of weight tying, which we turn to next.


Weight Tying: Sharing the Embedding and Unembedding Matrices

The Core Idea

Weight tying (also called weight sharing or tied embeddings) sets WU=WE⊀W_U = W_E^\top, so the unembedding matrix is exactly the transpose of the embedding matrix. Since WE∈RVΓ—dmodelW_E \in \mathbb{R}^{V \times d_\text{model}}, we have WU=WE⊀∈RdmodelΓ—VW_U = W_E^\top \in \mathbb{R}^{d_\text{model} \times V}, but because the logit computation is z=WUhz = W_U h where WU∈RVΓ—dmodelW_U \in \mathbb{R}^{V \times d_\text{model}}, the tied version simply uses WEW_E itself:

zt=WEhtz_t = W_E h_t

No separate unembedding matrix is stored. The same matrix that maps tokens to embeddings also maps hidden states to logits.

Why Weight Tying Works

The deeper question is: why should the same matrix work for both directions? The answer lies in the geometry of the shared semantic space.

The shared space argument. Both the embedding and the output head operate in Rdmodel\mathbb{R}^{d_\text{model}}. The embedding layer places tokens at positions in this space based on their semantic properties. The output head needs to identify which token the model is predicting by finding the closest point in the same space. If the embedding places β€œcat” at position ecate_\text{cat} and the transformer’s job is to produce a hidden state hh that points toward ecate_\text{cat} when predicting β€œcat”, then using the embedding matrix for both operations is not just convenient β€” it is the natural choice.

The regularization argument. Weight tying constrains the model by forcing the input and output representations to be consistent. Without tying, the model could learn an embedding space where semantically similar tokens are far apart and an unembedding space that compensates for this. Tying prevents such degenerate solutions and acts as a structural regularizer.

The gradient flow argument. During backpropagation, gradients from the loss function flow through the output head and update WUW_U. With weight tying, these same gradients also update WEW_E. This means the embedding matrix receives training signal from two sources: (1) the forward pass through the transformer (how embeddings are consumed by attention and FFN layers) and (2) the output projection (how well embeddings serve as targets for prediction). This dual gradient signal improves the quality of the embedding space, especially for rare tokens that appear infrequently in training data.

πŸ’‘ Historical Note

Weight tying was introduced by Press and Wolf (2017) in β€œUsing the Output Embedding to Improve Language Models” and independently by Inan et al. (2017). It was adopted by virtually every major language model since, including GPT-2, BERT, T5, Llama, Mistral, and most open-source models. The technique is so standard that papers typically mention it only in passing.

Memory Savings

Weight tying eliminates one of the largest parameter matrices in the model. The savings are proportional to VΓ—dmodelV \times d_\text{model}, which for modern models is enormous.

πŸ“Š

Memory Savings from Weight Tying

Model ConfigVd_modelUntied ParamsSaved by Tying% of Total Params
GPT-2 (124M) 50,257 768 77.2M 38.6M 31.1%
Llama 2 7B 32,000 4,096 262M 131M 1.9%
Llama 3 8B 128,000 4,096 1,049M 524M 6.5%
Llama 3 70B 128,000 8,192 2,097M 1,049M 1.5%
Hypothetical 405B 128,000 16,384 4,194M 2,097M 0.5%
Note: Untied params = 2 * V * d_model (embedding + unembedding). Saved = V * d_model. BF16 memory = saved params * 2 bytes.

For a model with V=128,000V = 128{,}000 and dmodel=8,192d_\text{model} = 8{,}192:

WUΒ size=128,000Γ—8,192=1,048,576,000Β parametersW_U \text{ size} = 128{,}000 \times 8{,}192 = 1{,}048{,}576{,}000 \text{ parameters}

At 2 bytes per parameter (BF16), that is approximately 2 GB of memory saved. For a 70B model occupying roughly 140 GB in BF16, this is a modest 1.5% savings. But for smaller models like Llama 3 8B, the embedding matrix represents 6.5% of total parameters β€” a meaningful reduction.

The savings become even more significant during training, where optimizer states (Adam’s first and second moments) triple the effective memory per parameter. Eliminating 1 billion parameters saves approximately 6 GB of optimizer state in mixed-precision training.

Unembedding Matrix Size by Vocabulary and Hidden Dimension

(GB (BF16))
V=32K, d=4096 Llama 2 scale
0.25 GB (BF16)
V=50K, d=768 GPT-2 scale
0.07 GB (BF16)
V=128K, d=4096 Llama 3 8B scale
1 GB (BF16)
V=128K, d=8192 Llama 3 70B scale
2 GB (BF16)
V=128K, d=16384 Hypothetical 400B+
4 GB (BF16)

Quality Impact: Neutral to Positive

The empirical evidence consistently shows that weight tying does not hurt quality and often helps slightly.

Press and Wolf (2017) reported perplexity improvements of 1β€”3 points on Penn Treebank when tying weights in LSTM language models. For transformers, the effect is smaller but still non-negative. The T5 paper (Raffel et al., 2020) used weight tying by default and ablated it, finding negligible quality difference. GPT-2 and GPT-3 use weight tying. BERT uses weight tying. The Llama series uses weight tying.

The rare cases where untied weights outperform tied weights typically involve very large vocabularies (200K+) with very small hidden dimensions, where the embedding matrix’s rank is insufficient to serve dual purposes. At typical LLM scales (dmodelβ‰₯4096d_\text{model} \geq 4096), this is not a concern.

⚑ When to Untie

The main reason to use untied weights is when the input and output distributions are fundamentally different β€” for example, in machine translation where the source and target languages have separate vocabularies. For standard causal language models where the same vocabulary serves as both input and output, weight tying is almost universally beneficial.


The Vocabulary Projection Bottleneck

The Scale of the Problem

The output head performs a matrix-vector multiplication z=WEhz = W_E h where WE∈RVΓ—dmodelW_E \in \mathbb{R}^{V \times d_\text{model}}. For a single token, this is a matrix-vector product. For a batch of BB tokens, it becomes a matrix-matrix product:

Z=HWE⊀Z = H W_E^\top

where H∈RBΓ—dmodelH \in \mathbb{R}^{B \times d_\text{model}} and Z∈RBΓ—VZ \in \mathbb{R}^{B \times V}.

The number of floating-point operations (FLOPs) for this single matmul is:

FLOPs=2Γ—BΓ—VΓ—dmodel\text{FLOPs} = 2 \times B \times V \times d_\text{model}

For B=1B = 1 (single-token decode), V=128,000V = 128{,}000, dmodel=8,192d_\text{model} = 8{,}192:

FLOPs=2Γ—1Γ—128,000Γ—8,192=2.1Γ—109=2.1Β GFLOPs\text{FLOPs} = 2 \times 1 \times 128{,}000 \times 8{,}192 = 2.1 \times 10^9 = 2.1 \text{ GFLOPs}

That is 2.1 billion FLOPs for a single token’s output projection. For comparison, a single attention head’s QKV projection at the same dimension is:

FLOPsQKV=2Γ—1Γ—3Γ—dmodelΓ—dmodel=2Γ—3Γ—81922=402Β GFLOPs\text{FLOPs}_\text{QKV} = 2 \times 1 \times 3 \times d_\text{model} \times d_\text{model} = 2 \times 3 \times 8192^2 = 402 \text{ GFLOPs}

The entire QKV projection across all heads costs about 400 GFLOPs. The output head at 2.1 GFLOPs is roughly 0.5% of that. But during autoregressive decode, the output head runs on every single token, and its memory bandwidth requirement (loading the full VΓ—dmodelV \times d_\text{model} matrix) is substantial.

Parameter Count Perspective

The unembedding matrix (or the tied embedding matrix) contains VΓ—dmodelV \times d_\text{model} parameters. For modern models, this is a significant fraction of the total.

πŸ“Š

Output Head Parameters vs Total Model Parameters

ModelTotal ParamsEmbedding Params% of TotalEquivalent Transformer Layers
GPT-2 124M 124M 38.6M 31.1% ~10 layers
Llama 2 7B 6.7B 131M 1.9% ~1 layer
Llama 3 8B 8.0B 524M 6.5% ~3 layers
Llama 3 70B 70B 1,049M 1.5% ~1 layer
Llama 3 405B 405B 2,097M 0.5% ~0.5 layers
Note: Equivalent layers computed as embedding params / params-per-layer. Larger V inflates embedding cost.

The jump from Llama 2’s 32K vocabulary to Llama 3’s 128K vocabulary quadrupled the embedding parameter count. This was a deliberate trade-off: the larger vocabulary improves tokenization efficiency (fewer tokens per document, enabling longer effective context) at the cost of a larger embedding matrix.

The Bandwidth Problem During Decode

During autoregressive decoding, each token generation requires loading the entire embedding matrix from HBM to compute the output projection. The memory bandwidth cost is:

BytesΒ loaded=VΓ—dmodelΓ—dtype\text{Bytes loaded} = V \times d_\text{model} \times d_\text{type}

For V=128,000V = 128{,}000, dmodel=8,192d_\text{model} = 8{,}192, BF16:

128,000Γ—8,192Γ—2=2.1Β GB128{,}000 \times 8{,}192 \times 2 = 2.1 \text{ GB}

On an H100 with 3.35 TB/s of HBM bandwidth, loading 2.1 GB takes:

2.1Β GB3,350Β GB/sβ‰ˆ0.63Β ms\frac{2.1 \text{ GB}}{3{,}350 \text{ GB/s}} \approx 0.63 \text{ ms}

This is the arithmetic intensity problem. The output head does 2.1 GFLOPs of compute while loading 2.1 GB of data, giving an arithmetic intensity of:

2.1Γ—109Β FLOPs2.1Γ—109Β bytes=1.0Β FLOPs/byte\frac{2.1 \times 10^9 \text{ FLOPs}}{2.1 \times 10^9 \text{ bytes}} = 1.0 \text{ FLOPs/byte}

The H100’s compute-to-bandwidth ratio is roughly 990Β TFLOPS/3.35Β TB/sβ‰ˆ295Β FLOPs/byte990 \text{ TFLOPS} / 3.35 \text{ TB/s} \approx 295 \text{ FLOPs/byte}. An arithmetic intensity of 1.0 means the output head is 295x more bandwidth-bound than compute-bound during single-token decode. The GPU’s compute units sit almost entirely idle while waiting for the embedding matrix to stream from memory.

Σ Definition: Arithmetic Intensity of the Output Head

The arithmetic intensity of the vocabulary projection for batch size BB is:

AI=2BVdVdβ‹…dtype+Bdβ‹…dtype+BVβ‹…dtypeβ‰ˆ2BdtypeΒ (whenΒ Bβ‰ͺV)\text{AI} = \frac{2BVd}{Vd \cdot d_\text{type} + Bd \cdot d_\text{type} + BV \cdot d_\text{type}} \approx \frac{2B}{d_\text{type}} \text{ (when } B \ll V\text{)}

For BF16 (dtype=2d_\text{type} = 2) at B=1B = 1: AI = 1.0. At B=128B = 128: AI = 128. The output head becomes compute-bound only at large batch sizes, which is precisely why batched inference is essential for throughput.


Efficient Vocabulary Projection

Tensor-Parallel Sharding

In tensor-parallel (TP) inference, the embedding matrix is sharded across GPUs along the vocabulary dimension. With NN GPUs, each GPU stores a V/NΓ—dmodelV/N \times d_\text{model} shard and computes logits for its slice of the vocabulary. An all-gather operation collects the full logit vector.

For V=128,000V = 128{,}000 on 8 GPUs, each GPU handles 16,000 vocabulary entries. The per-GPU memory for the embedding shard is:

128,000Γ—8,192Γ—28=262Β MB\frac{128{,}000 \times 8{,}192 \times 2}{8} = 262 \text{ MB}

This is easily manageable. The all-gather to collect logits transfers BΓ—VΓ—dtypeB \times V \times d_\text{type} bytes per step:

1Γ—128,000Γ—2=256Β KB1 \times 128{,}000 \times 2 = 256 \text{ KB}

This is negligible on NVLink. The vocabulary projection is one of the easiest operations to parallelize because there are no data dependencies between different vocabulary entries.

# Tensor-parallel output head (simplified)
class TPOutputHead:
    """
    Each GPU computes logits for its vocabulary shard.
    Final all-gather produces the full logit vector.
    """
    def __init__(self, d_model: int, vocab_size: int, tp_rank: int, tp_world: int):
        self.start = (vocab_size // tp_world) * tp_rank
        self.end = (vocab_size // tp_world) * (tp_rank + 1)
        # Each GPU holds only its shard of the embedding matrix
        self.weight = embedding_matrix[self.start:self.end]  # shape: [V/N, d_model]

    def forward(self, h: Tensor) -> Tensor:
        # Local logits for this GPU's vocabulary shard
        local_logits = h @ self.weight.T  # shape: [B, V/N]
        # All-gather to get full logit vector
        full_logits = all_gather(local_logits, dim=-1)  # shape: [B, V]
        return full_logits

Output Head Latency by Tensor Parallelism Degree (Single Token Decode)

(ms)
TP=1 (1 GPU) Full matrix load
0.63 ms
TP=2 Half matrix per GPU
0.33 ms
TP=4 Quarter matrix
0.18 ms
TP=8 Eighth matrix + all-gather
0.11 ms

Fused Softmax + Cross-Entropy for Training

During training, the output head produces logits that feed into softmax and then cross-entropy loss. A naive implementation materializes the full BΓ—VB \times V logit tensor, applies softmax, and then computes the loss. For B=2048B = 2048 (tokens in a micro-batch) and V=128,000V = 128{,}000:

LogitΒ tensorΒ size=2,048Γ—128,000Γ—4=1.05Β GBΒ (FP32)\text{Logit tensor size} = 2{,}048 \times 128{,}000 \times 4 = 1.05 \text{ GB (FP32)}

Materializing this tensor is wasteful because we ultimately only need the loss value (a scalar) and the gradient with respect to the logits. The fused softmax + cross-entropy kernel computes the loss without ever materializing the full softmax output:

  1. Compute logits row by row (or in small tiles).
  2. For each row, find the maximum logit (for numerical stability), compute the log-sum-exp, and subtract the logit at the target position.
  3. The gradient βˆ‚L/βˆ‚zi=piβˆ’1[i=y]\partial L / \partial z_i = p_i - \mathbb{1}[i = y] (softmax probability minus one-hot target) is computed in the same kernel.

This reduces peak memory from O(BV)O(BV) for the logit tensor to O(B)O(B) for the per-position losses. For training at scale, this fusion is critical.

# Fused cross-entropy: never materializes V-dimensional softmax output
def fused_cross_entropy(logits: Tensor, targets: Tensor) -> Tensor:
    """
    logits: [B, V] -- raw output from the LM head
    targets: [B] -- integer token IDs

    Computes loss without materializing the full softmax.
    In practice, this is a custom CUDA kernel.
    """
    # Numerically stable log-softmax (computed per-row)
    max_logits = logits.max(dim=-1, keepdim=True).values
    log_sum_exp = (logits - max_logits).exp().sum(dim=-1, keepdim=True).log()
    log_probs = logits - max_logits - log_sum_exp

    # Gather only the target positions -- no need for full softmax
    target_log_probs = log_probs.gather(dim=-1, index=targets.unsqueeze(-1))

    return -target_log_probs.mean()
⚑ Memory Savings from Fusion

For Llama 3 training with micro-batch of 4096 tokens and V=128,000V = 128{,}000, the fused kernel saves approximately 2 GB of activation memory per micro-batch. Across pipeline stages and gradient accumulation steps, this can free 10β€”20 GB per GPU β€” enough to increase batch size or sequence length.

Chunked Output Projection

An alternative to computing the full BΓ—VB \times V matmul in one shot is to chunk along the vocabulary dimension. Compute logits for a chunk of VchunkV_\text{chunk} vocabulary entries, immediately apply the fused loss for those entries, and accumulate. This further reduces peak memory:

PeakΒ logitΒ memory=BΓ—VchunkΓ—dtype\text{Peak logit memory} = B \times V_\text{chunk} \times d_\text{type}

For Vchunk=16,000V_\text{chunk} = 16{,}000 (8 chunks for V=128,000V = 128{,}000), peak logit memory drops from 1 GB to 128 MB. Libraries like liger-kernel implement this optimization.


The Logit Lens: Reading the Model’s Mind

The Key Insight

Here is a remarkable observation: the output head can be applied not just to the final layer’s hidden state, but to any intermediate layer’s hidden state. If we take the hidden state ht(β„“)h_t^{(\ell)} at layer β„“\ell and project it through the output head:

zt(β„“)=WEht(β„“)z_t^{(\ell)} = W_E h_t^{(\ell)}

we get a β€œprediction” for the next token at layer β„“\ell. This is the logit lens, introduced by nostalgebraist (2020).

The logit lens reveals how the transformer progressively refines its prediction layer by layer. At early layers, the prediction is typically diffuse or incoherent. By middle layers, the correct token often emerges as a top candidate. By the final layers, the prediction sharpens to a confident distribution.

Σ Theorem: Logit Lens Interpretation

Let p(β„“)p^{(\ell)} be the probability distribution obtained by applying the output head to layer β„“\ellβ€˜s hidden state. The sequence p(1),p(2),…,p(L)p^{(1)}, p^{(2)}, \ldots, p^{(L)} traces the model’s β€œthought process” β€” how it iteratively builds its prediction from raw token representations to a refined next-token distribution. Each layer can be interpreted as a refinement step that moves the hidden state closer to the target direction in embedding space.

What the Logit Lens Reveals

Early layers (1β€”5): The logit lens typically shows high probability on the current token or tokens that frequently follow it in the training data. The model has not yet processed much contextual information. The top predictions are essentially bigram statistics.

Middle layers (10β€”20): Contextual information begins to dominate. For tasks like factual recall (β€œThe capital of France is”), the correct answer (β€œParis”) starts appearing in the top predictions around layers 10β€”15 for a 32-layer model. This is where the model transitions from local to global reasoning.

Late layers (25β€”32): The prediction sharpens. The probability mass concentrates on the correct token. The final few layers often make only minor adjustments, suggesting that much of the β€œwork” is done in the middle of the network.

def logit_lens(model, input_ids):
    """
    Apply the output head to every intermediate layer's hidden state.
    Returns a [num_layers, seq_len, vocab_size] tensor of logit distributions.
    """
    hidden_states = []

    # Hook to capture intermediate hidden states
    def hook_fn(module, input, output):
        # After each transformer layer, capture the hidden state
        hidden_states.append(output[0].detach())

    hooks = []
    for layer in model.transformer.layers:
        hooks.append(layer.register_forward_hook(hook_fn))

    with torch.no_grad():
        model(input_ids)

    for hook in hooks:
        hook.remove()

    # Apply output head (shared embedding matrix) to each layer's output
    all_logits = []
    for h in hidden_states:
        # Apply final layer norm before projection
        h_normed = model.transformer.final_norm(h)
        logits = h_normed @ model.lm_head.weight.T
        all_logits.append(logits)

    return torch.stack(all_logits)  # [num_layers, batch, seq, vocab]

The Tuned Lens: A Refinement

The logit lens has a limitation: intermediate hidden states may not be in the same β€œformat” as the final hidden state. The model’s residual stream undergoes LayerNorm at the end, and intermediate states may use the representational space differently than the final state.

The tuned lens (Belrose et al., 2023) addresses this by learning a small affine transformation per layer that maps each layer’s hidden state into the final layer’s representational format:

zt(β„“)=WEβ‹…AffineProbe(β„“)(ht(β„“))z_t^{(\ell)} = W_E \cdot \text{AffineProbe}^{(\ell)}(h_t^{(\ell)})

where AffineProbe(β„“)\text{AffineProbe}^{(\ell)} is trained (with the base model frozen) to predict the final layer’s output. The tuned lens produces cleaner trajectories and reveals layer-by-layer refinement more clearly than the raw logit lens.

πŸ“Š

Logit Lens vs Tuned Lens: Top-1 Accuracy at Intermediate Layers (Llama 2 7B)

LayerLogit Lens Top-1Tuned Lens Top-1Improvement
Layer 4 / 32 8.2% 15.1% +6.9%
Layer 8 / 32 18.7% 29.4% +10.7%
Layer 16 / 32 42.3% 51.8% +9.5%
Layer 24 / 32 61.5% 65.2% +3.7%
Layer 32 / 32 68.4% 68.4% 0.0%
Note: Top-1 next-token accuracy on a held-out subset of The Pile. Final layer accuracy is identical by construction.

The gap between the logit lens and tuned lens is largest at early and middle layers, where the representational format differs most from the final layer. By the last quarter of the network, the two converge, suggesting that the model’s hidden states are already in approximately the right format for the output head.


Connection to Speculative Decoding

Why Speculative Decoding Works

Speculative decoding uses a small β€œdraft” model to generate candidate tokens quickly, then verifies them in parallel with the large β€œtarget” model. The key question is: why should a small model’s predictions be useful for a large model?

The answer connects directly to the output head and weight tying. Both models project hidden states through an embedding matrix into the same vocabulary space. If the smaller model has learned a similar (though less refined) semantic space, its top predictions will overlap substantially with the larger model’s.

Consider what happens when both models predict with high confidence. The small model might assign 85% probability to β€œParis” after β€œThe capital of France is”, while the large model assigns 95%. The token is the same. The small model’s logit vector, while less sharply peaked, points in the same direction in vocabulary space.

This is not a coincidence. Both models are trained on similar data, and weight tying forces both to organize their embedding spaces so that semantically similar tokens are nearby. The shared structure of natural language means that a 1B model and a 70B model will agree on the next token the vast majority of the time β€” disagreements occur mainly on ambiguous continuations where multiple tokens are plausible.

ℹ️ Acceptance Rate and Semantic Alignment

The acceptance rate in speculative decoding (the fraction of draft tokens accepted by the target model) is directly related to how well the two models’ output distributions align. Empirically, a well-chosen 1B draft model achieves 70β€”85% acceptance rates with a 70B target model on typical text, meaning their output heads agree on the top token about 3 out of 4 times.

The Medusa Approach: Multiple Output Heads

Medusa (Cai et al., 2024) takes the output head concept further by adding multiple parallel LM heads to a single model. Each head predicts a different future token: head 0 predicts xt+1x_{t+1}, head 1 predicts xt+2x_{t+2}, head 2 predicts xt+3x_{t+3}, and so on.

Each Medusa head is a small MLP followed by a vocabulary projection (using the tied embedding matrix):

zt,k=WEβ‹…MLPk(ht)z_{t,k} = W_E \cdot \text{MLP}_k(h_t)

where kk is the lookahead index. The additional MLP layers (∼2\sim 2 layers of dmodelΓ—dmodeld_\text{model} \times d_\text{model}) are tiny compared to the base model and can be trained cheaply with the base model frozen.

This creates a tree of candidate continuations that can be verified in a single forward pass, achieving 2β€”3x speedup without any separate draft model. The key insight is that the output head’s vocabulary projection is the expensive part; the small MLP that adapts the hidden state for multi-step prediction adds negligible cost.


Tied vs Untied Weights: Empirical Results

Large-Scale Ablations

Despite the ubiquity of weight tying, there are scenarios where untied weights are preferable. Let us examine the empirical evidence.

T5 ablation (Raffel et al., 2020): T5-Base with tied weights achieved 83.28 GLUE score; untied achieved 83.15. The difference is within noise. T5 uses weight tying by default.

PaLM (Chowdhery et al., 2022): PaLM 540B uses untied embeddings. At this scale, the embedding matrix is less than 0.5% of total parameters, so the memory savings are negligible. The authors found that untied weights provided a very small quality improvement, hypothesizing that at extreme scale, the model benefits from the additional capacity.

Llama (Touvron et al., 2023): All Llama models use weight tying. The Llama technical reports do not ablate this choice, suggesting the authors considered it settled.

Mistral and Mixtral: Use weight tying. The Mixtral MoE architecture has so many parameters in the expert FFN layers that the embedding matrix is a tiny fraction of total parameters regardless.

πŸ“Š

Weight Tying: Empirical Comparison Across Models

StudyModel SizeTiedUntiedWinnerNotes
Press & Wolf 2017 LSTM LM PPL 75.2 PPL 77.8 Tied 3.4% improvement
T5 (Raffel 2020) 220M--11B 83.28 GLUE 83.15 GLUE Tied (~tie) Within noise
PaLM (2022) 540B -- -- Untied Marginal gain at scale
Llama 2 (2023) 7B--70B Used -- Tied (default) Not ablated
Gemma (2024) 2B--7B Used -- Tied (default) Not ablated
Note: Most modern LLMs use weight tying without ablating the choice, indicating it is considered settled best practice.

When to Use Untied Weights

Based on the literature, untied weights are worth considering in these scenarios:

  1. Extremely large models (500B+) where the embedding matrix is a negligible fraction of total parameters and the additional capacity is essentially free.

  2. Cross-lingual or multi-modal models where the input and output distributions differ significantly. A model that takes image patches as input but produces text tokens as output should not tie the image encoder’s projection with the text unembedding.

  3. Models with very large vocabularies (200K+) and small hidden dimensions where the embedding matrix is rank-deficient for dual use. With V=256,000V = 256{,}000 and dmodel=2,048d_\text{model} = 2{,}048, the embedding matrix has rank at most 2,048 β€” it cannot represent 256,000 independent directions, and the compromise between embedding quality and unembedding quality may hurt both.

  4. Encoder-decoder models where the encoder and decoder have different vocabularies or the encoder’s embedding serves a fundamentally different role than the decoder’s output projection.

⚠️ The Default Should Be Tying

Unless you have a specific reason to untie, use weight tying. The memory savings are guaranteed, the quality impact is neutral or positive, and the implementation is simpler (one fewer parameter matrix to manage). Every major open-source LLM framework (Hugging Face Transformers, vLLM, TensorRT-LLM) supports weight tying natively.


The Output Head in the Full Forward Pass

Let us place the output head in context by tracing a complete forward pass.

Step 1: Tokenization. Input text is converted to token IDs: [x1,x2,…,xn][x_1, x_2, \ldots, x_n].

Step 2: Embedding. Token IDs are mapped to vectors via lookup in WEW_E: et=WE[xt]e_t = W_E[x_t]. Position information is added (RoPE, applied at attention time).

Step 3: Transformer layers. The embedded sequence passes through LL layers of attention + FFN. Each layer reads from and writes to the residual stream. After LL layers, we have ht(L)h_t^{(L)}.

Step 4: Final LayerNorm. The final hidden state is normalized: h^t=RMSNorm(ht(L))\hat{h}_t = \text{RMSNorm}(h_t^{(L)}).

Step 5: Output head. The normalized hidden state is projected to vocabulary logits: zt=WEh^tz_t = W_E \hat{h}_t (with weight tying).

Step 6: Softmax / sampling. Logits are converted to probabilities and a token is sampled.

The output head is the thinnest layer in the stack β€” a single matmul β€” but it touches the largest weight matrix in the model (by vocabulary dimension) and produces the distribution that ultimately determines every token the model generates.

Relative Compute Cost per Component (Single Token Decode, Llama 3 70B)

(% of total FLOPs)
Attention QKV Proj 3 * d^2 per layer
23 % of total FLOPs
Attention Scores + Softmax Scales with context
5 % of total FLOPs
Attention Output Proj d^2 per layer
8 % of total FLOPs
FFN Up + Gate 2 * d * d_ff per layer
32 % of total FLOPs
FFN Down d_ff * d per layer
16 % of total FLOPs
Output Head V * d (one matmul)
1 % of total FLOPs
LayerNorm + Other Norms, residuals
15 % of total FLOPs

At 1% of total FLOPs, the output head is compute-cheap. But remember: during decode, each component’s latency is dominated by memory bandwidth, not compute. The output head loads 2 GB of weights to produce 256 KB of logits. In terms of bandwidth consumption, it is closer to 5β€”8% of the total, making it a meaningful contributor to decode latency.


Advanced Topics

Vocabulary Pruning and Adaptive Softmax

For extremely large vocabularies, computing logits over the entire vocabulary is wasteful because most tokens have near-zero probability. Several techniques address this.

Adaptive softmax (Grave et al., 2017) partitions the vocabulary into clusters by frequency. The most common tokens (head cluster) get the full dmodeld_\text{model}-dimensional projection. Rarer tokens (tail clusters) first project to a smaller dimension and then to the vocabulary, reducing computation proportionally. This was important for word-level vocabularies (100K+ words) but is less used with subword tokenizers that keep vocabularies manageable.

Vocabulary pruning at inference time restricts the output projection to a subset of the vocabulary. If we know the model is generating code, we can mask out non-code tokens, skipping their logit computation entirely. This requires careful implementation to avoid inadvertently excluding valid tokens.

Temperature and the Output Head

Temperature scaling modifies the logits before softmax:

p(xt+1=i)=exp⁑(zt,i/Ο„)βˆ‘j=1Vexp⁑(zt,j/Ο„)p(x_{t+1} = i) = \frac{\exp(z_{t,i} / \tau)}{\sum_{j=1}^{V} \exp(z_{t,j} / \tau)}

At Ο„=1\tau = 1, this is standard softmax. At Ο„β†’0\tau \to 0, the distribution collapses to a point mass on the highest logit (greedy decoding). At Ο„β†’βˆž\tau \to \infty, the distribution approaches uniform.

Temperature acts after the output head, modifying how the logit geometry is converted to probabilities. A higher temperature β€œflattens” the sharp peaks that weight tying naturally creates (since the embedding space has strong geometric structure), enabling more diverse generation.

The Softmax Bottleneck

Yang et al. (2018) identified a fundamental limitation: the softmax output of a tied model has rank at most dmodeld_\text{model}, because the logit vector z=WEhz = W_E h lies in the column space of WEW_E, which has rank min⁑(V,dmodel)=dmodel\min(V, d_\text{model}) = d_\text{model}. This means the model cannot represent arbitrary probability distributions over the vocabulary β€” it is limited to distributions that can be expressed as softmax of a dmodeld_\text{model}-rank matrix.

This is the softmax bottleneck. For dmodel=8,192d_\text{model} = 8{,}192 and V=128,000V = 128{,}000, the model can represent at most (81921)=8,192\binom{8192}{1} = 8{,}192 independent directions, far fewer than the 128,000 vocabulary entries. In practice, this is rarely a binding constraint because natural language distributions are highly structured (most probability mass concentrates on a small number of tokens), but it is a theoretical limitation worth understanding.

Σ Theorem: Softmax Bottleneck (Yang et al., 2018)

For a language model with tied embedding matrix WE∈RVΓ—dW_E \in \mathbb{R}^{V \times d} and hidden state h∈Rdh \in \mathbb{R}^d, the output distribution p=softmax(WEh)p = \text{softmax}(W_E h) is constrained to lie in a (dβˆ’1)(d-1)-dimensional manifold within the (Vβˆ’1)(V-1)-simplex. When dβ‰ͺVd \ll V, this prevents the model from expressing the true next-token distribution if it requires more than dd degrees of freedom.

Proposed solutions include Mixture of Softmaxes (MoS), which computes multiple softmax outputs from different projections and averages them, breaking the rank constraint. However, at modern model scales (dmodelβ‰₯4096d_\text{model} \geq 4096), the softmax bottleneck is not empirically limiting, and MoS adds complexity without measurable benefit.


Summary

The output head is where the transformer meets the real world. It converts internal geometry into token predictions through a single matrix multiplication β€” the unembedding projection. Weight tying elegantly unifies the input embedding and output projection, saving billions of parameters at no quality cost. The vocabulary projection creates a significant bandwidth bottleneck during inference, addressed through tensor parallelism and fused kernels during training. And the logit lens turns the output head from a mere projection layer into a window on the model’s internal reasoning, revealing how predictions are refined layer by layer.

The output head’s simplicity is deceptive. It is a linear layer with no activation function, yet it is the interface between the model’s continuous internal world and the discrete tokens that humans read. Every token the model generates β€” every word of every answer, every line of every code snippet β€” passes through this single matrix multiplication. Understanding it is essential for understanding how language models work, how to optimize their inference, and how to interpret what they have learned.

In the next post, we turn to the only training signal that shapes all of this: the cross-entropy loss function.