Part of Series Transformer Anatomy 1 of 23
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 21 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 22 Knowledge Distillation: Training Small Models to Match Large Ones 23 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search

The attention mechanism is the computational heart of every modern large language model. It replaced recurrence, enabled parallelism, and made it possible to train models on thousands of GPUs simultaneously. But attention is not free. Its quadratic cost in sequence length is the single largest constraint on how much context a model can process, and understanding exactly where that cost comes from — in FLOPs, in memory, in memory bandwidth — is essential for anyone building or deploying these systems.

This post starts from first principles: why attention was invented, what problem it solves, how the math works, and then proceeds through a full performance autopsy of every tensor, every matrix multiply, and every memory access pattern in the attention computation. By the end, you will know not just what attention does, but exactly how expensive it is and why.

Part 1: Why Attention Was Created

The RNN Bottleneck

Before transformers, sequence modeling was dominated by recurrent neural networks — LSTMs, GRUs, and their variants. RNNs process sequences one token at a time, maintaining a hidden state that is updated at each step:

ht=f(ht1,xt)h_t = f(h_{t-1}, x_t)

This design has three fundamental problems.

Problem 1: Sequential processing prevents parallelism. Because hth_t depends on ht1h_{t-1}, you cannot compute any hidden state until all previous states have been computed. A sequence of length nn requires nn sequential steps. On a GPU with thousands of cores, most of those cores sit idle waiting for the previous step to complete. Training on long sequences is painfully slow — not because the hardware lacks capacity, but because the algorithm cannot use it.

Problem 2: Vanishing and exploding gradients. During backpropagation through time, gradients must flow backward through every sequential step. At each step, gradients are multiplied by the recurrent weight matrix. Over hundreds or thousands of steps, these repeated multiplications cause gradients to either vanish (approach zero) or explode (approach infinity). LSTMs and GRUs introduced gating mechanisms to mitigate this, but they only partially solve the problem. In practice, RNNs struggle to learn dependencies beyond a few hundred tokens.

Problem 3: The information bottleneck. The hidden state hth_t is a fixed-size vector, typically a few hundred to a few thousand dimensions. All information from the entire preceding sequence must be compressed into this single vector. For a 10,000-token document, the model must somehow encode every relevant fact into, say, 1024 floating-point numbers. Information is inevitably lost.

ℹ️ The Core Insight

Attention solves all three problems simultaneously. It replaces sequential computation with parallel pairwise comparisons, provides direct gradient paths between any two positions (regardless of distance), and allows each output position to selectively access information from every input position without compression into a fixed-size bottleneck.

The Attention Solution

The key idea is simple: instead of processing tokens sequentially and compressing history into a hidden state, let every token look at every other token directly. For a sequence of nn tokens, compute n×nn \times n pairwise similarity scores, then use those scores to create weighted combinations of token representations.

This is fully parallelizable — all n2n^2 similarity scores can be computed in a single matrix multiplication. Gradients flow directly between any two positions through the attention weights, with no repeated multiplication through recurrent connections. And there is no information bottleneck: each output position has direct access to the full representation of every input position.

The cost? That n2n^2 factor. We will dissect it thoroughly.

Part 2: Attention as Database Lookup

The most useful mental model for attention is a soft database lookup. Imagine a key-value store where, instead of exact matching, you perform a fuzzy search and return a weighted combination of all values based on how well each key matches your query.

The Three Projections

Given an input sequence of token embeddings XRn×dX \in \mathbb{R}^{n \times d}, attention creates three different “views” of the data:

Q=XWQ,K=XWK,V=XWVQ = XW_Q, \quad K = XW_K, \quad V = XW_V

where WQ,WK,WVRd×dW_Q, W_K, W_V \in \mathbb{R}^{d \times d} are learned weight matrices. Each serves a distinct role:

  • Query (QQ): “What am I looking for?” Each row of QQ represents what a token wants to find in the sequence. When token ii at position 5 generates its query vector, that vector encodes something like “I need the subject of this clause” or “I need the most recent noun.”

  • Key (KK): “What do I contain?” Each row of KK represents what a token offers to other tokens searching for information. A noun might generate a key that says “I am a noun, I am the subject, I am at position 3.”

  • Value (VV): “What do I return if matched?” Each row of VV is the actual information payload. If a query matches a key well, the corresponding value is what gets passed along.

The separation of K and V is crucial. The key determines whether a token is relevant; the value determines what information is extracted. A word might be highly relevant (strong key match) but contribute different types of information depending on context (different value representations for different downstream tasks).

The Dot Product as Similarity

The similarity between query ii and key jj is computed as a dot product:

scoreij=qikj=l=1dqilkjl\text{score}_{ij} = q_i \cdot k_j = \sum_{l=1}^{d} q_{il} \cdot k_{jl}

Why dot product? It is a measure of alignment in the embedding space. Two vectors pointing in the same direction yield a large positive score; orthogonal vectors yield zero; opposing vectors yield a large negative score. The dot product is also extremely fast to compute — it maps directly to matrix multiplication hardware (tensor cores on NVIDIA GPUs).

In matrix form, all n2n^2 scores are computed in one operation:

S=QKTRn×nS = QK^T \in \mathbb{R}^{n \times n}

Each element SijS_{ij} tells us how much token ii should attend to token jj.

Why Scale by d\sqrt{d}

