Part of Series Transformer Anatomy 5 of 23
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 21 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 22 Knowledge Distillation: Training Small Models to Match Large Ones 23 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search

In the previous four parts of this series, we built up the transformer’s machinery piece by piece: attention as a soft database lookup (Part 1), tokenization (Part 2), embeddings (Part 3), and positional encoding (Part 4). One function appeared repeatedly, always at a critical juncture, and we took it for granted each time: softmax. It converts raw attention scores into weights. It converts final logits into a probability distribution over the vocabulary. It appears in every single forward pass of every single transformer ever deployed.

Softmax looks trivial. It is an exponential followed by a normalization. A first-year student can implement it in two lines of Python. And yet, this function is the source of some of the most insidious numerical bugs in deep learning, it is the computational bottleneck that FlashAttention had to rethink from scratch, and its behavior under different scaling regimes determines whether your language model produces coherent text or degenerate repetition.

This post gives softmax the treatment it deserves: we start with the definition and its properties, walk through the numerical instability that makes a naive implementation dangerous, derive the log-sum-exp trick that every production framework uses, then proceed to the online softmax algorithm that made FlashAttention possible. We cover temperature scaling and its connection to entropy, examine how softmax interacts with the attention mechanism’s scaling factor, and finally survey the alternatives that researchers have proposed — and why none of them have displaced softmax from its central role.


Part 1: The Softmax Function

Definition

Given a vector of real numbers z=(z1,z2,,zn)z = (z_1, z_2, \dots, z_n), the softmax function maps it to a probability distribution:

σ(z)i=ezij=1nezj\sigma(z)_i = \frac{e^{z_i}}{\sum_{j=1}^{n} e^{z_j}}

Each output σ(z)i\sigma(z)_i is a value in (0,1)(0, 1), and the outputs sum to exactly 1. The inputs ziz_i are called logits — they are unconstrained real numbers that can be positive, negative, or zero.

Why Exponential?

The exponential function is not an arbitrary choice. It has three properties that make it uniquely suited for this role.

Property 1: Positivity. For any real ziz_i, ezi>0e^{z_i} > 0. This guarantees that every output of softmax is strictly positive — no probability is ever exactly zero. This matters because zero probabilities create problems in downstream computations like cross-entropy loss, where log(0)=\log(0) = -\infty.

Property 2: Monotonicity. If zi>zjz_i > z_j, then ezi>ezje^{z_i} > e^{z_j}, and therefore σ(z)i>σ(z)j\sigma(z)_i > \sigma(z)_j. Softmax preserves the ordering of its inputs. The largest logit always gets the highest probability. This is essential: we want the function to respect the model’s “preferences.”

Property 3: Amplification. The exponential function grows super-linearly. A small difference between two logits becomes a large difference in the corresponding probabilities. If z1=3z_1 = 3 and z2=1z_2 = 1, the ratio of their probabilities is e3/e1=e27.4e^{3}/e^{1} = e^{2} \approx 7.4. The exponential amplifies the gap from 2 to 7.4x. This is what makes softmax a “soft” version of argmax: it concentrates most of the probability mass on the largest logit, but still assigns nonzero probability to all others.

The “Soft Argmax” Interpretation

The argmax function returns a one-hot vector — all zeros except for a 1 at the position of the largest element. It is not differentiable (you cannot compute gradients through it), which makes it useless for training neural networks.

Softmax is the differentiable relaxation of argmax. As the differences between logits grow, softmax output approaches a one-hot vector. As differences shrink toward zero, softmax approaches a uniform distribution. It provides a smooth, differentiable interpolation between these extremes.

Σ Definition: Softmax as Gibbs Distribution

The softmax function is equivalent to the Gibbs (Boltzmann) distribution from statistical mechanics: P(i)=ezi/ZP(i) = e^{z_i} / Z, where Z=jezjZ = \sum_j e^{z_j} is the partition function. The logits ziz_i play the role of negative energies: lower energy states (higher logits) are more probable. This connection is not merely analogical — it becomes precise when we introduce temperature scaling later in this post.

What Softmax Does in a Transformer

Softmax appears in two critical places:

  1. Attention weights. After computing the raw scores QKTQK^T, softmax converts them to weights that sum to 1 across each row. This ensures that the output of attention is a convex combination of value vectors — a weighted average where the weights are non-negative and normalized.

  2. Output logits. The final linear layer of a transformer produces logits over the vocabulary (a vector of size VV, typically 32,000 to 128,000). Softmax converts these to probabilities for the next token.

Both uses exploit the same core property: softmax transforms arbitrary real-valued scores into a proper probability distribution.


Part 2: The Numerical Instability Problem

The definition of softmax involves computing ezie^{z_i}. This is where things go wrong in practice.

The Overflow Problem

