In the previous four parts of this series, we built up the transformer’s machinery piece by piece: attention as a soft database lookup (Part 1), tokenization (Part 2), embeddings (Part 3), and positional encoding (Part 4). One function appeared repeatedly, always at a critical juncture, and we took it for granted each time: softmax. It converts raw attention scores into weights. It converts final logits into a probability distribution over the vocabulary. It appears in every single forward pass of every single transformer ever deployed.
Softmax looks trivial. It is an exponential followed by a normalization. A first-year student can implement it in two lines of Python. And yet, this function is the source of some of the most insidious numerical bugs in deep learning, it is the computational bottleneck that FlashAttention had to rethink from scratch, and its behavior under different scaling regimes determines whether your language model produces coherent text or degenerate repetition.
This post gives softmax the treatment it deserves: we start with the definition and its properties, walk through the numerical instability that makes a naive implementation dangerous, derive the log-sum-exp trick that every production framework uses, then proceed to the online softmax algorithm that made FlashAttention possible. We cover temperature scaling and its connection to entropy, examine how softmax interacts with the attention mechanism’s scaling factor, and finally survey the alternatives that researchers have proposed — and why none of them have displaced softmax from its central role.
Part 1: The Softmax Function
Definition
Given a vector of real numbers , the softmax function maps it to a probability distribution:
Each output is a value in , and the outputs sum to exactly 1. The inputs are called logits — they are unconstrained real numbers that can be positive, negative, or zero.
Why Exponential?
The exponential function is not an arbitrary choice. It has three properties that make it uniquely suited for this role.
Property 1: Positivity. For any real , . This guarantees that every output of softmax is strictly positive — no probability is ever exactly zero. This matters because zero probabilities create problems in downstream computations like cross-entropy loss, where .
Property 2: Monotonicity. If , then , and therefore . Softmax preserves the ordering of its inputs. The largest logit always gets the highest probability. This is essential: we want the function to respect the model’s “preferences.”
Property 3: Amplification. The exponential function grows super-linearly. A small difference between two logits becomes a large difference in the corresponding probabilities. If and , the ratio of their probabilities is . The exponential amplifies the gap from 2 to 7.4x. This is what makes softmax a “soft” version of argmax: it concentrates most of the probability mass on the largest logit, but still assigns nonzero probability to all others.
The “Soft Argmax” Interpretation
The argmax function returns a one-hot vector — all zeros except for a 1 at the position of the largest element. It is not differentiable (you cannot compute gradients through it), which makes it useless for training neural networks.
Softmax is the differentiable relaxation of argmax. As the differences between logits grow, softmax output approaches a one-hot vector. As differences shrink toward zero, softmax approaches a uniform distribution. It provides a smooth, differentiable interpolation between these extremes.
The softmax function is equivalent to the Gibbs (Boltzmann) distribution from statistical mechanics: , where is the partition function. The logits play the role of negative energies: lower energy states (higher logits) are more probable. This connection is not merely analogical — it becomes precise when we introduce temperature scaling later in this post.
What Softmax Does in a Transformer
Softmax appears in two critical places:
-
Attention weights. After computing the raw scores , softmax converts them to weights that sum to 1 across each row. This ensures that the output of attention is a convex combination of value vectors — a weighted average where the weights are non-negative and normalized.
-
Output logits. The final linear layer of a transformer produces logits over the vocabulary (a vector of size , typically 32,000 to 128,000). Softmax converts these to probabilities for the next token.
Both uses exploit the same core property: softmax transforms arbitrary real-valued scores into a proper probability distribution.
Part 2: The Numerical Instability Problem
The definition of softmax involves computing . This is where things go wrong in practice.
The Overflow Problem
Consider what happens when is large. In IEEE 754 float32, the largest representable value is approximately . The exponential overflows to infinity when . Any logit above 88.7 produces inf in float32.
In float16, the situation is far worse. The maximum representable value is 65,504, and overflows when . Logits above 11.1 produce inf in float16.
These are not extreme or pathological values. During training, logits routinely reach magnitudes of 20, 50, or higher. In attention, dot products between query and key vectors can easily exceed 100, especially for large head dimensions.
A Concrete Example
Consider a simple three-element softmax with logits in float32:
import numpy as np
z = np.array([1000.0, 1001.0, 1002.0], dtype=np.float32)
exp_z = np.exp(z)
print(exp_z)
# [inf, inf, inf]
softmax = exp_z / exp_z.sum()
print(softmax)
# [nan, nan, nan]
Every exponential overflows to inf. The sum of infinities is inf. And inf / inf is nan (not a number). The computation is completely destroyed — not just inaccurate, but producing values that propagate corruption through every subsequent calculation.
A single NaN in the softmax output infects every downstream computation. In attention, it corrupts the weighted sum of values. In the output layer, it corrupts the loss. During backpropagation, NaN gradients corrupt every parameter update. One overflow in one softmax call can silently destroy an entire training run.
The Underflow Problem
Overflow is the more dramatic failure, but underflow also causes issues. When is very negative (say ), underflows to 0. If all logits are very negative, every exponential underflows and the denominator becomes 0, giving us = NaN.
More subtly, even when the denominator does not become zero, underflow can cause a loss of relative precision. If a probability should be but is rounded to 0, the subsequent in cross-entropy loss produces another catastrophic failure.
How Common Is This in Practice?
Logit Magnitudes in Real Models (typical ranges)
| Scenario | Typical Max Logit | FP32 Overflow? | FP16 Overflow? |
|---|---|---|---|
| Attention scores (d=64, no scaling) | ~50-80 | No | Yes |
| Attention scores (d=128, no scaling) | ~80-120 | Yes | Yes |
| Vocab logits (early training) | ~5-15 | No | Maybe |
| Vocab logits (late training) | ~20-50 | No | Yes |
| Adversarial/pathological inputs | ~100-1000+ | Yes | Yes |
The conclusion is inescapable: naive softmax is broken for any real-world deployment. We need a mathematically equivalent formulation that avoids computing large exponentials.
Part 3: The Log-Sum-Exp Trick
The solution is elegant: subtract the maximum value from all logits before exponentiating.
The Key Identity
For any constant and any vector :
That is, softmax is invariant to additive shifts. Adding the same constant to every logit does not change the output probabilities.
Proof. Let be any real constant. Then:
The factor cancels in the numerator and denominator. This holds for any , but the optimal choice for numerical stability is .
Why Works
When we set , every shifted logit satisfies . Therefore, every exponential satisfies . No exponential can overflow, because the largest exponent is exactly 0 (corresponding to ).
Underflow is still possible for very negative shifted logits, but this is benign: if , then , which is the correct answer — that element’s probability is genuinely negligible.
The Two-Pass Algorithm
The stable softmax requires two passes over the data:
Pass 1: Find the maximum.
Pass 2: Compute shifted exponentials and normalize.
In Python:
import numpy as np
def stable_softmax(z):
m = np.max(z) # Pass 1: find max
exp_z = np.exp(z - m) # Pass 2a: shifted exponentials
return exp_z / exp_z.sum() # Pass 2b: normalize
z = np.array([1000.0, 1001.0, 1002.0], dtype=np.float32)
print(stable_softmax(z))
# [0.09003057, 0.24472848, 0.66524094]
The same logits that produced NaN with naive softmax now give correct, meaningful probabilities.
The Log-Sum-Exp Connection
The name “log-sum-exp trick” comes from the closely related problem of computing — the log of the partition function. This quantity appears in cross-entropy loss and log-likelihood computations. Naively computing it overflows for the same reason softmax does. The stable version is:
where . This is mathematically equivalent to but numerically stable, because the exponentials are all in the range .
The connection to softmax is direct: . In practice, many frameworks compute log-softmax (which is needed for cross-entropy loss) using the LSE trick directly, avoiding the need to compute softmax and then take the log.
PyTorch’s torch.nn.functional.log_softmax and torch.nn.functional.cross_entropy both use the log-sum-exp trick internally. If you compute softmax followed by log manually, you lose both numerical stability (underflow in softmax produces ) and performance (two kernel launches instead of one). Always use the fused versions.
Performance Cost of Two Passes
The two-pass algorithm reads the input data twice: once to find the maximum, once to compute the exponentials. For a vector of elements, this is reads from memory. In the attention context, where softmax operates on each row of the attention matrix, this means reads across all rows. On memory-bandwidth-limited hardware (which GPUs are for this operation), that factor of 2 matters.
This cost is what motivated the search for a single-pass algorithm.
Part 4: Online Softmax (Milakov-Gimelshein)
The two-pass softmax algorithm requires seeing all elements before computing any output: you need the global maximum before you can compute any exponential safely. This is a problem for FlashAttention, which processes the key-value pairs in tiles — it sees one block of keys at a time and cannot afford to materialize the full attention matrix.
The online softmax algorithm, described by Milakov and Gimelshein (2018), solves this problem. It computes the exact same result as the two-pass algorithm, but in a single streaming pass that can process data incrementally.
The Core Insight
The key observation is that if you have a partial softmax computed over elements , and new elements arrive, you can update the running result without restarting from scratch. The only extra work is a rescaling of the partial sum to account for the new maximum.
The Algorithm
Maintain two running variables:
- — the maximum of all logits seen so far (through element )
- — the sum of exponentials , computed using the current maximum
Initialization:
Update for each new element :
The first term rescales the old sum: all previous exponentials were computed with respect to , but we now need them with respect to . Multiplying by accomplishes this exactly. The second term adds the contribution of the new element.
After processing all elements, , and the softmax is:
After processing all elements, the online algorithm produces:
which gives the identical result to the two-pass stable softmax.
Proof by induction. The base case is trivial: , . For the inductive step, assume and . Then:
And:
The rescaling factor corrects each exponential in the partial sum from the old maximum to the new one. This is exact arithmetic — no approximation is involved.
Block-Wise Extension for FlashAttention
In FlashAttention, we do not process one element at a time. We process tiles of key-value pairs. The algorithm extends naturally to blocks. Suppose we have processed blocks and maintained , where is the running output (the weighted sum of value vectors). When block with keys and values arrives:
- Compute local attention scores:
- Compute local maximum:
- Update global maximum:
- Rescale old sum:
- Rescale old output and add new contribution:
The entire reason FlashAttention can process attention in tiles without materializing the matrix is that online softmax allows incremental computation. Each tile of K and V is loaded from HBM to SRAM, the running softmax statistics are updated in registers, and the tile is discarded. The full attention matrix never exists in memory. Without online softmax, tiled attention would require two full passes over the KV data — defeating the purpose of tiling.
Implementation in Python
Here is the online softmax algorithm for a single vector, demonstrating correctness:
import numpy as np
def online_softmax(z):
n = len(z)
m = -np.inf # running max
l = 0.0 # running sum of shifted exponentials
# Single pass: compute m and l incrementally
for i in range(n):
m_new = max(m, z[i])
l = np.exp(m - m_new) * l + np.exp(z[i] - m_new)
m = m_new
# Compute final softmax using accumulated statistics
return np.exp(z - m) / l
z = np.array([1000.0, 1001.0, 1002.0], dtype=np.float64)
print(online_softmax(z))
# [0.09003057, 0.24472847, 0.66524096]
This produces the same result as the two-pass algorithm but processes the input in a single sweep.
Performance Comparison
Memory Reads for Softmax (per row of N elements)
(passes over data)The online algorithm matches the naive version’s single-pass behavior while achieving the two-pass version’s numerical stability. For FlashAttention, where each “pass” over KV data requires a round-trip from HBM to SRAM, halving the number of passes translates directly to halving the memory traffic for the softmax portion of the computation.
Part 5: Temperature Scaling
Temperature is a single scalar parameter that controls how “sharp” or “flat” the softmax distribution is. It provides a knob to tune the trade-off between confident (low-entropy) and exploratory (high-entropy) outputs.
Definition
The temperature-scaled softmax is:
where is the temperature parameter. The logits are divided by before the softmax is applied.
The Three Regimes
(Low temperature): approaches argmax. As temperature decreases toward zero, the softmax concentrates all probability mass on the largest logit. In the limit, . The distribution becomes deterministic: only the most probable element survives.
To see why, consider two logits . The ratio of their probabilities is . As , this ratio grows without bound — the larger logit dominates exponentially.
(Unit temperature): the model’s native distribution. The softmax operates on the raw logits as learned during training. This is the distribution the model was optimized to produce.
(High temperature): approaches uniform. As temperature increases, division by shrinks all logits toward zero, erasing the differences between them. In the limit, . Every element is equally likely.
Effect of Temperature on Softmax Output
| Temperature | softmax([3, 1, 0.5]) | Max Prob | Entropy (bits) |
|---|---|---|---|
| T = 0.1 | [0.9999, 0.0000, 0.0000] | 99.99% | 0.002 |
| T = 0.5 | [0.9820, 0.0122, 0.0058] | 98.20% | 0.145 |
| T = 1.0 | [0.8360, 0.1131, 0.0688] | 83.60% | 0.813 |
| T = 2.0 | [0.5409, 0.2412, 0.2179] | 54.09% | 1.430 |
| T = 5.0 | [0.3908, 0.3092, 0.3000] | 39.08% | 1.569 |
| T = 100 | [0.3340, 0.3333, 0.3327] | 33.40% | 1.585 |
The Connection to Entropy
Shannon entropy of a probability distribution is . Higher entropy means more “randomness” or “uncertainty.” The entropy of the softmax output is a monotonically increasing function of temperature:
- At : entropy approaches 0 (complete certainty)
- At : entropy approaches (maximum uncertainty)
This relationship is why temperature is the primary control for generation quality in language models. Sampling from a low-temperature distribution produces repetitive, predictable text. Sampling from a high-temperature distribution produces creative but potentially incoherent text. The optimal temperature depends on the task: code generation favors low temperature (correctness matters), creative writing favors moderate temperature (variety matters).
Why Temperature Works Mathematically
Temperature scaling is equivalent to exponentiating the original softmax probabilities:
When , we raise probabilities to a power greater than 1, which makes large probabilities larger and small probabilities smaller (sharpening). When , we raise probabilities to a power less than 1, which makes all probabilities more equal (flattening).
Most inference APIs expose temperature as a user-facing parameter. OpenAI’s API defaults to . Common values: for deterministic outputs (greedy decoding), — for factual tasks, — for general conversation, — for creative generation. Values above 2.0 rarely produce useful output.
Learned Temperature in Attention
Some transformer architectures go beyond a fixed temperature. The original “Attention Is All You Need” paper uses a fixed scaling factor , which is equivalent to a temperature of . But more recent work has explored:
-
Per-head learned temperature. Each attention head learns its own scalar temperature, allowing some heads to be “sharp” (attending to a single position) and others to be “broad” (attending to many positions). This was explored in various multi-scale attention designs.
-
Query-dependent temperature. The temperature is computed as a function of the query vector, allowing the sharpness of attention to vary by position and context. When the query is confident (high norm), the temperature can be lowered automatically.
These learned temperature variants add minimal parameters (one scalar per head or a small projection) but can improve performance on tasks where different attention patterns require different levels of focus.
Part 6: Softmax in Attention
Having understood softmax as an isolated function, we now examine its specific role in the attention mechanism and the engineering decisions that surround it.
The Scaled Dot-Product Attention
The standard attention formula is:
The softmax is applied row-wise to the matrix , where , , and the resulting attention matrix is .
Why Scaling Is Essential
Without the factor, the dot products have a variance that scales linearly with . If each component of and is drawn from a distribution with zero mean and unit variance, then:
The standard deviation of the dot product is . For (a common head dimension), the dot products have a standard deviation of about 11.3. This means many raw scores will have magnitudes of 20, 30, or higher.
When the inputs to softmax have large magnitude, the output approaches a one-hot vector — nearly all probability mass concentrates on the single largest element. In this saturated regime, gradients become vanishingly small (the Jacobian of softmax approaches the zero matrix), and the model cannot learn to adjust its attention patterns. This is not merely a numerical issue — it is a fundamental training failure mode.
Dividing by normalizes the variance back to approximately 1, keeping the softmax inputs in a regime where the function is sensitive and gradients flow properly. This is the “temperature” of attention, baked into the architecture:
For , . For , .
The Causal Mask
In autoregressive (decoder-only) models, token must not attend to any token — it cannot look into the future. This is enforced by adding a mask to the attention scores before softmax:
After softmax, , so masked positions receive zero weight. But there is a numerical subtlety: in floating-point arithmetic, we cannot represent exactly. The convention is:
- FP32: Use
-1e9orfloat('-inf')(IEEE 754 negative infinity, which is representable) - FP16: Use
-65504(the most negative FP16 value). True negative infinity is representable in FP16, but some kernels use the finite minimum instead. - BF16: Use
-3.39e38or BF16’s representable negative infinity.
If you have ever seen the magic number -65504 in attention kernel code and wondered where it comes from: it is the most negative finite value in IEEE 754 half-precision floating point. After the subtraction in the log-sum-exp trick, this becomes something like , which is so astronomically negative that the exponential underflows to exactly 0.0 in any floating-point format. The causal mask works correctly even though we cannot represent true .
Numerical Stability with Masking
The interaction between masking and the log-sum-exp trick deserves attention. When we compute , the masked positions are effectively and do not affect the maximum (assuming at least one unmasked position exists). The subsequent for masked positions evaluates to . The system is self-consistent: the mask, the max-subtraction, and the exponential all cooperate to produce the correct zero weights.
One edge case: for the very first token in a sequence (position 0), the causal mask allows attention only to itself. The softmax of a single element is always 1.0, regardless of the logit’s value. This is numerically stable by construction.
Part 7: Alternatives to Softmax
Despite its ubiquity, softmax is not the only way to normalize attention scores. Researchers have proposed several alternatives, motivated by the desire to reduce softmax’s computational cost or remove its sequence-length-dependent normalization.
Sigmoid Attention
Replace the row-wise softmax with element-wise sigmoid:
Each attention weight is independently mapped to . The weights for a given query no longer sum to 1.
Advantages: Sigmoid is cheaper than softmax (no global normalization across the sequence). Each weight can be computed independently, which is more hardware-friendly. Some vision transformer variants (notably SigLIP) have used sigmoid attention successfully.
Disadvantages: Without normalization, the output of attention is no longer a weighted average — it is a weighted sum. The magnitude of the output depends on how many keys match the query, which can vary widely across positions and layers. This typically requires additional normalization (such as LayerNorm) after attention.
ReLU Attention
Replace softmax with :
Negative scores become zero; positive scores pass through unchanged.
Advantages: ReLU is extremely cheap and introduces true sparsity — many attention weights are exactly zero, which could enable sparse computation.
Disadvantages: Like sigmoid, no normalization. But worse: ReLU outputs can be arbitrarily large, making the attention output scale unpredictable. Training instability is common. The gradient of ReLU is either 0 or 1, losing the smooth gradient landscape that softmax provides.
Squared ReLU Attention
A refinement: . The squaring amplifies large values (similar to how softmax’s exponential amplifies differences) and provides a smoother gradient at zero. This has shown competitive results in some research papers, but has not seen widespread adoption.
Linear Attention
The most radical departure: remove the nonlinearity entirely. Rewrite attention as:
where is a feature map (such as or a random Fourier feature). The key insight is the associativity of matrix multiplication: instead of computing the attention matrix first (cost ), we compute first (cost ), then multiply by (cost ). The total cost is — linear in sequence length.
Asymptotic Complexity: Softmax vs Linear Attention
(relative compute (N=4096, d=128))The catch: Linear attention consistently underperforms softmax attention on quality benchmarks, especially for language modeling. The normalization property of softmax — that weights sum to 1 — appears to be critical for learning sharp, selective attention patterns. Without it, attention becomes “mushy”: every token attends to everything with similar weight, degrading the model’s ability to focus on relevant context.
Softmax vs Alternatives: Quality Comparison (indicative)
| Attention Type | Normalization | Complexity | LM Quality | Adoption |
|---|---|---|---|---|
| Softmax | Row-wise sum-to-1 | O(N^2 d) | Baseline | Dominant |
| Sigmoid | Element-wise (0,1) | O(N^2 d) | -1 to -3% ppl | Niche (vision) |
| ReLU | None | O(N^2 d) | -5 to -10% ppl | Research only |
| Squared ReLU | None | O(N^2 d) | -2 to -5% ppl | Research only |
| Linear | Approximate | O(N d^2) | -5 to -15% ppl | Research only |
Why Softmax Persists
Softmax has survived every challenger because of a combination of mathematical and practical properties:
-
Probability distribution. Outputs are positive and sum to 1. This means attention output is a convex combination of value vectors, with a clear probabilistic interpretation.
-
Gradient flow. The softmax Jacobian provides smooth, well-conditioned gradients across a wide range of inputs. The saturated regime (near one-hot) has small gradients, but the scaling keeps inputs away from saturation.
-
Sharp attention. The exponential function allows softmax to produce very peaked distributions when needed (attending strongly to one or two positions) while remaining smooth and differentiable. Linear and sigmoid alternatives struggle to be both selective and stable.
-
Mature infrastructure. FlashAttention, cuDNN, and every major framework have heavily optimized softmax attention. The constant factor in any alternative must beat not just the theoretical complexity but also the engineering investment in softmax kernels.
-
Scaling law alignment. Models trained with softmax attention have demonstrated consistent, predictable scaling behavior across orders of magnitude in compute. Alternatives have not been validated at frontier scale, and the risk of switching is high when training runs cost tens of millions of dollars.
Part 8: Input/Output Specifications
To close, here is the precise specification of what softmax takes in and produces in transformer contexts.
As the Output Layer
- Input: A logit vector , where is the vocabulary size (commonly 32,000 to 128,256). Each element is an unbounded real number representing the model’s unnormalized score for token .
- Output: A probability vector on the probability simplex: for all , and .
- Computational cost: for a single softmax. Across a batch of sequences each of length , the total cost is at the output layer, though in practice this is dominated by the preceding linear projection.
As the Attention Normalizer
- Input: An attention score matrix (for one head, one layer), where . Values are real numbers, typically in after scaling. Masked positions are set to a large negative value (e.g., or ).
- Output: An attention weight matrix where each row lies on the probability simplex: and for every row .
- Computational cost: per head. For heads, layers, and batch size , the total softmax cost across the model is .
Softmax Computation Cost by Context
| Context | Input Size | Cost per Call | Total per Forward Pass |
|---|---|---|---|
| Output logits (V=128K) | R^128000 | O(V) = O(128K) | O(B * N * V) |
| Attention (N=2K, H=32, L=80) | R^(2K x 2K) | O(N^2) = O(4M) | O(B * L * H * N^2) = O(B * 10T) |
| Attention (N=128K, H=32, L=80) | R^(128K x 128K) | O(N^2) = O(16B) | O(B * L * H * N^2) |
Numerical Requirements
The log-sum-exp trick is non-negotiable for any production implementation. Online softmax is required for FlashAttention-style tiled computation. Temperature scaling must be applied before the max-subtraction, not after. Causal masking must be applied before softmax, and the mask value must be sufficiently negative to underflow to zero after exponentiation.
In a single forward pass of a modern LLM, softmax executes thousands of times: once per head per layer for attention ( times for a 70B model), plus once at the output layer. The FlashAttention kernel fuses the softmax computation (including the online max tracking, exponential computation, and normalization) into the same CUDA kernel as the matrix multiplications, so softmax never appears as a separate operation in the execution timeline. It is computed in registers and SRAM, tile by tile, invisible to HBM. That is the final engineering triumph of the ideas in this post: the function that appears in every textbook as a simple fraction is, in practice, a carefully orchestrated streaming computation that never fully materializes its inputs or outputs.
Summary
Softmax is far more than . It is a numerically treacherous function that requires the log-sum-exp trick for stability, a streaming algorithm (online softmax) for FlashAttention, a temperature parameter for controlling distribution sharpness, and careful engineering around masking, scaling, and floating-point representation. Every alternative that has been proposed — sigmoid, ReLU, linear attention — trades away one or more of softmax’s core properties (normalization, sharp selectivity, smooth gradients) and has failed to displace it at scale.
If you are building systems that run transformers, you need to understand softmax at this level. Not because you will implement it from scratch, but because the numerical choices made inside softmax — the max-subtraction, the rescaling in online softmax, the temperature, the -65504 mask value — determine whether your model trains stably, generates coherently, and runs efficiently. These are the details that separate a working system from one that produces NaN.
In Part 6, we will move to the next component in the transformer block: the attention variants (MHA, MQA, GQA, and MLA) that reshape the KV cache and determine inference memory cost.