The raw dot products grow in magnitude with the dimension dd. If qq and kk are vectors of independent random values with zero mean and unit variance, the expected value of qkq \cdot k is zero but its variance is dd. For d=128d = 128, the standard deviation of scores would be about 12811.3\sqrt{128} \approx 11.3, pushing softmax into saturated regions where gradients vanish.

Dividing by d\sqrt{d} normalizes the variance back to 1:

S=QKTdkS = \frac{QK^T}{\sqrt{d_k}}

This keeps the softmax in a well-behaved regime throughout training.

Why Softmax: A Probability Distribution Over Values

The scaled scores are passed through softmax to produce attention weights:

αij=exp(Sij)k=1nexp(Sik)\alpha_{ij} = \frac{\exp(S_{ij})}{\sum_{k=1}^{n} \exp(S_{ik})}

Softmax has three important properties:

  1. All weights are non-negative. This means the output is always a convex combination of values — the output “lives in the same space” as the input values.
  2. Weights sum to 1. Each query produces a proper probability distribution over positions. This acts as a form of normalization that prevents output magnitudes from scaling with sequence length.
  3. It is differentiable. Gradients flow smoothly from the output through the attention weights back to Q, K, and V.

The final output is a weighted sum of values:

outputi=j=1nαijvj\text{output}_i = \sum_{j=1}^{n} \alpha_{ij} v_j