Consider what happens when ziz_i is large. In IEEE 754 float32, the largest representable value is approximately 3.4×10383.4 \times 10^{38}. The exponential ezie^{z_i} overflows to infinity when zi>ln(3.4×1038)88.7z_i > \ln(3.4 \times 10^{38}) \approx 88.7. Any logit above 88.7 produces inf in float32.

In float16, the situation is far worse. The maximum representable value is 65,504, and ezie^{z_i} overflows when zi>ln(65504)11.1z_i > \ln(65504) \approx 11.1. Logits above 11.1 produce inf in float16.

These are not extreme or pathological values. During training, logits routinely reach magnitudes of 20, 50, or higher. In attention, dot products between query and key vectors can easily exceed 100, especially for large head dimensions.

A Concrete Example

Consider a simple three-element softmax with logits z=[1000,1001,1002]z = [1000, 1001, 1002] in float32:

import numpy as np

z = np.array([1000.0, 1001.0, 1002.0], dtype=np.float32)
exp_z = np.exp(z)
print(exp_z)
# [inf, inf, inf]

softmax = exp_z / exp_z.sum()
print(softmax)
# [nan, nan, nan]

Every exponential overflows to inf. The sum of infinities is inf. And inf / inf is nan (not a number). The computation is completely destroyed — not just inaccurate, but producing values that propagate corruption through every subsequent calculation.

🚨 NaN Propagation

A single NaN in the softmax output infects every downstream computation. In attention, it corrupts the weighted sum of values. In the output layer, it corrupts the loss. During backpropagation, NaN gradients corrupt every parameter update. One overflow in one softmax call can silently destroy an entire training run.

The Underflow Problem

Overflow is the more dramatic failure, but underflow also causes issues. When ziz_i is very negative (say 1000-1000), e1000e^{-1000} underflows to 0. If all logits are very negative, every exponential underflows and the denominator becomes 0, giving us 0/00/0 = NaN.

More subtly, even when the denominator does not become zero, underflow can cause a loss of relative precision. If a probability should be 102010^{-20} but is rounded to 0, the subsequent log(0)=\log(0) = -\infty in cross-entropy loss produces another catastrophic failure.

How Common Is This in Practice?

📊

Logit Magnitudes in Real Models (typical ranges)

ScenarioTypical Max LogitFP32 Overflow?FP16 Overflow?
Attention scores (d=64, no scaling) ~50-80 No Yes
Attention scores (d=128, no scaling) ~80-120 Yes Yes
Vocab logits (early training) ~5-15 No Maybe
Vocab logits (late training) ~20-50 No Yes
Adversarial/pathological inputs ~100-1000+ Yes Yes
Note: FP32 overflows at logits > ~88.7. FP16 overflows at logits > ~11.1. BF16 has the same exponent range as FP32 (overflow at ~88.7) but lower mantissa precision.

The conclusion is inescapable: naive softmax is broken for any real-world deployment. We need a mathematically equivalent formulation that avoids computing large exponentials.


Part 3: The Log-Sum-Exp Trick

The solution is elegant: subtract the maximum value from all logits before exponentiating.

The Key Identity

Σ Theorem: Shift Invariance of Softmax

For any constant cRc \in \mathbb{R} and any vector zRnz \in \mathbb{R}^n:

σ(z)i=σ(zc)i\sigma(z)_i = \sigma(z - c)_i

That is, softmax is invariant to additive shifts. Adding the same constant to every logit does not change the output probabilities.

Proof. Let cc be any real constant. Then:

σ(zc)i=ezicj=1nezjc=eziecj=1nezjec=eceziecj=1nezj=ezij=1nezj=σ(z)i\sigma(z - c)_i = \frac{e^{z_i - c}}{\sum_{j=1}^{n} e^{z_j - c}} = \frac{e^{z_i} \cdot e^{-c}}{\sum_{j=1}^{n} e^{z_j} \cdot e^{-c}} = \frac{e^{-c} \cdot e^{z_i}}{e^{-c} \cdot \sum_{j=1}^{n} e^{z_j}} = \frac{e^{z_i}}{\sum_{j=1}^{n} e^{z_j}} = \sigma(z)_i

The ece^{-c} factor cancels in the numerator and denominator. This holds for any cc, but the optimal choice for numerical stability is c=max(z)c = \max(z).

Why c=max(z)c = \max(z) Works

When we set c=m=maxjzjc = m = \max_j z_j, every shifted logit zimz_i - m satisfies zim0z_i - m \leq 0. Therefore, every exponential ezime^{z_i - m} satisfies 0<ezim10 < e^{z_i - m} \leq 1. No exponential can overflow, because the largest exponent is exactly 0 (corresponding to emm=e0=1e^{m - m} = e^0 = 1).

Underflow is still possible for very negative shifted logits, but this is benign: if zimz_i \ll m, then ezim0e^{z_i - m} \approx 0, which is the correct answer — that element’s probability is genuinely negligible.

