The attention mechanism computes a weighted sum over value vectors, where the weights are determined by query-key dot products passed through softmax. The mask controls which query-key pairs are allowed to interact. Setting a mask entry to negative infinity before softmax drives the corresponding attention weight to zero, effectively preventing information flow between those two positions. Every architectural decision about what the model can and cannot attend to is expressed through this mask.
This post covers five masking patterns in detail: causal (autoregressive), bidirectional (encoder-style), sliding window (Mistral), block sparse (BigBird, Longformer), and custom masks for production scenarios like variable-length batching and multi-document processing. Each section includes a complete implementation and analysis of computational cost. The final section covers how FlashAttention handles these patterns at the kernel level.
1. The Attention Equation and Where Masks Enter
The standard scaled dot-product attention for a single head:
where , , , and is the mask matrix. Here is the number of query positions and is the number of key/value positions.
The mask is an additive mask applied to the raw attention scores before softmax. Two conventions exist:
Additive mask: where attention is allowed, where attention is blocked. This is the mathematically natural form because .
Boolean mask: A boolean tensor where True means “block this position” (PyTorch convention) or “allow this position” (some other frameworks). The boolean mask is converted to additive form internally.
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: [batch, heads, seq_q, d_k]
K: [batch, heads, seq_k, d_k]
V: [batch, heads, seq_k, d_v]
mask: [seq_q, seq_k] or [batch, 1, seq_q, seq_k], additive
"""
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
if mask is not None:
scores = scores + mask # -inf entries zero out after softmax
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V)
The mask shape deserves attention. For a batch of sequences with heads, the full score tensor is . The mask can be:
- : same mask for all batches and heads (causal mask)
- : per-batch mask, shared across heads (padding mask)
- : per-batch, per-head mask (rarely needed, expensive)
Broadcasting rules apply. In practice, the mask is almost always or .
Memory Cost of Materializing the Mask
For a sequence length of and FP16 scores, the full attention score matrix is:
For 32 heads and batch size 4:
This is why FlashAttention avoids materializing the full score matrix. The mask must be applied tile-by-tile, never stored in full. We return to this in Section 7.
2. Causal (Autoregressive) Mask
The causal mask is the foundation of all decoder-only models (GPT, Llama, Mistral, DeepSeek). It enforces a strict constraint: token can only attend to tokens . This prevents information leakage from future tokens during training and is structurally necessary for autoregressive generation.
Construction
The causal mask is an upper-triangular matrix of values:
def create_causal_mask(seq_len, device='cuda', dtype=torch.float16):
"""
Create a causal (autoregressive) attention mask.
Returns: [seq_len, seq_len] tensor with 0 and -inf
"""
mask = torch.full(
(seq_len, seq_len), float('-inf'), device=device, dtype=dtype
)
mask = torch.triu(mask, diagonal=1)
return mask
# Example for seq_len=5:
# tensor([[ 0., -inf, -inf, -inf, -inf],
# [ 0., 0., -inf, -inf, -inf],
# [ 0., 0., 0., -inf, -inf],
# [ 0., 0., 0., 0., -inf],
# [ 0., 0., 0., 0., 0.]])
Boolean Alternative
PyTorch’s torch.nn.functional.scaled_dot_product_attention accepts a boolean attn_mask where True means “mask out” (block attention):
def create_causal_mask_bool(seq_len, device='cuda'):
"""Boolean causal mask. True = blocked."""
return torch.triu(
torch.ones(seq_len, seq_len, device=device, dtype=torch.bool),
diagonal=1
)
The boolean form uses 1 byte per element instead of 2 (FP16) or 4 (FP32). For : boolean mask is 64 MB vs 128 MB for FP16. In practice, neither is materialized in FlashAttention.
Causal Mask During Inference Decode
During the decode phase (generating one token at a time), the query has length 1 and the KV cache has length (the number of tokens generated so far). The attention score is . The causal mask is trivially all-zeros: a single query token at position can attend to all positions in the KV cache. No masking is needed during autoregressive decode.
During prefill (processing the full prompt), the causal mask is applied in full. This is the only phase where the causal mask matters at inference time.
FlashAttention 2 and 3 have a dedicated is_causal=True flag. When set, the kernel skips all tiles that fall entirely above the diagonal — it never loads those K/V tiles or computes those dot products. This is not just masking; it is a genuine compute reduction. For a causal mask, roughly half the tiles are skipped, reducing both FLOPs and memory traffic by approximately 2x compared to a full (bidirectional) attention pass.
FLOPs for Causal Attention
Full bidirectional attention computes FLOPs for the product (each of the entries requires a dot product of length , which is FLOPs). Causal attention only computes the lower triangle:
compared to full attention:
Causal attention is exactly 2x cheaper in FLOPs. For and :
3. Bidirectional (Encoder) Mask
Bidirectional attention allows every token to attend to every other token. There is no mask (or equivalently, the mask is all zeros). This is the attention pattern used in BERT, RoBERTa, and the encoder of encoder-decoder models like T5.
When Bidirectional Attention Makes Sense
Bidirectional attention is appropriate when the model processes a complete input and does not generate output autoregressively. Classification, retrieval, embedding, and the encoder stage of machine translation all use bidirectional attention.
def bidirectional_attention(Q, K, V):
"""No mask needed — all positions attend to all positions."""
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V)
The Cost of Bidirectional
The full score matrix is computed. No tiles are skipped. FLOPs are exactly — double the causal case. For long sequences, this is the most expensive attention pattern.
Attention FLOPs by Masking Pattern (Per Head, d=128)
| Sequence Length | Causal (GFLOPs) | Bidirectional (GFLOPs) | Sliding W=256 (GFLOPs) |
|---|---|---|---|
| 1,024 | 0.13 | 0.27 | 0.067 |
| 4,096 | 2.15 | 4.29 | 0.27 |
| 8,192 | 8.59 | 17.18 | 0.54 |
| 32,768 | 137.4 | 274.9 | 2.15 |
| 131,072 | 2,199 | 4,398 | 8.59 |
At 131K sequence length, the difference between bidirectional and sliding window is 512x in FLOPs. This is why long-context models universally use causal or sparse patterns.
Prefix-LM: A Hybrid
Some models (PaLM, UL2) use a prefix-LM pattern: bidirectional attention over a prefix, causal attention over the rest. The mask is:
where is the prefix length.
def create_prefix_lm_mask(seq_len, prefix_len, device='cuda', dtype=torch.float16):
"""
Bidirectional over prefix, causal over suffix.
"""
mask = torch.full((seq_len, seq_len), float('-inf'), device=device, dtype=dtype)
# Prefix: all positions can attend to prefix
mask[:, :prefix_len] = 0.0
# Causal region: lower triangle
causal_part = torch.triu(
torch.ones(seq_len, seq_len, device=device), diagonal=1
)
# Apply causal only to non-prefix query positions
for i in range(prefix_len, seq_len):
for j in range(prefix_len, seq_len):
if j <= i:
mask[i, j] = 0.0
return mask
In practice, the vectorized version avoids the loop:
def create_prefix_lm_mask_fast(seq_len, prefix_len, device='cuda', dtype=torch.float16):
mask = torch.full((seq_len, seq_len), float('-inf'), device=device, dtype=dtype)
mask[:, :prefix_len] = 0.0
# Lower triangle for suffix region
suffix_causal = torch.tril(torch.zeros(seq_len, seq_len, device=device, dtype=dtype))
mask[prefix_len:, prefix_len:] = suffix_causal[prefix_len:, prefix_len:]
return mask
4. Sliding Window Attention (Mistral)
Sliding window attention restricts each token to attend only to the most recent tokens. Token attends to tokens in the range . This is the attention pattern used in Mistral 7B (with ) and Mixtral.
Motivation: Linear Memory in Sequence Length
Standard causal attention has FLOPs and an KV cache that grows linearly with sequence length. However, the total attention computation over all positions is . Sliding window changes this:
- Each position attends to at most keys.
- Total FLOPs across all positions: — linear in if is fixed.
- KV cache at any decode step: only the last tokens need to be retained. Cache size is , independent of total sequence length.
For Mistral with : a 128K-token sequence uses the same KV cache memory as a 4K-token sequence. The savings are enormous.
Implementation
def create_sliding_window_mask(seq_len, window_size, device='cuda', dtype=torch.float16):
"""
Sliding window causal mask.
Token i attends to [max(0, i - window_size + 1), i].
"""
mask = torch.full((seq_len, seq_len), float('-inf'), device=device, dtype=dtype)
for i in range(seq_len):
start = max(0, i - window_size + 1)
mask[i, start:i+1] = 0.0
return mask
# Vectorized version
def create_sliding_window_mask_fast(seq_len, window_size, device='cuda', dtype=torch.float16):
row_idx = torch.arange(seq_len, device=device).unsqueeze(1)
col_idx = torch.arange(seq_len, device=device).unsqueeze(0)
# Causal: col <= row
# Window: col >= row - window_size + 1
valid = (col_idx <= row_idx) & (col_idx >= row_idx - window_size + 1)
mask = torch.where(
valid,
torch.tensor(0.0, device=device, dtype=dtype),
torch.tensor(float('-inf'), device=device, dtype=dtype),
)
return mask
For and , the mask pattern looks like (0 = attend, X = blocked):
Position: 0 1 2 3 4 5 6 7
Token 0: [ 0 X X X X X X X ]
Token 1: [ 0 0 X X X X X X ]
Token 2: [ 0 0 0 X X X X X ]
Token 3: [ X 0 0 0 X X X X ]
Token 4: [ X X 0 0 0 X X X ]
Token 5: [ X X X 0 0 0 X X ]
Token 6: [ X X X X 0 0 0 X ]
Token 7: [ X X X X X 0 0 0 ]
Information Flow Across Layers
A single sliding window layer with window allows information to flow at most positions. But stacking layers creates an effective receptive field of . Mistral with layers and has a theoretical receptive field of tokens — covering the full 128K context length.
This works because token at layer aggregates information from tokens at layer . Those tokens themselves aggregated from at layer . By layer , token has (indirect) access to tokens back to position .
The theoretical receptive field of assumes perfect information propagation through every layer. In practice, the effective attention range is substantially shorter because information degrades as it passes through many layers. Empirically, Mistral’s effective context length is closer to 16K-32K tokens despite the 128K theoretical receptive field. Models that need strong long-range retrieval typically use full attention for a subset of layers.
Sliding Window KV Cache Management
During inference, the KV cache becomes a ring buffer of size :
class SlidingWindowKVCache:
def __init__(self, window_size, num_heads, head_dim, dtype=torch.float16):
self.window_size = window_size
self.num_heads = num_heads
self.head_dim = head_dim
# Pre-allocate ring buffer
self.k_cache = torch.zeros(
(num_heads, window_size, head_dim), dtype=dtype, device='cuda'
)
self.v_cache = torch.zeros(
(num_heads, window_size, head_dim), dtype=dtype, device='cuda'
)
self.position = 0 # Current write position in ring buffer
self.length = 0 # Number of valid entries
def update(self, k_new, v_new):
"""
k_new, v_new: [num_heads, 1, head_dim] (single new token)
"""
idx = self.position % self.window_size
self.k_cache[:, idx, :] = k_new[:, 0, :]
self.v_cache[:, idx, :] = v_new[:, 0, :]
self.position += 1
self.length = min(self.length + 1, self.window_size)
def get_kv(self):
"""Return valid K, V entries in correct order."""
if self.length < self.window_size:
return self.k_cache[:, :self.length, :], self.v_cache[:, :self.length, :]
# Ring buffer: reorder so oldest entry is first
idx = self.position % self.window_size
order = list(range(idx, self.window_size)) + list(range(0, idx))
return self.k_cache[:, order, :], self.v_cache[:, order, :]
Memory savings for Mistral 7B (, , , BF16):
Per layer KV cache at :
All 32 layers:
Compare with full attention at 128K context:
KV Cache Memory: Sliding Window vs Full Attention (Mistral 7B, BF16)
(MB)5. Block Sparse Attention (BigBird, Longformer)
Block sparse attention replaces the dense attention matrix with a structured sparsity pattern composed of local blocks, global tokens, and random connections. This achieves complexity while maintaining strong long-range modeling.
BigBird Sparsity Pattern
BigBird combines three attention patterns:
- Local window: Each token attends to neighboring tokens (similar to sliding window).
- Global tokens: A small set of tokens attend to and are attended by all tokens. These are typically the first few tokens (CLS, BOS) or learned sentinel tokens.
- Random connections: Each token randomly attends to additional tokens. This provides shortcut paths in the attention graph, reducing the diameter from to .
def create_bigbird_mask(
seq_len, window_size=64, num_global=16, num_random=8,
device='cuda', dtype=torch.float16
):
"""
BigBird block sparse mask combining local, global, and random patterns.
"""
mask = torch.full((seq_len, seq_len), float('-inf'), device=device, dtype=dtype)
# 1. Local window (band around diagonal)
for i in range(seq_len):
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
mask[i, start:end] = 0.0
# 2. Global tokens (first num_global tokens)
mask[:num_global, :] = 0.0 # Global tokens attend to everything
mask[:, :num_global] = 0.0 # Everything attends to global tokens
# 3. Random connections
for i in range(seq_len):
random_indices = torch.randint(0, seq_len, (num_random,), device=device)
mask[i, random_indices] = 0.0
# Apply causal constraint if needed (for decoder models)
causal = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
mask[causal] = float('-inf')
return mask
Longformer Pattern
Longformer is similar but drops the random connections and instead uses a combination of sliding window attention (for most layers) and global attention on specific tokens:
def create_longformer_mask(
seq_len, window_size=512, global_token_indices=None,
device='cuda', dtype=torch.float16
):
"""
Longformer: sliding window + global attention on selected tokens.
global_token_indices: list of token positions that get global attention.
"""
if global_token_indices is None:
global_token_indices = [0] # Default: only CLS token is global
mask = torch.full((seq_len, seq_len), float('-inf'), device=device, dtype=dtype)
# Sliding window for all tokens
for i in range(seq_len):
start = max(0, i - window_size // 2)
end = min(seq_len, i + window_size // 2 + 1)
mask[i, start:end] = 0.0
# Global tokens
for g in global_token_indices:
mask[g, :] = 0.0 # Global token attends to all
mask[:, g] = 0.0 # All attend to global token
return mask
Block-Level Implementation for Efficiency
Practical implementations operate on blocks, not individual tokens. Divide the sequence into blocks of size (typically 64 or 128). Each query block attends to a subset of key blocks:
def block_sparse_attention(Q, K, V, block_size=64, sparsity_map=None):
"""
Block-sparse attention with explicit sparsity map.
Q, K, V: [batch, heads, seq_len, d]
sparsity_map: dict mapping query_block_idx to list of key_block_idx
"""
batch, heads, seq_len, d = Q.shape
num_blocks = seq_len // block_size
output = torch.zeros_like(Q)
# Reshape into blocks: [batch, heads, num_blocks, block_size, d]
Q_blocks = Q.view(batch, heads, num_blocks, block_size, d)
K_blocks = K.view(batch, heads, num_blocks, block_size, d)
V_blocks = V.view(batch, heads, num_blocks, block_size, d)
O_blocks = output.view(batch, heads, num_blocks, block_size, d)
for q_idx in range(num_blocks):
if sparsity_map is not None:
k_indices = sparsity_map[q_idx]
else:
# Default: local window of 3 blocks + first block (global)
k_indices = list(set([0] + list(range(
max(0, q_idx - 1), min(num_blocks, q_idx + 2)
))))
# Gather relevant K, V blocks
K_sel = torch.cat([K_blocks[:, :, ki:ki+1] for ki in k_indices], dim=2)
V_sel = torch.cat([V_blocks[:, :, ki:ki+1] for ki in k_indices], dim=2)
# K_sel: [batch, heads, num_selected_blocks, block_size, d]
# Reshape for matmul
K_flat = K_sel.reshape(batch, heads, -1, d) # [B, H, num_sel*bs, d]
V_flat = V_sel.reshape(batch, heads, -1, d)
q = Q_blocks[:, :, q_idx] # [B, H, block_size, d]
scores = torch.matmul(q, K_flat.transpose(-2, -1)) / (d ** 0.5)
weights = F.softmax(scores, dim=-1)
O_blocks[:, :, q_idx] = torch.matmul(weights, V_flat)
return output
Sparsity and FLOP Reduction
For a sequence of tokens with block size and key blocks per query block:
For BigBird with local window blocks, global block, and random block: . Full attention has blocks. The ratio:
For and , : speedup = fewer FLOPs.
Attention Patterns: Sparsity and FLOPs Comparison (n=16384, d=128)
| Pattern | Blocks Attended Per Query | Total FLOPs (GFLOPs) | vs Full Attention |
|---|---|---|---|
| Full bidirectional | 256 | 137.4 | 1.0x |
| Causal | 128 (avg) | 68.7 | 0.5x |
| Sliding window W=512 | 8 | 4.29 | 0.031x |
| BigBird (w=3, g=1, r=1) | 5 | 2.68 | 0.020x |
| Longformer (w=8, g=1) | 9 | 4.83 | 0.035x |
6. Custom Masks for Production Scenarios
Beyond the standard patterns, production inference systems require custom masks for two critical scenarios: variable-length batching (padding tokens) and multi-document processing (preventing cross-document attention).
6.1 Padding Masks for Variable-Length Batching
Real requests have different lengths. When batching multiple sequences, shorter sequences are padded to the maximum length. Padding tokens must not receive or contribute attention.
def create_padding_mask(seq_lengths, max_len, device='cuda', dtype=torch.float16):
"""
Create a padding mask for variable-length batched sequences.
seq_lengths: [batch_size] tensor of actual sequence lengths
max_len: maximum sequence length (padding target)
Returns: [batch_size, 1, max_len, max_len] additive mask
"""
batch_size = seq_lengths.size(0)
# Row mask: query positions beyond seq_len are invalid
positions = torch.arange(max_len, device=device).unsqueeze(0) # [1, max_len]
valid = positions < seq_lengths.unsqueeze(1) # [batch, max_len]
# Key mask: key positions beyond seq_len should not be attended to
key_valid = valid.unsqueeze(2) # [batch, max_len, 1] -> key dimension
query_valid = valid.unsqueeze(1) # [batch, 1, max_len] -> query dimension
# Both query and key must be valid
# Actually: only key needs to be valid (queries at padding positions
# will be overwritten anyway, but masking them avoids NaN in softmax)
combined = key_valid.unsqueeze(1) # [batch, 1, max_len, 1] for broadcasting
# Expand to [batch, 1, max_len, max_len]
key_mask = valid.unsqueeze(1).unsqueeze(2).expand(-1, 1, max_len, -1)
mask = torch.where(
key_mask,
torch.tensor(0.0, device=device, dtype=dtype),
torch.tensor(float('-inf'), device=device, dtype=dtype),
)
return mask
# Example: batch of 3 sequences with lengths [3, 5, 2], max_len=5
# Sequence 0: attends to positions [0,1,2], masks [3,4]
# Sequence 1: attends to positions [0,1,2,3,4], masks nothing
# Sequence 2: attends to positions [0,1], masks [2,3,4]
Combining Padding with Causal Mask
In practice, you combine the causal mask with the padding mask:
def create_causal_padding_mask(seq_lengths, max_len, device='cuda', dtype=torch.float16):
"""Combined causal + padding mask."""
causal = create_causal_mask(max_len, device=device, dtype=dtype) # [max_len, max_len]
padding = create_padding_mask(seq_lengths, max_len, device=device, dtype=dtype)
# Broadcasting: causal is [n, n], padding is [batch, 1, n, n]
# Result: [batch, 1, n, n] — element-wise minimum (most restrictive)
combined = causal.unsqueeze(0).unsqueeze(0) + padding
# Since both use -inf for blocking, adding two -inf is still -inf
# and adding 0 + (-inf) is -inf (correct: either mask blocks)
# But 0 + 0 = 0 (both allow). This works because we want OR of blocks.
# Actually: min(0, -inf) = -inf. We want: block if EITHER mask blocks.
# With additive masks: combined = causal + padding has -inf if either is -inf.
# This is correct.
return combined
For production systems, avoid constructing the combined mask tensor. Instead, pass the sequence lengths to the kernel and let it compute the mask predicate on-the-fly. FlashAttention’s variable-length interface (flash_attn_varlen_func) takes cu_seqlens (cumulative sequence lengths) and handles both padding and causal masking internally without materializing any mask tensor. This saves both memory and the time to construct and transfer the mask.
6.2 Multi-Document Masks: Preventing Cross-Document Attention
When packing multiple documents into a single sequence for training efficiency (document packing / sequence packing), you must prevent tokens in one document from attending to tokens in another. Without this mask, the model can attend across document boundaries, which leaks information and degrades training quality.
def create_document_mask(doc_boundaries, seq_len, device='cuda', dtype=torch.float16):
"""
Prevent cross-document attention in packed sequences.
doc_boundaries: list of (start, end) tuples for each document.
Example: [(0, 100), (100, 250), (250, 400)] for 3 packed documents.
"""
mask = torch.full((seq_len, seq_len), float('-inf'), device=device, dtype=dtype)
for start, end in doc_boundaries:
mask[start:end, start:end] = 0.0
return mask
def create_causal_document_mask(doc_boundaries, seq_len, device='cuda', dtype=torch.float16):
"""Causal + document boundary mask."""
doc_mask = create_document_mask(doc_boundaries, seq_len, device=device, dtype=dtype)
causal_mask = create_causal_mask(seq_len, device=device, dtype=dtype)
# Combine: block if either mask blocks
# Both are additive with -inf, so addition works correctly
combined = doc_mask + causal_mask
# Clamp to avoid -2*inf issues (though PyTorch handles -inf + -inf = -inf)
combined = combined.clamp(min=float('-inf'))
return combined
Visualization for 3 documents packed into a single sequence:
Doc A: positions 0-3, Doc B: positions 4-7, Doc C: positions 8-11
Causal + Document mask (0 = attend, X = blocked):
0 1 2 3 4 5 6 7 8 9 10 11
Pos 0: [ 0 X X X X X X X X X X X ]
Pos 1: [ 0 0 X X X X X X X X X X ]
Pos 2: [ 0 0 0 X X X X X X X X X ]
Pos 3: [ 0 0 0 0 X X X X X X X X ]
Pos 4: [ X X X X 0 X X X X X X X ] <-- Doc B starts, no cross-doc
Pos 5: [ X X X X 0 0 X X X X X X ]
Pos 6: [ X X X X 0 0 0 X X X X X ]
Pos 7: [ X X X X 0 0 0 0 X X X X ]
Pos 8: [ X X X X X X X X 0 X X X ] <-- Doc C starts
Pos 9: [ X X X X X X X X 0 0 X X ]
Pos10: [ X X X X X X X X 0 0 0 X ]
Pos11: [ X X X X X X X X 0 0 0 0 ]
Each document forms an independent causal block along the diagonal. This is equivalent to processing each document separately but is more compute-efficient because it fills the GPU with a single large batch.
6.3 FlashAttention Variable-Length Interface
FlashAttention provides flash_attn_varlen_func that handles document packing natively:
from flash_attn import flash_attn_varlen_func
def packed_attention(q, k, v, cu_seqlens, max_seqlen):
"""
Attention over packed sequences using FlashAttention.
q, k, v: [total_tokens, num_heads, head_dim] (packed, no padding)
cu_seqlens: [num_docs + 1] cumulative sequence lengths.
Example: [0, 100, 250, 400] for 3 docs of length 100, 150, 150.
max_seqlen: maximum document length in the batch.
"""
output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
causal=True,
)
return output
This interface is strictly superior to the mask-based approach:
- No padding tokens waste compute.
- No mask tensor is materialized.
- The kernel handles document boundaries internally.
- Memory usage is , not .
Packed vs Padded Attention: Memory and Compute (4 docs, lengths 512-4096)
| Method | Total Tokens Processed | Peak Memory (MB) | Wall Time (ms) |
|---|---|---|---|
| Padded batch (max_len=4096) | 16,384 | 2,048 | 3.2 |
| Packed (varlen) | 8,704 | 1,088 | 1.7 |
| Savings | 47% fewer tokens | 47% less memory | 47% faster |
7. Performance: Dense vs Sparse Masks and Kernel Support
The choice of masking pattern has direct consequences for compute cost, memory usage, and which kernels can efficiently execute the operation.
7.1 Dense Masks in Standard Attention
With naive (non-Flash) attention, the mask is applied element-wise to the materialized score matrix. The cost of the mask itself is negligible compared to the matmul. The bottleneck is the memory for the score matrix.
Dense masks (causal, bidirectional) add zero overhead to the score computation. The mask is just an element-wise add or a conditional write. Sparse masks stored densely (an tensor with mostly ) are equally cheap to apply but do not save any FLOPs because the full product is still computed.
7.2 Sparse Masks with Specialized Kernels
To actually save FLOPs with sparse masks, you need kernels that skip the masked-out blocks entirely. This requires:
- A block-level sparsity pattern (individual element sparsity is impractical on GPUs).
- A kernel that iterates over only the non-zero blocks.
- Block sizes aligned with tensor core tile sizes (16, 32, 64, or 128).
Triton makes writing block-sparse attention kernels accessible:
import triton
import triton.language as tl
@triton.jit
def block_sparse_attention_kernel(
Q, K, V, Out,
block_table, # [num_query_blocks, max_kv_blocks_per_query]
num_kv_blocks, # [num_query_blocks] actual number of kv blocks per query block
stride_qb, stride_qh, stride_qm, stride_qd,
stride_kb, stride_kh, stride_kn, stride_kd,
stride_vb, stride_vh, stride_vn, stride_vd,
stride_ob, stride_oh, stride_om, stride_od,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
scale: tl.constexpr,
):
"""
Block-sparse attention: each query block attends to a subset of KV blocks
specified by block_table.
"""
pid_m = tl.program_id(0) # Query block index
pid_bh = tl.program_id(1) # Batch * head index
# Load query block
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, BLOCK_D)
q = tl.load(Q + pid_bh * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qd)
# Online softmax accumulators
m_i = tl.full([BLOCK_M], float('-inf'), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)
# Iterate over KV blocks for this query block
n_kv = tl.load(num_kv_blocks + pid_m)
for kv_idx in range(0, n_kv):
kv_block_id = tl.load(block_table + pid_m * max_kv_blocks + kv_idx)
offs_n = kv_block_id * BLOCK_N + tl.arange(0, BLOCK_N)
k = tl.load(K + pid_bh * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kd)
v = tl.load(V + pid_bh * stride_vh + offs_n[:, None] * stride_vn + offs_d[None, :] * stride_vd)
# Compute attention scores for this block
s = tl.dot(q, tl.trans(k)) * scale
# Online softmax update
m_ij = tl.max(s, axis=1)
m_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_new)
beta = tl.exp(m_ij - m_new)
l_i = alpha * l_i + tl.sum(beta[:, None] * tl.exp(s - m_ij[:, None]), axis=1)
acc = alpha[:, None] * acc + tl.dot(tl.exp(s - m_ij[:, None]) * beta[:, None], v)
m_i = m_new
# Final normalization
acc = acc / l_i[:, None]
tl.store(Out + pid_bh * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_od, acc)
7.3 FlashAttention’s Native Mask Support
FlashAttention 2 supports three patterns natively (without materializing a mask):
| Pattern | Flag | Compute Savings |
|---|---|---|
| Full (bidirectional) | default | None |
| Causal | causal=True | ~2x (skips upper triangle tiles) |
| Variable-length | flash_attn_varlen_func | Proportional to padding saved |
FlashAttention 2 added sliding window support:
from flash_attn import flash_attn_func
output = flash_attn_func(
q, k, v,
causal=True,
window_size=(4096, 0), # (left_window, right_window)
# right_window=0 means causal (no future tokens)
# left_window=4096 means attend to 4096 tokens back
)
The kernel skips tiles that fall entirely outside the sliding window, providing genuine compute savings proportional to the sparsity.
FlashAttention 3 (Hopper GPUs) extends this with warp-specialized pipelining that overlaps the tile skipping decision with the compute of active tiles, further reducing the overhead of sparse patterns.
For arbitrary custom masks (not causal, not sliding window), FlashAttention currently requires passing a dense mask tensor and cannot skip tiles. The mask is loaded and applied per-tile but all tiles are still computed. If you need true compute savings from an arbitrary sparse pattern, you must write a custom Triton kernel with an explicit block schedule. This is a significant implementation investment but can yield 10-50x speedups for highly sparse patterns on long sequences.
7.4 Choosing the Right Pattern
Attention Kernel Throughput on H100 (seq_len=8192, d=128, BF16)
(TFLOPS)Key observations:
-
FlashAttention with causal or sliding window achieves near-peak throughput because the tile-skipping logic is integrated into the kernel pipeline. The per-tile TFLOPS is similar to full attention; you just compute fewer tiles.
-
Block sparse Triton kernels show lower per-tile TFLOPS because of the irregular memory access pattern (loading K/V blocks from non-contiguous memory) and the overhead of the block table indirection. However, total wall time can be much lower because far fewer tiles are computed.
-
Custom masks passed densely to FlashAttention achieve full TFLOPS but no tile skipping. The mask is loaded per-tile (adding bandwidth overhead) but all tiles are computed.
8. Practical Implementation: Putting It All Together
Here is a complete multi-pattern attention module that selects the optimal kernel based on the requested pattern:
import torch
import torch.nn as nn
from enum import Enum
class AttentionPattern(Enum):
FULL = "full"
CAUSAL = "causal"
SLIDING_WINDOW = "sliding_window"
CAUSAL_SLIDING = "causal_sliding"
PREFIX_LM = "prefix_lm"
class MultiPatternAttention(nn.Module):
def __init__(self, d_model, n_heads, pattern, window_size=None, prefix_len=None):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.pattern = pattern
self.window_size = window_size
self.prefix_len = prefix_len
self.scale = self.d_head ** -0.5
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
def forward(self, x, cu_seqlens=None):
B, N, _ = x.shape
q = self.W_q(x).view(B, N, self.n_heads, self.d_head).transpose(1, 2)
k = self.W_k(x).view(B, N, self.n_heads, self.d_head).transpose(1, 2)
v = self.W_v(x).view(B, N, self.n_heads, self.d_head).transpose(1, 2)
if self.pattern == AttentionPattern.FULL:
return self._full_attention(q, k, v)
elif self.pattern == AttentionPattern.CAUSAL:
return self._causal_attention(q, k, v, N)
elif self.pattern == AttentionPattern.SLIDING_WINDOW:
return self._sliding_attention(q, k, v, N)
elif self.pattern == AttentionPattern.CAUSAL_SLIDING:
return self._causal_sliding_attention(q, k, v, N)
elif self.pattern == AttentionPattern.PREFIX_LM:
return self._prefix_lm_attention(q, k, v, N)
def _full_attention(self, q, k, v):
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
weights = torch.softmax(scores, dim=-1)
out = torch.matmul(weights, v)
B, H, N, D = out.shape
return self.W_o(out.transpose(1, 2).reshape(B, N, self.d_model))
def _causal_attention(self, q, k, v, seq_len):
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
mask = torch.triu(
torch.full((seq_len, seq_len), float('-inf'), device=q.device, dtype=q.dtype),
diagonal=1
)
scores = scores + mask
weights = torch.softmax(scores, dim=-1)
out = torch.matmul(weights, v)
B, H, N, D = out.shape
return self.W_o(out.transpose(1, 2).reshape(B, N, self.d_model))
def _sliding_attention(self, q, k, v, seq_len):
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
mask = create_sliding_window_mask_fast(seq_len, self.window_size, q.device, q.dtype)
scores = scores + mask
weights = torch.softmax(scores, dim=-1)
out = torch.matmul(weights, v)
B, H, N, D = out.shape
return self.W_o(out.transpose(1, 2).reshape(B, N, self.d_model))
def _causal_sliding_attention(self, q, k, v, seq_len):
"""Sliding window with causal constraint (Mistral-style)."""
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
row_idx = torch.arange(seq_len, device=q.device).unsqueeze(1)
col_idx = torch.arange(seq_len, device=q.device).unsqueeze(0)
valid = (col_idx <= row_idx) & (col_idx >= row_idx - self.window_size + 1)
mask = torch.where(
valid,
torch.tensor(0.0, device=q.device, dtype=q.dtype),
torch.tensor(float('-inf'), device=q.device, dtype=q.dtype),
)
scores = scores + mask
weights = torch.softmax(scores, dim=-1)
out = torch.matmul(weights, v)
B, H, N, D = out.shape
return self.W_o(out.transpose(1, 2).reshape(B, N, self.d_model))
def _prefix_lm_attention(self, q, k, v, seq_len):
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
mask = create_prefix_lm_mask_fast(seq_len, self.prefix_len, q.device, q.dtype)
scores = scores + mask
weights = torch.softmax(scores, dim=-1)
out = torch.matmul(weights, v)
B, H, N, D = out.shape
return self.W_o(out.transpose(1, 2).reshape(B, N, self.d_model))
9. Summary and Decision Framework
Attention Masking Pattern Decision Matrix
| Pattern | Use Case | Complexity | KV Cache | FlashAttention Support |
|---|---|---|---|---|
| Causal | Autoregressive LLMs | O(n^2/2) | O(n) | Native (is_causal) |
| Bidirectional | Encoders, embeddings | O(n^2) | N/A | Native (default) |
| Sliding window | Long-context LLMs | O(nW) | O(W) | Native (window_size) |
| Block sparse | Very long docs | O(nk*B_s) | Varies | Custom kernel needed |
| Prefix-LM | Encoder-decoder hybrid | O(p*n + (n-p)^2/2) | O(n) | Custom mask |
| Document packing | Training efficiency | O(sum(d_i^2)) | N/A | Native (varlen) |
The masking pattern is not a minor configuration detail. It determines the asymptotic complexity of attention, the KV cache memory requirements, and which hardware-optimized kernels can be used. For sequences up to 8K tokens, causal attention with FlashAttention is sufficient. For 8K-128K tokens, sliding window attention (Mistral-style) provides constant KV cache memory with minimal quality loss. Beyond 128K tokens, block sparse patterns become necessary, but require custom kernels that are significantly more complex to implement and maintain.
Verify the FLOP reduction claim for causal vs bidirectional attention. For a sequence of tokens with : bidirectional computes FLOPs per head for alone. Causal computes the lower triangle: FLOPs. The ratio is , confirming the 2x reduction. The slight deviation from exactly 0.5 comes from the diagonal itself: the causal mask includes diagonal entries that a strict lower-triangle would exclude. This detail is often glossed over in approximations.