Or in matrix form:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Scaled dot-product attention.

    Q: [batch, seq_len, d_k]
    K: [batch, seq_len, d_k]
    V: [batch, seq_len, d_v]
    Returns: [batch, seq_len, d_v], [batch, seq_len, seq_len]
    """
    d_k = Q.size(-1)

    # Step 1: QK^T — compute all pairwise similarities
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # Step 2: Optional masking (for causal / padding)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 3: Softmax — convert scores to probability distribution
    attn_weights = F.softmax(scores, dim=-1)

    # Step 4: Weighted sum of values
    output = torch.matmul(attn_weights, V)

    return output, attn_weights
💡 The Database Analogy, Precisely

In a traditional database: you have a query, you scan keys for an exact match, and you return the matched value. In attention: you have a query vector, you compute similarity against all key vectors (dot product), convert similarities to weights (softmax), and return a weighted blend of all values. The “fuzziness” — the fact that you get a blend rather than an exact match — is what makes attention differentiable and trainable.

Part 3: Multi-Head Attention

Why Multiple Heads?

A single attention mechanism computes one set of attention weights — one way of deciding “who attends to whom.” But language has many simultaneous types of relationships:

  • Syntactic relationships: A verb attends to its subject and object.
  • Semantic relationships: A pronoun attends to its antecedent.
  • Positional relationships: A token attends to its local neighbors.
  • Long-range dependencies: A closing bracket attends to its matching opening bracket.

A single attention head must somehow capture all of these in one set of weights. Multi-head attention solves this by running multiple independent attention computations in parallel, each with its own learned projections, allowing different heads to specialize in different relationship types.

The Mechanics

Given hh attention heads and a model dimension dmodeld_{\text{model}}, each head operates on a subspace of dimension dk=dmodel/hd_k = d_{\text{model}} / h:

headi=Attention(XWQi,XWKi,XWVi)\text{head}_i = \text{Attention}(XW_Q^i, XW_K^i, XW_V^i)

where WQi,WKi,WViRdmodel×dkW_Q^i, W_K^i, W_V^i \in \mathbb{R}^{d_{\text{model}} \times d_k}.

The outputs from all heads are concatenated and projected back to dmodeld_{\text{model}}:

MultiHead(X)=Concat(head1,,headh)WO\text{MultiHead}(X) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W_O

where WORdmodel×dmodelW_O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}}.

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        # Separate projections for Q, K, V
        self.W_q = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_k = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_v = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_o = torch.nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.size()

        # Project: [batch, seq_len, d_model]
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)

        # Split into heads: [batch, num_heads, seq_len, d_k]
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # Per-head attention: scores [batch, num_heads, seq_len, seq_len]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_output = torch.matmul(attn_weights, V)
        # attn_output: [batch, num_heads, seq_len, d_k]

        # Concatenate heads: [batch, seq_len, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)

        # Final projection
        output = self.W_o(attn_output)

        return output, attn_weights

What Heads Actually Learn

Research on trained models reveals striking specialization:

  • Some heads learn positional patterns — always attending to the previous token, or to the first token in the sequence.
  • Some heads learn syntactic roles — tracking subject-verb agreement across clauses.
  • Some heads learn coreference — connecting pronouns to their referents many positions away.
  • Some heads learn induction patterns — detecting and continuing repeated sequences (critical for in-context learning).

This specialization emerges entirely from training. The architecture provides the capacity for hh independent attention patterns; gradient descent finds useful specializations.

The Cost of Multi-Head Attention

A key insight: multi-head attention with hh heads of dimension dk=dmodel/hd_k = d_{\text{model}}/h has the same total compute as single-head attention with dimension dmodeld_{\text{model}}. The work is redistributed, not multiplied. Each head computes attention over a smaller dkd_k-dimensional subspace, and there are hh such computations.

The total FLOPs for the QKV projection matrices are:

3×2×n×dmodel23 \times 2 \times n \times d_{\text{model}}^2

This is the same regardless of how many heads you use. The attention computation itself (QK^T, softmax, AV) is also the same total, just partitioned differently. What multi-head attention adds is the output projection WOW_O: an additional 2×n×dmodel22 \times n \times d_{\text{model}}^2 FLOPs.

📊

Multi-Head Attention Parameter Count (d_model = 4096)

ComponentShapeParametersMemory (FP16)
W_Q 4096 x 4096 16.8M 32 MB
W_K 4096 x 4096 16.8M 32 MB
W_V 4096 x 4096 16.8M 32 MB
W_O 4096 x 4096 16.8M 32 MB
Total 4 x 4096^2 67.1M 128 MB
Note: Per layer. A 32-layer model has 32 x 128 MB = 4 GB in attention projection weights alone.

Part 4: The O(n^2) Cost Breakdown

The attention computation has three major stages, each with distinct computational and memory characteristics. Let us dissect them one by one.

Stage 1: QK^T — The Score Matrix

S=QKTRn×nS = QK^T \in \mathbb{R}^{n \times n}

This is a matrix multiplication of QRn×dkQ \in \mathbb{R}^{n \times d_k} by KTRdk×nK^T \in \mathbb{R}^{d_k \times n}.

FLOPs: 2×n2×dk2 \times n^2 \times d_k (each of n2n^2 output elements requires dkd_k multiply-adds). With hh heads, total is 2×n2×dk×h=2×n2×dmodel2 \times n^2 \times d_k \times h = 2 \times n^2 \times d_{\text{model}}.

Memory written: n2n^2 elements for the score matrix (per head, so h×n2h \times n^2 total). In FP16 (2 bytes per element), that is 2×h×n22 \times h \times n^2 bytes.

Arithmetic intensity: FLOPs per byte of output = 2×dk2 \times d_k. For dk=128d_k = 128 (typical for Llama-class models), that is 256 FLOPs per element — well above the compute-memory ratio of modern GPUs. This means QK^T is compute-bound during prefill (large GEMM), but during decode (when QQ has only 1 row), it becomes a matrix-vector product with only 2×dk2 \times d_k FLOPs per output element and the operation becomes memory-bound.

Stage 2: Softmax

αij=exp(SijmaxkSik)kexp(SikmaxkSik)\alpha_{ij} = \frac{\exp(S_{ij} - \max_k S_{ik})}{\sum_{k} \exp(S_{ik} - \max_k S_{ik})}

(The subtraction of max\max is for numerical stability.)

FLOPs: O(n2)O(n^2) — three passes over the n×nn \times n matrix (find max, compute exp and sum, divide). Each pass touches n2n^2 elements.

Memory accesses: This is where softmax becomes expensive. Even though the FLOPs are “only” O(n2)O(n^2) compared to O(n2d)O(n^2 d) for the matrix multiplies, softmax requires reading and writing the entire n×nn \times n attention matrix. On modern GPUs, memory bandwidth — not compute — is the scarce resource. The softmax forces the full n2n^2 attention matrix to be materialized in GPU memory (HBM), read back for the three passes, and written out again. This is exactly the bottleneck that FlashAttention eliminates by fusing QK^T, softmax, and AV into a single kernel that keeps the attention matrix in on-chip SRAM.

Softmax: Cheap in FLOPs, Expensive in Memory Traffic

Softmax performs about 5n25n^2 FLOPs (max, subtract, exp, sum, divide) but requires 3×n2×bytes3 \times n^2 \times \text{bytes} of memory traffic (read scores, write weights, read weights again for AV). At sequence length 4096 with 32 heads in FP16, that is 3×32×40962×2=3.03 \times 32 \times 4096^2 \times 2 = 3.0 GB of memory traffic just for softmax. On an A100 with 2 TB/s bandwidth, softmax alone takes ~1.5 ms — comparable to the matrix multiplies despite having far fewer FLOPs.

Stage 3: Attention x V

O=αVRn×dvO = \alpha V \in \mathbb{R}^{n \times d_v}

This multiplies the attention weights αRn×n\alpha \in \mathbb{R}^{n \times n} by values VRn×dvV \in \mathbb{R}^{n \times d_v}.

FLOPs: 2×n2×dv2 \times n^2 \times d_v per head, or 2×n2×dmodel2 \times n^2 \times d_{\text{model}} total.

Memory read: n2n^2 (attention weights) + n×dvn \times d_v (V matrix).

Same arithmetic intensity analysis as QK^T applies here: compute-bound during prefill, memory-bound during decode.

Total Cost Summary

📊

Attention FLOPs Breakdown (per layer, all heads)

OperationFLOPsMemory WrittenBound (Prefill)Bound (Decode)
QKV Projection 6 n d^2 3 n d Compute Memory BW
QK^T 2 n^2 d h n^2 Compute Memory BW
Softmax ~5 h n^2 h n^2 Memory BW Memory BW
Attention x V 2 n^2 d n d Compute Memory BW
Output Projection 2 n d^2 n d Compute Memory BW
Note: n = sequence length, d = d_model, h = num_heads. FLOPs counted as multiply-adds x 2.

The attention-specific operations (QK^T + softmax + AV) cost 4n2d+5hn24n^2d + 5hn^2 FLOPs. The projection operations (QKV + output) cost 8nd28nd^2 FLOPs.

The crossover point — where attention operations exceed projection operations — occurs when:

4n2d>8nd2    n>2d4n^2d \gt 8nd^2 \implies n \gt 2d

For d=4096d = 4096 (Llama-7B), attention dominates when n>8192n \gt 8192. For shorter sequences, the projection matrices dominate. This is a critical insight: for most practical workloads today (context lengths under 8K), the linear projections cost more than the quadratic attention computation itself.

Attention vs Projection FLOPs by Sequence Length (d=4096)

(GFLOPs)
n=512: Projections Projections dominate
68.7 GFLOPs
n=512: Attention ops
8.6 GFLOPs
n=2048: Projections
274.9 GFLOPs
n=2048: Attention ops
137.4 GFLOPs
n=8192: Projections Crossover point
1,099.5 GFLOPs
n=8192: Attention ops Attention dominates
2,199 GFLOPs
n=32768: Projections
4,398 GFLOPs
n=32768: Attention ops 8x more than projections
35,184.4 GFLOPs

Part 5: Memory Analysis — Every Tensor, Every Byte

Understanding where memory goes is essential for capacity planning. Let us trace every intermediate tensor through the attention computation for a single layer and a single batch element.

The Tensors

For a model with dmodel=4096d_{\text{model}} = 4096, h=32h = 32 heads, dk=128d_k = 128, using FP16 (2 bytes per element):

Input: XRn×dX \in \mathbb{R}^{n \times d}n×4096×2n \times 4096 \times 2 bytes = 8192n8192n bytes

Q, K, V: Each is Rn×d\mathbb{R}^{n \times d}3×8192n3 \times 8192n bytes = 24576n24576n bytes total

Score matrix (QK^T): Rh×n×n\mathbb{R}^{h \times n \times n}32×n2×232 \times n^2 \times 2 bytes = 64n264n^2 bytes

Attention weights (after softmax): Same shape as scores — 64n264n^2 bytes

Attention output (before concatenation): Rn×d\mathbb{R}^{n \times d}8192n8192n bytes

Final output (after WOW_O): Rn×d\mathbb{R}^{n \times d}8192n8192n bytes

📊

Intermediate Tensor Memory (per layer, batch=1, d=4096, h=32, FP16)

Tensorn=512n=2048n=8192n=32768n=131072
Q + K + V 12 MB 48 MB 192 MB 768 MB 3,072 MB
Score matrix (QK^T) 32 MB 512 MB 8,192 MB 131,072 MB 2,097,152 MB
Attention weights 32 MB 512 MB 8,192 MB 131,072 MB 2,097,152 MB
Attention output 4 MB 16 MB 64 MB 256 MB 1,024 MB
Total activations 80 MB 1,088 MB 16,640 MB 263,168 MB 4,198,400 MB
Note: Score matrix and attention weights are the dominant terms. At n=8192, they consume 16 GB combined — the full memory of a T4 GPU.

The Crossover: When Attention Matrix Exceeds QKV Memory

The QKV tensors occupy 3×n×d×23 \times n \times d \times 2 bytes (linear in nn). The score matrix occupies h×n2×2h \times n^2 \times 2 bytes (quadratic in nn). The crossover occurs when:

h×n2×2>3×n×d×2    n>3dh=3×409632=384h \times n^2 \times 2 \gt 3 \times n \times d \times 2 \implies n \gt \frac{3d}{h} = \frac{3 \times 4096}{32} = 384

So for a Llama-7B-class model, the attention matrix exceeds QKV memory at just 384 tokens. For essentially any practical sequence length, the attention matrix is the dominant memory consumer.

Attention Memory Layout at n=2048, d=4096, h=32 (FP16)

The attention score matrix alone is 512 MB — 10x larger than Q+K+V combined.

0x0400 0x0000
0x1000 0x0400
0x3000 0x1000
0x5000 0x3000
0x5400 0x5000
Input X 16 MB
Q + K + V tensors 48 MB
Score matrix (QK^T) 512 MB
Attention weights 512 MB
Output 16 MB
n x d x 2 bytes
3 x n x d x 2 bytes — linear in n
h x n^2 x 2 bytes — quadratic in n
h x n^2 x 2 bytes — same as scores
n x d x 2 bytes
Input X 16 MB
Q + K + V tensors 48 MB
Score matrix (QK^T) 512 MB
Attention weights 512 MB
Output 16 MB

This is exactly why FlashAttention was invented. By computing attention in tiles and never materializing the full n×nn \times n score matrix in HBM, FlashAttention reduces attention memory from O(n2)O(n^2) to O(n)O(n) — a reduction from 512 MB to a few megabytes at n=2048n = 2048.

Part 6: Causal Masking

Why Decoders Need It

In autoregressive language models, the model generates tokens one at a time, left to right. During training, the model processes entire sequences in parallel (teacher forcing), but each position must only attend to positions at or before it — otherwise the model would “see the answer” during training and learn nothing useful.

Mathematically, for a decoder-only model, position ii can attend to positions {1,2,,i}\{1, 2, \ldots, i\} but not {i+1,,n}\{i+1, \ldots, n\}. This constraint is called causal masking (because it enforces the causal structure: effects cannot depend on future causes).

How It Is Implemented

There are two common approaches:

Approach 1: Additive mask before softmax. Create an upper-triangular matrix of -\infty values and add it to the score matrix before softmax. Since exp()=0\exp(-\infty) = 0, the masked positions contribute zero weight:

def causal_attention(Q, K, V):
    """Causal (autoregressive) attention."""
    d_k = Q.size(-1)
    seq_len = Q.size(-2)

    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

    # Create causal mask: upper triangle = -inf
    causal_mask = torch.triu(
        torch.full((seq_len, seq_len), float('-inf'), device=Q.device),
        diagonal=1
    )
    scores = scores + causal_mask

    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, V)

    return output

Approach 2: Boolean mask with masked_fill. Create a boolean upper-triangular mask and fill masked positions with a large negative value:

mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)
scores = scores.masked_fill(mask, float('-inf'))

Both approaches are mathematically equivalent. In practice, most implementations use approach 1 because the additive mask can be precomputed and broadcast efficiently.

Computational Savings from Causal Masking

With causal masking, approximately half of the attention matrix is zeroed out. Does this save computation?

In theory: Yes. Only n(n+1)/2n(n+1)/2 of the n2n^2 attention scores are nonzero — roughly half the computation for QK^T and attention-times-V.

In practice: It depends on the implementation. A naive implementation still computes the full n×nn \times n matrix and then masks it. Optimized implementations (including FlashAttention) can skip computation for the masked region entirely, yielding close to a 2x speedup for the attention-specific operations. Some hardware-aware kernels go further by using triangular tile patterns that avoid loading masked tiles altogether.

📊

Causal Masking Computational Impact

ImplementationQK^T FLOPsAV FLOPsMemoryNotes
Full attention (no mask) 2 n^2 d 2 n^2 d O(n^2) Encoder / bidirectional
Naive causal (compute + mask) 2 n^2 d 2 n^2 d O(n^2) Computes everything, zeros half
Optimized causal (skip upper) n^2 d n^2 d O(n^2) ~2x fewer FLOPs
FlashAttention causal n^2 d n^2 d O(n) Tiled + skip + fused
Note: FlashAttention's causal mode skips upper-triangular tiles entirely, saving both FLOPs and memory traffic.

KV Cache: Exploiting Causal Structure at Inference Time

During autoregressive generation, we generate one token at a time. Thanks to causal masking, when generating token tt, the attention computation for all previous tokens {1,,t1}\{1, \ldots, t-1\} is identical to what was computed at step t1t-1. We can cache the key and value projections for all previous tokens and only compute the new Q, K, V for the current token:

class CachedCausalAttention(torch.nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.W_q = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_k = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_v = torch.nn.Linear(d_model, d_model, bias=False)
        self.W_o = torch.nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, kv_cache=None):
        """
        x: [batch, 1, d_model] during generation (single new token)
        kv_cache: tuple of (cached_K, cached_V) from previous steps
        """
        batch_size = x.size(0)

        # Project the new token
        q = self.W_q(x)  # [batch, 1, d_model]
        k = self.W_k(x)  # [batch, 1, d_model]
        v = self.W_v(x)  # [batch, 1, d_model]

        # Reshape for multi-head
        q = q.view(batch_size, 1, self.num_heads, self.d_k).transpose(1, 2)
        k = k.view(batch_size, 1, self.num_heads, self.d_k).transpose(1, 2)
        v = v.view(batch_size, 1, self.num_heads, self.d_k).transpose(1, 2)

        if kv_cache is not None:
            # Append new K, V to cache
            k = torch.cat([kv_cache[0], k], dim=2)
            v = torch.cat([kv_cache[1], v], dim=2)

        new_cache = (k, v)

        # Attention: Q is [b, h, 1, d_k], K is [b, h, t, d_k]
        # Scores: [b, h, 1, t] — only one row of the attention matrix!
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attn_weights, v)

        output = output.transpose(1, 2).contiguous().view(batch_size, 1, -1)
        output = self.W_o(output)

        return output, new_cache
ℹ️ KV Cache: Trading Memory for Compute

KV caching reduces per-token attention from O(n2d)O(n^2 d) to O(nd)O(n d) — you compute only one row of the attention matrix instead of all nn rows. The cost is O(nd)O(n d) memory per layer to store cached K and V tensors. For Llama-7B (32 layers, d=4096) at sequence length 2048 in FP16, the KV cache is 32×2×2048×4096×2=1.032 \times 2 \times 2048 \times 4096 \times 2 = 1.0 GB. At 131K context, it grows to 65 GB — often larger than the model weights themselves.

Part 7: The Computational Profile — Prefill vs Decode

The performance characteristics of attention differ dramatically between the two phases of inference.

Prefill: Processing the Prompt

During prefill, the model processes the entire input prompt in one forward pass. The attention computation involves large matrix multiplications: QRn×dQ \in \mathbb{R}^{n \times d} multiplied by KTRd×nK^T \in \mathbb{R}^{d \times n} to produce an n×nn \times n score matrix.

This is a large GEMM (General Matrix Multiply) — exactly the workload that GPUs are designed for. Tensor cores on modern NVIDIA GPUs can sustain hundreds of teraFLOPs on large GEMMs. Prefill attention is compute-bound: the GPU is limited by how fast it can execute multiply-accumulate operations, not by how fast it can read data from memory.

Optimization strategy for prefill: Maximize compute utilization. Use large batch sizes to create larger GEMMs. FlashAttention helps by reducing memory traffic (avoiding the HBM round-trip for the attention matrix), which allows the compute cores to stay fed with data.

Decode: Generating One Token at a Time

During decode, QQ has only 1 row (the new token). The QK^T computation becomes a matrix-vector product: qR1×dq \in \mathbb{R}^{1 \times d} times KTRd×tK^T \in \mathbb{R}^{d \times t}, producing a 1×t1 \times t score vector.

Matrix-vector products have terrible arithmetic intensity. For each output element, you perform 2dk2d_k FLOPs but must read dkd_k elements of KK from memory. With dk=128d_k = 128 and FP16, that is 256 FLOPs per 256 bytes read = 1 FLOP/byte. An A100 has ~312 TFLOPS of FP16 compute but only ~2 TB/s of HBM bandwidth — a ratio of ~156 FLOPs/byte. At 1 FLOP/byte, you are utilizing less than 1% of the GPU’s compute capability. Decode attention is profoundly memory-bandwidth-bound.

The dominant cost during decode is not computing the attention scores — it is reading the KV cache from HBM. For each token generated, the entire KV cache must be read once. At sequence length tt with LL layers:

KV cache read=L×2×t×d×2 bytes\text{KV cache read} = L \times 2 \times t \times d \times 2 \text{ bytes}

For Llama-7B at t=4096t = 4096: 32×2×4096×4096×2=2.032 \times 2 \times 4096 \times 4096 \times 2 = 2.0 GB per token generated. At 2 TB/s bandwidth, just reading the KV cache takes 1 ms per token.

Optimization strategy for decode: Reduce memory traffic. Grouped-Query Attention (GQA) reduces the KV cache by sharing K and V across groups of query heads — Llama-2 70B uses 8 KV heads instead of 64, reducing KV cache by 8x. KV cache quantization (FP16 to INT8 or INT4) provides another 2-4x reduction. Batching multiple requests together amortizes the weight-loading cost across requests.

📊

Prefill vs Decode Attention Profile (Llama-7B, A100 80GB)

MetricPrefill (n=2048)Decode (t=2048)Ratio
Q shape [2048, 4096] [1, 4096] 2048:1
Score matrix shape [2048, 2048] [1, 2048] 2048:1
Attention FLOPs ~34 GFLOPs ~16 MFLOPs ~2000:1
KV cache read N/A (computed) 2.0 GB --
Bottleneck Compute (tensor cores) Memory BW (HBM) --
GPU utilization 60-80% under 5% --
Optimization FlashAttention, larger batch GQA, KV quant, batching --
Note: Decode attention is so memory-bound that the GPU's compute cores are almost entirely idle. This is why batching multiple requests is essential for throughput.

GPU Resource Utilization: Prefill vs Decode

(% of peak)
Prefill: Compute utilization Large GEMMs saturate tensor cores
70 % of peak
Prefill: Memory BW utilization
40 % of peak
Decode: Compute utilization Matrix-vector products waste compute
3 % of peak
Decode: Memory BW utilization Bottleneck: reading KV cache
85 % of peak

Part 8: Where Attention Fits in the Transformer

Attention does not operate in isolation. It is one component in a larger transformer block. Understanding the full block is essential for understanding where attention’s costs fit in the bigger picture.

The Transformer Block

A standard transformer block (post-LN variant, as used in GPT-2 and many others) has this structure:

x = x + Attention(LayerNorm(x))    // Residual + attention sub-block
x = x + FFN(LayerNorm(x))          // Residual + FFN sub-block

The pre-LN variant (Llama, most modern models) applies LayerNorm before each sub-layer rather than after:

x_norm = RMSNorm(x)
x = x + Attention(x_norm)
x_norm = RMSNorm(x)
x = x + FFN(x_norm)

Each transformer block therefore contains:

  1. RMSNorm / LayerNorm (2x) — normalizes activations
  2. Multi-Head Attention — the Q/K/V projections, attention computation, and output projection
  3. Feed-Forward Network (FFN) — two or three linear layers with a nonlinearity
  4. Residual connections (2x) — adds the sub-layer input to the sub-layer output

Why Residual Connections Matter

Residual connections are the unsung hero of deep transformers. They provide a “gradient highway” — during backpropagation, the gradient of the loss with respect to layer ll‘s input includes a direct additive term from the gradient at layer l+1l+1:

Lxl=Lxl+1(1+Fl(xl)xl)\frac{\partial \mathcal{L}}{\partial x_l} = \frac{\partial \mathcal{L}}{\partial x_{l+1}} \cdot \left(1 + \frac{\partial F_l(x_l)}{\partial x_l}\right)

That “1+” term means gradients always have a direct path backward through the residual connections, regardless of what the attention or FFN layers do. Without residual connections, training transformers deeper than a few layers becomes extremely difficult due to vanishing gradients — the same problem that plagued deep RNNs.

Why FFN Accounts for 2/3 of the FLOPs

The FFN in a modern transformer (Llama-style SwiGLU) has three weight matrices:

FFN(x)=(SiLU(xWgate)xWup)Wdown\text{FFN}(x) = (\text{SiLU}(xW_{\text{gate}}) \odot xW_{\text{up}}) W_{\text{down}}

where Wgate,WupRd×dffW_{\text{gate}}, W_{\text{up}} \in \mathbb{R}^{d \times d_{ff}} and WdownRdff×dW_{\text{down}} \in \mathbb{R}^{d_{ff} \times d}.

With the typical expansion factor dff2.7dd_{ff} \approx 2.7d (for SwiGLU) or dff=4dd_{ff} = 4d (for the original transformer), the FFN’s parameter count and FLOPs per token are:

  • Original transformer FFN: 2×d×4d=8d22 \times d \times 4d = 8d^2 parameters, 2×2×n×d×4d=16nd22 \times 2 \times n \times d \times 4d = 16nd^2 FLOPs
  • SwiGLU FFN: 3×d×2.7d8d23 \times d \times 2.7d \approx 8d^2 parameters, 3×2×n×d×2.7d16nd23 \times 2 \times n \times d \times 2.7d \approx 16nd^2 FLOPs

Meanwhile, attention projections (QKV + output) have 4d24d^2 parameters and 8nd28nd^2 FLOPs. The attention computation itself adds 4n2d4n^2d FLOPs.

So the FFN has roughly twice the parameter count and twice the FLOPs of the attention projections. At moderate sequence lengths where n<2dn \lt 2d, the attention computation (4n2d4n^2d) is smaller than the attention projections (8nd28nd^2), making the total attention sub-block about 8nd2+4n2d8nd^2 + 4n^2d versus the FFN’s 16nd216nd^2. The FFN dominates.

📊

FLOPs per Component Within a Transformer Layer (Llama-7B, d=4096)

ComponentFLOPs at n=512ShareFLOPs at n=4096Share
QKV projections 25.8 GFLOPs 21% 206.2 GFLOPs 16%
Attention (QK^T + softmax + AV) 4.3 GFLOPs 3% 274.9 GFLOPs 22%
Output projection 8.6 GFLOPs 7% 68.7 GFLOPs 5%
FFN (gate + up + down) 57.9 GFLOPs 47% 463.5 GFLOPs 37%
Norms + residuals ~0.4 GFLOPs under 1% ~3.4 GFLOPs under 1%
Total per layer ~97 GFLOPs 100% ~1017 GFLOPs 100%
Note: At n=512, FFN is 47% vs attention at 31%. At n=4096, FFN drops to 37% as attention's quadratic cost grows to 22%.

Component Share of Layer FLOPs vs Sequence Length

(% of layer FLOPs)
n=512: FFN FFN dominates
47 % of layer FLOPs
n=512: Attention total
31 % of layer FLOPs
n=4096: FFN
37 % of layer FLOPs
n=4096: Attention total Attention catches up
43 % of layer FLOPs
n=16384: FFN
19 % of layer FLOPs
n=16384: Attention total Attention dominates
72 % of layer FLOPs

Component Scaling Across Model Sizes

The ratio of attention to FFN cost is remarkably stable across model sizes because both scale with d2d^2:

📊

Component Scaling Across Model Sizes (n=1024)

ModelLayersd_modelAttention ShareFFN ShareNorm/Residual
Llama-1B 22 2048 28% 68% 4%
Llama-7B 32 4096 28% 67% 5%
Llama-13B 40 5120 29% 66% 5%
Llama-70B 80 8192 30% 65% 5%
Note: At n=1024, the ratios are nearly identical. Attention share only increases significantly when sequence length grows.

Part 9: Scaling Laws — How Attention Cost Scales

With Model Size

For a model with LL layers, dmodeld_{\text{model}} dimensions, and hh heads, the total attention FLOPs per forward pass are:

Attention FLOPs=L×(8nd2+4n2d)\text{Attention FLOPs} = L \times (8nd^2 + 4n^2d)

The 8nd28nd^2 term (projections) scales quadratically with dd and linearly with nn. As models grow wider (larger dd), this term dominates at short sequences. Scaling from d=4096d = 4096 to d=8192d = 8192 quadruples the projection cost.

The 4n2d4n^2d term (attention computation) scales linearly with dd and quadratically with nn. Doubling dd only doubles this term, but doubling nn quadruples it.

With Sequence Length

Sequence length has a unique and problematic scaling behavior. At length nn:

  • Compute: O(n2)O(n^2) for attention, O(n)O(n) for everything else
  • Memory (activations): O(n2)O(n^2) for the attention matrix, O(n)O(n) for everything else
  • Memory (KV cache): O(n)O(n) per layer, but this applies to all LL layers

The quadratic scaling means that doubling context length from 4K to 8K does not merely double the cost — it roughly quadruples the attention cost. Going from 4K to 128K (a 32x increase) results in a 1,024x increase in attention FLOPs and memory.

📊

Attention Scaling with Sequence Length (Llama-7B, single layer)

Seq LengthAttention FLOPsAttention MemoryKV Cache (all layers)vs n=1024
1,024 8.6 GFLOPs 64 MB 0.5 GB 1x
4,096 137.4 GFLOPs 1,024 MB 2.0 GB 16x
16,384 2,199 GFLOPs 16,384 MB 8.0 GB 256x
65,536 35,184 GFLOPs 262,144 MB 32.0 GB 4,096x
131,072 140,737 GFLOPs 1,048,576 MB 64.0 GB 16,384x
Note: At 131K context, the attention matrix alone would require 1 TB per layer — impossible without FlashAttention or similar memory-efficient methods. KV cache at 131K exceeds the memory of most GPUs.

Why Context Length Is the Frontier Challenge

The quadratic scaling wall is why extending context length is one of the most active research areas in LLM development:

  • FlashAttention: Reduces memory from O(n2)O(n^2) to O(n)O(n) by tiling, but does not reduce FLOPs.
  • Ring Attention: Distributes the attention computation across multiple GPUs by partitioning the sequence.
  • Sparse attention patterns (Longformer, BigBird): Reduce FLOPs to O(nn)O(n \sqrt{n}) or O(nlogn)O(n \log n) by attending to only a subset of positions.
  • Linear attention (Mamba, RWKV, RetNet): Replace softmax attention with recurrent or state-space mechanisms that have O(n)O(n) cost, but sacrifice some of the quality that makes standard attention powerful.
  • Sliding window attention (Mistral): Limit each token’s attention to a local window, reducing cost to O(n×w)O(n \times w) where ww is the window size.

Each approach makes a different trade-off between computational cost, memory, and model quality. No approach has yet matched full quadratic attention in quality while achieving truly subquadratic scaling for all tasks.

Attention Cost Scaling Approaches

(relative FLOPs at n=32K)
Full attention O(n^2) Baseline
100 relative FLOPs at n=32K
Sliding window (w=4K) O(n x w)
12.5 relative FLOPs at n=32K
Sparse (sqrt pattern) O(n sqrt(n))
5.6 relative FLOPs at n=32K
Linear attention O(n x d)
3.1 relative FLOPs at n=32K

Part 10: When Attention Is NOT the Bottleneck

It is tempting to assume that attention is always the performance bottleneck. After all, it has quadratic complexity. But in many practical scenarios, attention is not the dominant cost.

Short Sequences: FFN Dominates

At sequence lengths under ~2K tokens, the FFN’s 16nd216nd^2 FLOPs exceed attention’s 8nd2+4n2d8nd^2 + 4n^2d FLOPs. For a typical chatbot interaction (512-1024 tokens of context), the FFN accounts for roughly half of all computation, while the attention computation (excluding projections) is under 10%.

If you are optimizing inference for short-context workloads, quantizing FFN weights (INT4/INT8) will yield a larger speedup than optimizing attention. At n=512n = 512 on Llama-7B, the FFN contains 58% of all parameters and consumes 47% of compute — quantizing it from FP16 to INT4 can reduce model memory by 6 GB and improve throughput by 2-3x.

Batch Decode: Weight Loading Dominates

During autoregressive decode with batch size 1, the dominant cost is not attention or FFN computation — it is loading model weights from HBM into the compute units. The model weights for Llama-7B are ~14 GB in FP16. Each token requires a full forward pass through all weights. At 2 TB/s HBM bandwidth, simply reading the weights takes 7 ms per token — this is the floor latency regardless of any attention optimization.

With larger batch sizes, the weight-loading cost is amortized across the batch. At batch size 32, the same 14 GB of weights are read once but used for 32 tokens. This shifts the bottleneck from weight loading toward the attention KV cache (which is per-request and scales with batch size).

Large Batch Prefill: Compute Capacity Limits

For large-batch prefill workloads, the GPU’s compute capacity (TFLOPS) is the bottleneck, not any specific component. The attention GEMMs and FFN GEMMs both achieve high tensor-core utilization. Optimization at this point is about maximizing occupancy and minimizing kernel launch overhead, not about attention-specific tricks.

📊

Bottleneck Analysis by Workload

WorkloadDominant BottleneckAttention ShareBest Optimization
Short context (n under 1K), batch=1 Weight loading (memory BW) ~10% of compute Weight quantization (INT4)
Short context, large batch FFN compute ~15% of compute FFN quantization + batching
Long context (n over 8K), prefill Attention compute ~50%+ of compute FlashAttention, ring attention
Long context, decode KV cache loading (memory BW) ~70% of time GQA, KV quantization, paging
Very long context (n over 64K) Attention memory capacity Dominates everything Sparse/linear attention, offloading
Note: Profile your specific workload before optimizing. The bottleneck shifts dramatically between these regimes.

The Decision Framework

When deciding what to optimize, measure first. Here is a rough framework:

  1. Is your sequence length > 2 x d_model? If yes, attention likely dominates. Optimize attention (FlashAttention, GQA, sparse patterns).

  2. Is your sequence length < d_model? If yes, FFN likely dominates. Optimize FFN (weight quantization, pruning).

  3. Is your batch size 1? If yes, you are memory-bandwidth-bound. Quantize everything, use speculative decoding, or increase batch size.

  4. Is your KV cache larger than your model weights? If yes, the KV cache is the bottleneck. Use GQA, KV quantization, or paged attention.

  5. None of the above? Profile with a tool like PyTorch profiler or Nsight Systems. The bottleneck may be in an unexpected place — kernel launch overhead, CPU-GPU synchronization, or memory allocation.

Putting It All Together

The attention mechanism is one of the most elegant ideas in modern deep learning: a differentiable database lookup that replaces sequential recurrence with parallel pairwise comparison. Its design — separate query, key, and value projections; scaled dot-product similarity; softmax normalization; multi-head decomposition — solves the fundamental problems of RNNs while enabling massive parallelism.

But elegance comes with a cost. The O(n2)O(n^2) scaling in both compute and memory means that attention’s cost grows faster than every other component in the transformer. At short sequences, this barely matters — the FFN dominates and attention is a modest overhead. At long sequences, attention becomes the overwhelming bottleneck, consuming the majority of FLOPs and memory.

The performance characteristics also differ sharply between phases. During prefill, attention is a large GEMM — compute-bound, high utilization, amenable to hardware acceleration. During decode, attention degenerates into memory-bound vector operations that waste most of the GPU’s compute capacity. These two regimes require fundamentally different optimization strategies.

Understanding these performance realities is not academic — it is essential for making good engineering decisions. Should you invest in FlashAttention or weight quantization? GQA or pruning? Longer context or larger batch size? The answer depends entirely on where in the performance landscape your workload sits.

The Practitioner's Summary

For sequences under ~2K tokens, optimize the FFN first (quantization saves the most memory and FLOPs). For sequences over ~8K tokens, optimize attention (FlashAttention, GQA, sparse patterns). For single-request decode latency, optimize memory bandwidth (quantize everything, increase batch size). For long-context decode, the KV cache is king — GQA and KV quantization are the highest-leverage interventions. Always profile before optimizing.