The Two-Pass Algorithm

The stable softmax requires two passes over the data:

Pass 1: Find the maximum. m=maxj=1nzjm = \max_{j=1}^{n} z_j

Pass 2: Compute shifted exponentials and normalize. σ(z)i=ezimj=1nezjm\sigma(z)_i = \frac{e^{z_i - m}}{\sum_{j=1}^{n} e^{z_j - m}}

In Python:

import numpy as np

def stable_softmax(z):
    m = np.max(z)               # Pass 1: find max
    exp_z = np.exp(z - m)       # Pass 2a: shifted exponentials
    return exp_z / exp_z.sum()  # Pass 2b: normalize

z = np.array([1000.0, 1001.0, 1002.0], dtype=np.float32)
print(stable_softmax(z))
# [0.09003057, 0.24472848, 0.66524094]

The same logits that produced NaN with naive softmax now give correct, meaningful probabilities.

The Log-Sum-Exp Connection

The name “log-sum-exp trick” comes from the closely related problem of computing log(jezj)\log(\sum_j e^{z_j}) — the log of the partition function. This quantity appears in cross-entropy loss and log-likelihood computations. Naively computing it overflows for the same reason softmax does. The stable version is:

LSE(z)=m+log(j=1nezjm)\text{LSE}(z) = m + \log\left(\sum_{j=1}^{n} e^{z_j - m}\right)

where m=maxjzjm = \max_j z_j. This is mathematically equivalent to log(jezj)\log(\sum_j e^{z_j}) but numerically stable, because the exponentials are all in the range (0,1](0, 1].

The connection to softmax is direct: log(σ(z)i)=ziLSE(z)\log(\sigma(z)_i) = z_i - \text{LSE}(z). In practice, many frameworks compute log-softmax (which is needed for cross-entropy loss) using the LSE trick directly, avoiding the need to compute softmax and then take the log.

Framework Implementations

PyTorch’s torch.nn.functional.log_softmax and torch.nn.functional.cross_entropy both use the log-sum-exp trick internally. If you compute softmax followed by log manually, you lose both numerical stability (underflow in softmax produces log(0)=\log(0) = -\infty) and performance (two kernel launches instead of one). Always use the fused versions.

Performance Cost of Two Passes

The two-pass algorithm reads the input data twice: once to find the maximum, once to compute the exponentials. For a vector of nn elements, this is 2n2n reads from memory. In the attention context, where softmax operates on each row of the N×NN \times N attention matrix, this means 2N22N^2 reads across all rows. On memory-bandwidth-limited hardware (which GPUs are for this operation), that factor of 2 matters.

This cost is what motivated the search for a single-pass algorithm.


Part 4: Online Softmax (Milakov-Gimelshein)

The two-pass softmax algorithm requires seeing all elements before computing any output: you need the global maximum before you can compute any exponential safely. This is a problem for FlashAttention, which processes the key-value pairs in tiles — it sees one block of keys at a time and cannot afford to materialize the full N×NN \times N attention matrix.

The online softmax algorithm, described by Milakov and Gimelshein (2018), solves this problem. It computes the exact same result as the two-pass algorithm, but in a single streaming pass that can process data incrementally.

The Core Insight

The key observation is that if you have a partial softmax computed over elements 1,,k1, \dots, k, and new elements k+1,,k+pk+1, \dots, k+p arrive, you can update the running result without restarting from scratch. The only extra work is a rescaling of the partial sum to account for the new maximum.

The Algorithm

Maintain two running variables:

  • m(k)m^{(k)} — the maximum of all logits seen so far (through element kk)
  • (k)\ell^{(k)} — the sum of exponentials j=1kezjm(k)\sum_{j=1}^{k} e^{z_j - m^{(k)}}, computed using the current maximum

Initialization: m(0)=,(0)=0m^{(0)} = -\infty, \quad \ell^{(0)} = 0

Update for each new element zk+1z_{k+1}: m(k+1)=max(m(k),zk+1)m^{(k+1)} = \max(m^{(k)}, z_{k+1}) (k+1)=em(k)m(k+1)(k)+ezk+1m(k+1)\ell^{(k+1)} = e^{m^{(k)} - m^{(k+1)}} \cdot \ell^{(k)} + e^{z_{k+1} - m^{(k+1)}}

The first term rescales the old sum: all previous exponentials were computed with respect to m(k)m^{(k)}, but we now need them with respect to m(k+1)m^{(k+1)}. Multiplying by em(k)m(k+1)e^{m^{(k)} - m^{(k+1)}} accomplishes this exactly. The second term adds the contribution of the new element.

After processing all nn elements, (n)=j=1nezjm(n)\ell^{(n)} = \sum_{j=1}^{n} e^{z_j - m^{(n)}}, and the softmax is:

σ(z)i=ezim(n)(n)\sigma(z)_i = \frac{e^{z_i - m^{(n)}}}{\ell^{(n)}}

Σ Theorem: Correctness of Online Softmax

After processing all nn elements, the online algorithm produces:

m(n)=maxj=1nzj,(n)=j=1nezjm(n)m^{(n)} = \max_{j=1}^{n} z_j, \quad \ell^{(n)} = \sum_{j=1}^{n} e^{z_j - m^{(n)}}

which gives the identical result to the two-pass stable softmax.

Proof by induction. The base case k=0k = 0 is trivial: m(0)=m^{(0)} = -\infty, (0)=0\ell^{(0)} = 0. For the inductive step, assume m(k)=maxjkzjm^{(k)} = \max_{j \leq k} z_j and (k)=j=1kezjm(k)\ell^{(k)} = \sum_{j=1}^{k} e^{z_j - m^{(k)}}. Then:

m(k+1)=max(m(k),zk+1)=maxjk+1zjm^{(k+1)} = \max(m^{(k)}, z_{k+1}) = \max_{j \leq k+1} z_j

And:

(k+1)=em(k)m(k+1)j=1kezjm(k)+ezk+1m(k+1)\ell^{(k+1)} = e^{m^{(k)} - m^{(k+1)}} \cdot \sum_{j=1}^{k} e^{z_j - m^{(k)}} + e^{z_{k+1} - m^{(k+1)}}

=j=1kezjm(k)em(k)m(k+1)+ezk+1m(k+1)= \sum_{j=1}^{k} e^{z_j - m^{(k)}} \cdot e^{m^{(k)} - m^{(k+1)}} + e^{z_{k+1} - m^{(k+1)}}

=j=1kezjm(k+1)+ezk+1m(k+1)=j=1k+1ezjm(k+1)= \sum_{j=1}^{k} e^{z_j - m^{(k+1)}} + e^{z_{k+1} - m^{(k+1)}} = \sum_{j=1}^{k+1} e^{z_j - m^{(k+1)}}

The rescaling factor em(k)m(k+1)e^{m^{(k)} - m^{(k+1)}} corrects each exponential in the partial sum from the old maximum to the new one. This is exact arithmetic — no approximation is involved.

Block-Wise Extension for FlashAttention

In FlashAttention, we do not process one element at a time. We process tiles of BcB_c key-value pairs. The algorithm extends naturally to blocks. Suppose we have processed blocks 1,,t1, \dots, t and maintained (m(t),(t),o(t))(m^{(t)}, \ell^{(t)}, o^{(t)}), where o(t)o^{(t)} is the running output (the weighted sum of value vectors). When block t+1t+1 with keys Kt+1K_{t+1} and values Vt+1V_{t+1} arrives:

  1. Compute local attention scores: st+1=QKt+1T/dks_{t+1} = Q \cdot K_{t+1}^T / \sqrt{d_k}
  2. Compute local maximum: mlocal=max(st+1)m_{\text{local}} = \max(s_{t+1})
  3. Update global maximum: m(t+1)=max(m(t),mlocal)m^{(t+1)} = \max(m^{(t)}, m_{\text{local}})
  4. Rescale old sum: (t+1)=em(t)m(t+1)(t)+est+1m(t+1)\ell^{(t+1)} = e^{m^{(t)} - m^{(t+1)}} \cdot \ell^{(t)} + \sum e^{s_{t+1} - m^{(t+1)}}
  5. Rescale old output and add new contribution:

o(t+1)=em(t)m(t+1)(t)o(t)+est+1m(t+1)Vt+1(t+1)o^{(t+1)} = \frac{e^{m^{(t)} - m^{(t+1)}} \cdot \ell^{(t)} \cdot o^{(t)} + \sum e^{s_{t+1} - m^{(t+1)}} \cdot V_{t+1}}{\ell^{(t+1)}}

ℹ️ Why This Matters for FlashAttention

The entire reason FlashAttention can process attention in tiles without materializing the N×NN \times N matrix is that online softmax allows incremental computation. Each tile of K and V is loaded from HBM to SRAM, the running softmax statistics are updated in registers, and the tile is discarded. The full attention matrix never exists in memory. Without online softmax, tiled attention would require two full passes over the KV data — defeating the purpose of tiling.

Implementation in Python

Here is the online softmax algorithm for a single vector, demonstrating correctness:

import numpy as np

def online_softmax(z):
    n = len(z)
    m = -np.inf  # running max
    l = 0.0      # running sum of shifted exponentials

    # Single pass: compute m and l incrementally
    for i in range(n):
        m_new = max(m, z[i])
        l = np.exp(m - m_new) * l + np.exp(z[i] - m_new)
        m = m_new

    # Compute final softmax using accumulated statistics
    return np.exp(z - m) / l

z = np.array([1000.0, 1001.0, 1002.0], dtype=np.float64)
print(online_softmax(z))
# [0.09003057, 0.24472847, 0.66524096]

This produces the same result as the two-pass algorithm but processes the input in a single sweep.

Performance Comparison

Memory Reads for Softmax (per row of N elements)

(passes over data)
Naive (overflow-prone) 1 pass, N reads
1 passes over data
Two-pass stable 2 passes, 2N reads
2 passes over data
Online (Milakov-Gimelshein) 1 pass, N reads
1 passes over data

The online algorithm matches the naive version’s single-pass behavior while achieving the two-pass version’s numerical stability. For FlashAttention, where each “pass” over KV data requires a round-trip from HBM to SRAM, halving the number of passes translates directly to halving the memory traffic for the softmax portion of the computation.


Part 5: Temperature Scaling

Temperature is a single scalar parameter that controls how “sharp” or “flat” the softmax distribution is. It provides a knob to tune the trade-off between confident (low-entropy) and exploratory (high-entropy) outputs.

Definition

The temperature-scaled softmax is:

σ(z/T)i=ezi/Tj=1nezj/T\sigma(z/T)_i = \frac{e^{z_i / T}}{\sum_{j=1}^{n} e^{z_j / T}}

where T>0T > 0 is the temperature parameter. The logits ziz_i are divided by TT before the softmax is applied.

The Three Regimes

T0T \to 0 (Low temperature): approaches argmax. As temperature decreases toward zero, the softmax concentrates all probability mass on the largest logit. In the limit, limT0+σ(z/T)=one-hot(argmaxz)\lim_{T \to 0^+} \sigma(z/T) = \text{one-hot}(\arg\max z). The distribution becomes deterministic: only the most probable element survives.

To see why, consider two logits z1>z2z_1 > z_2. The ratio of their probabilities is e(z1z2)/Te^{(z_1 - z_2)/T}. As T0T \to 0, this ratio grows without bound — the larger logit dominates exponentially.

T=1T = 1 (Unit temperature): the model’s native distribution. The softmax operates on the raw logits as learned during training. This is the distribution the model was optimized to produce.

TT \to \infty (High temperature): approaches uniform. As temperature increases, division by TT shrinks all logits toward zero, erasing the differences between them. In the limit, limTσ(z/T)=(1/n,1/n,,1/n)\lim_{T \to \infty} \sigma(z/T) = (1/n, 1/n, \dots, 1/n). Every element is equally likely.

📊

Effect of Temperature on Softmax Output

Temperaturesoftmax([3, 1, 0.5])Max ProbEntropy (bits)
T = 0.1 [0.9999, 0.0000, 0.0000] 99.99% 0.002
T = 0.5 [0.9820, 0.0122, 0.0058] 98.20% 0.145
T = 1.0 [0.8360, 0.1131, 0.0688] 83.60% 0.813
T = 2.0 [0.5409, 0.2412, 0.2179] 54.09% 1.430
T = 5.0 [0.3908, 0.3092, 0.3000] 39.08% 1.569
T = 100 [0.3340, 0.3333, 0.3327] 33.40% 1.585
Note: Entropy is measured in bits. Maximum entropy for 3 classes is log2(3) = 1.585 bits (uniform distribution).

The Connection to Entropy

Shannon entropy of a probability distribution pp is H(p)=ipilogpiH(p) = -\sum_i p_i \log p_i. Higher entropy means more “randomness” or “uncertainty.” The entropy of the softmax output is a monotonically increasing function of temperature:

  • At T0T \to 0: entropy approaches 0 (complete certainty)
  • At TT \to \infty: entropy approaches logn\log n (maximum uncertainty)

This relationship is why temperature is the primary control for generation quality in language models. Sampling from a low-temperature distribution produces repetitive, predictable text. Sampling from a high-temperature distribution produces creative but potentially incoherent text. The optimal temperature depends on the task: code generation favors low temperature (correctness matters), creative writing favors moderate temperature (variety matters).

Why Temperature Works Mathematically

Temperature scaling is equivalent to exponentiating the original softmax probabilities:

σ(z/T)i=σ(z)i1/Tjσ(z)j1/T\sigma(z/T)_i = \frac{\sigma(z)_i^{1/T}}{\sum_j \sigma(z)_j^{1/T}}

When T<1T < 1, we raise probabilities to a power greater than 1, which makes large probabilities larger and small probabilities smaller (sharpening). When T>1T > 1, we raise probabilities to a power less than 1, which makes all probabilities more equal (flattening).

💡 Temperature in Practice

Most inference APIs expose temperature as a user-facing parameter. OpenAI’s API defaults to T=1.0T = 1.0. Common values: T=0.0T = 0.0 for deterministic outputs (greedy decoding), T=0.2T = 0.20.40.4 for factual tasks, T=0.7T = 0.71.01.0 for general conversation, T=1.0T = 1.01.51.5 for creative generation. Values above 2.0 rarely produce useful output.

Learned Temperature in Attention

Some transformer architectures go beyond a fixed temperature. The original “Attention Is All You Need” paper uses a fixed scaling factor 1/dk1/\sqrt{d_k}, which is equivalent to a temperature of T=dkT = \sqrt{d_k}. But more recent work has explored:

  • Per-head learned temperature. Each attention head learns its own scalar temperature, allowing some heads to be “sharp” (attending to a single position) and others to be “broad” (attending to many positions). This was explored in various multi-scale attention designs.

  • Query-dependent temperature. The temperature is computed as a function of the query vector, allowing the sharpness of attention to vary by position and context. When the query is confident (high norm), the temperature can be lowered automatically.

These learned temperature variants add minimal parameters (one scalar per head or a small projection) but can improve performance on tasks where different attention patterns require different levels of focus.


Part 6: Softmax in Attention

Having understood softmax as an isolated function, we now examine its specific role in the attention mechanism and the engineering decisions that surround it.

The Scaled Dot-Product Attention

The standard attention formula is:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

The softmax is applied row-wise to the matrix QKT/dkQK^T / \sqrt{d_k}, where QRN×dkQ \in \mathbb{R}^{N \times d_k}, KRN×dkK \in \mathbb{R}^{N \times d_k}, and the resulting attention matrix is RN×N\mathbb{R}^{N \times N}.

Why 1/dk1/\sqrt{d_k} Scaling Is Essential

Without the 1/dk1/\sqrt{d_k} factor, the dot products qikjq_i \cdot k_j have a variance that scales linearly with dkd_k. If each component of qq and kk is drawn from a distribution with zero mean and unit variance, then:

Var(qk)=l=1dkVar(qlkl)=dk\text{Var}(q \cdot k) = \sum_{l=1}^{d_k} \text{Var}(q_l \cdot k_l) = d_k

The standard deviation of the dot product is dk\sqrt{d_k}. For dk=128d_k = 128 (a common head dimension), the dot products have a standard deviation of about 11.3. This means many raw scores will have magnitudes of 20, 30, or higher.

⚠️ Softmax Saturation

When the inputs to softmax have large magnitude, the output approaches a one-hot vector — nearly all probability mass concentrates on the single largest element. In this saturated regime, gradients become vanishingly small (the Jacobian of softmax approaches the zero matrix), and the model cannot learn to adjust its attention patterns. This is not merely a numerical issue — it is a fundamental training failure mode.

Dividing by dk\sqrt{d_k} normalizes the variance back to approximately 1, keeping the softmax inputs in a regime where the function is sensitive and gradients flow properly. This is the “temperature” of attention, baked into the architecture:

Tattention=dkT_{\text{attention}} = \sqrt{d_k}

For dk=64d_k = 64, T=8T = 8. For dk=128d_k = 128, T11.3T \approx 11.3.

The Causal Mask

In autoregressive (decoder-only) models, token ii must not attend to any token j>ij > i — it cannot look into the future. This is enforced by adding a mask to the attention scores before softmax:

scoreij={qikj/dkif jiif j>i\text{score}_{ij} = \begin{cases} q_i \cdot k_j / \sqrt{d_k} & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}

After softmax, e=0e^{-\infty} = 0, so masked positions receive zero weight. But there is a numerical subtlety: in floating-point arithmetic, we cannot represent -\infty exactly. The convention is:

  • FP32: Use -1e9 or float('-inf') (IEEE 754 negative infinity, which is representable)
  • FP16: Use -65504 (the most negative FP16 value). True negative infinity is representable in FP16, but some kernels use the finite minimum instead.
  • BF16: Use -3.39e38 or BF16’s representable negative infinity.
ℹ️ The -65504 Constant

If you have ever seen the magic number -65504 in attention kernel code and wondered where it comes from: it is the most negative finite value in IEEE 754 half-precision floating point. After the max\max subtraction in the log-sum-exp trick, this becomes something like e65504me^{-65504 - m}, which is so astronomically negative that the exponential underflows to exactly 0.0 in any floating-point format. The causal mask works correctly even though we cannot represent true -\infty.

Numerical Stability with Masking

The interaction between masking and the log-sum-exp trick deserves attention. When we compute m=maxjsijm = \max_j s_{ij}, the masked positions are effectively -\infty and do not affect the maximum (assuming at least one unmasked position exists). The subsequent esijme^{s_{ij} - m} for masked positions evaluates to em=0e^{-\infty - m} = 0. The system is self-consistent: the mask, the max-subtraction, and the exponential all cooperate to produce the correct zero weights.

One edge case: for the very first token in a sequence (position 0), the causal mask allows attention only to itself. The softmax of a single element is always 1.0, regardless of the logit’s value. This is numerically stable by construction.


Part 7: Alternatives to Softmax

Despite its ubiquity, softmax is not the only way to normalize attention scores. Researchers have proposed several alternatives, motivated by the desire to reduce softmax’s computational cost or remove its sequence-length-dependent normalization.

Sigmoid Attention

Replace the row-wise softmax with element-wise sigmoid:

αij=σsigmoid(qikj/dk)=11+eqikj/dk\alpha_{ij} = \sigma_{\text{sigmoid}}(q_i \cdot k_j / \sqrt{d_k}) = \frac{1}{1 + e^{-q_i \cdot k_j / \sqrt{d_k}}}

Each attention weight is independently mapped to (0,1)(0, 1). The weights for a given query no longer sum to 1.

Advantages: Sigmoid is cheaper than softmax (no global normalization across the sequence). Each weight can be computed independently, which is more hardware-friendly. Some vision transformer variants (notably SigLIP) have used sigmoid attention successfully.

Disadvantages: Without normalization, the output of attention is no longer a weighted average — it is a weighted sum. The magnitude of the output depends on how many keys match the query, which can vary widely across positions and layers. This typically requires additional normalization (such as LayerNorm) after attention.

ReLU Attention

Replace softmax with ReLU(qikj)\text{ReLU}(q_i \cdot k_j):

αij=max(0,qikj/dk)\alpha_{ij} = \max(0, q_i \cdot k_j / \sqrt{d_k})

Negative scores become zero; positive scores pass through unchanged.

Advantages: ReLU is extremely cheap and introduces true sparsity — many attention weights are exactly zero, which could enable sparse computation.

Disadvantages: Like sigmoid, no normalization. But worse: ReLU outputs can be arbitrarily large, making the attention output scale unpredictable. Training instability is common. The gradient of ReLU is either 0 or 1, losing the smooth gradient landscape that softmax provides.

Squared ReLU Attention

A refinement: αij=ReLU(qikj)2\alpha_{ij} = \text{ReLU}(q_i \cdot k_j)^2. The squaring amplifies large values (similar to how softmax’s exponential amplifies differences) and provides a smoother gradient at zero. This has shown competitive results in some research papers, but has not seen widespread adoption.

Linear Attention

The most radical departure: remove the nonlinearity entirely. Rewrite attention as:

Attention(Q,K,V)=ϕ(Q)(ϕ(K)TV)ϕ(Q)ϕ(K)T1\text{Attention}(Q, K, V) = \frac{\phi(Q) \cdot (\phi(K)^T V)}{\phi(Q) \cdot \phi(K)^T \mathbf{1}}

where ϕ\phi is a feature map (such as elu(x)+1\text{elu}(x) + 1 or a random Fourier feature). The key insight is the associativity of matrix multiplication: instead of computing the N×NN \times N attention matrix first (cost O(N2d)O(N^2 d)), we compute ϕ(K)TV\phi(K)^T V first (cost O(Nd2)O(N d^2)), then multiply by ϕ(Q)\phi(Q) (cost O(Nd2)O(N d^2)). The total cost is O(Nd2)O(N d^2)linear in sequence length.

Asymptotic Complexity: Softmax vs Linear Attention

(relative compute (N=4096, d=128))
Softmax Attention O(N^2 d)
4,096 relative compute (N=4096, d=128)
FlashAttention O(N^2 d), reduced I/O
4,096 relative compute (N=4096, d=128)
Linear Attention O(N d^2)
128 relative compute (N=4096, d=128)

The catch: Linear attention consistently underperforms softmax attention on quality benchmarks, especially for language modeling. The normalization property of softmax — that weights sum to 1 — appears to be critical for learning sharp, selective attention patterns. Without it, attention becomes “mushy”: every token attends to everything with similar weight, degrading the model’s ability to focus on relevant context.

📊

Softmax vs Alternatives: Quality Comparison (indicative)

Attention TypeNormalizationComplexityLM QualityAdoption
Softmax Row-wise sum-to-1 O(N^2 d) Baseline Dominant
Sigmoid Element-wise (0,1) O(N^2 d) -1 to -3% ppl Niche (vision)
ReLU None O(N^2 d) -5 to -10% ppl Research only
Squared ReLU None O(N^2 d) -2 to -5% ppl Research only
Linear Approximate O(N d^2) -5 to -15% ppl Research only
Note: Perplexity degradation is approximate and varies by model size, dataset, and training duration. Softmax attention with FlashAttention remains the production standard.

Why Softmax Persists

Softmax has survived every challenger because of a combination of mathematical and practical properties:

  1. Probability distribution. Outputs are positive and sum to 1. This means attention output is a convex combination of value vectors, with a clear probabilistic interpretation.

  2. Gradient flow. The softmax Jacobian provides smooth, well-conditioned gradients across a wide range of inputs. The saturated regime (near one-hot) has small gradients, but the 1/dk1/\sqrt{d_k} scaling keeps inputs away from saturation.

  3. Sharp attention. The exponential function allows softmax to produce very peaked distributions when needed (attending strongly to one or two positions) while remaining smooth and differentiable. Linear and sigmoid alternatives struggle to be both selective and stable.

  4. Mature infrastructure. FlashAttention, cuDNN, and every major framework have heavily optimized softmax attention. The constant factor in any alternative must beat not just the theoretical complexity but also the engineering investment in softmax kernels.

  5. Scaling law alignment. Models trained with softmax attention have demonstrated consistent, predictable scaling behavior across orders of magnitude in compute. Alternatives have not been validated at frontier scale, and the risk of switching is high when training runs cost tens of millions of dollars.


Part 8: Input/Output Specifications

To close, here is the precise specification of what softmax takes in and produces in transformer contexts.

As the Output Layer

  • Input: A logit vector zRVz \in \mathbb{R}^V, where VV is the vocabulary size (commonly 32,000 to 128,256). Each element ziz_i is an unbounded real number representing the model’s unnormalized score for token ii.
  • Output: A probability vector pRVp \in \mathbb{R}^V on the probability simplex: pi>0p_i > 0 for all ii, and i=1Vpi=1\sum_{i=1}^{V} p_i = 1.
  • Computational cost: O(V)O(V) for a single softmax. Across a batch of BB sequences each of length NN, the total cost is O(BNV)O(BNV) at the output layer, though in practice this is dominated by the preceding linear projection.

As the Attention Normalizer

  • Input: An attention score matrix SRN×NS \in \mathbb{R}^{N \times N} (for one head, one layer), where Sij=qikj/dkS_{ij} = q_i \cdot k_j / \sqrt{d_k}. Values are real numbers, typically in [10,10][-10, 10] after scaling. Masked positions are set to a large negative value (e.g., 109-10^{9} or 65504-65504).
  • Output: An attention weight matrix PRN×NP \in \mathbb{R}^{N \times N} where each row lies on the probability simplex: Pij0P_{ij} \geq 0 and jPij=1\sum_{j} P_{ij} = 1 for every row ii.
  • Computational cost: O(N2)O(N^2) per head. For HH heads, LL layers, and batch size BB, the total softmax cost across the model is O(BLHN2)O(BLHN^2).
📊

Softmax Computation Cost by Context

ContextInput SizeCost per CallTotal per Forward Pass
Output logits (V=128K) R^128000 O(V) = O(128K) O(B * N * V)
Attention (N=2K, H=32, L=80) R^(2K x 2K) O(N^2) = O(4M) O(B * L * H * N^2) = O(B * 10T)
Attention (N=128K, H=32, L=80) R^(128K x 128K) O(N^2) = O(16B) O(B * L * H * N^2)
Note: For long-context models, attention softmax is by far the dominant softmax cost. FlashAttention computes it without materializing the N x N matrix.

Numerical Requirements

The log-sum-exp trick is non-negotiable for any production implementation. Online softmax is required for FlashAttention-style tiled computation. Temperature scaling must be applied before the max-subtraction, not after. Causal masking must be applied before softmax, and the mask value must be sufficiently negative to underflow to zero after exponentiation.

The Complete Softmax Pipeline in Modern Transformers

In a single forward pass of a modern LLM, softmax executes thousands of times: once per head per layer for attention (H×L=32×80=2,560H \times L = 32 \times 80 = 2{,}560 times for a 70B model), plus once at the output layer. The FlashAttention kernel fuses the softmax computation (including the online max tracking, exponential computation, and normalization) into the same CUDA kernel as the matrix multiplications, so softmax never appears as a separate operation in the execution timeline. It is computed in registers and SRAM, tile by tile, invisible to HBM. That is the final engineering triumph of the ideas in this post: the function that appears in every textbook as a simple fraction is, in practice, a carefully orchestrated streaming computation that never fully materializes its inputs or outputs.


Summary

Softmax is far more than ex/exe^x / \sum e^x. It is a numerically treacherous function that requires the log-sum-exp trick for stability, a streaming algorithm (online softmax) for FlashAttention, a temperature parameter for controlling distribution sharpness, and careful engineering around masking, scaling, and floating-point representation. Every alternative that has been proposed — sigmoid, ReLU, linear attention — trades away one or more of softmax’s core properties (normalization, sharp selectivity, smooth gradients) and has failed to displace it at scale.

If you are building systems that run transformers, you need to understand softmax at this level. Not because you will implement it from scratch, but because the numerical choices made inside softmax — the max-subtraction, the rescaling in online softmax, the 1/dk1/\sqrt{d_k} temperature, the -65504 mask value — determine whether your model trains stably, generates coherently, and runs efficiently. These are the details that separate a working system from one that produces NaN.

In Part 6, we will move to the next component in the transformer block: the attention variants (MHA, MQA, GQA, and MLA) that reshape the KV cache and determine inference memory cost.