The attention mechanism’s KV cache grows linearly with sequence length and batch size, dominating memory in long-context inference. Different attention variants trade model quality for memory efficiency. Let’s analyze them quantitatively.
Multi-Head Attention (MHA)
The standard attention formulation:
class MultiHeadAttention:
"""
Standard MHA: Each head has independent Q, K, V projections.
"""
def __init__(self, d_model: int, num_heads: int):
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Separate projections for each head
self.W_q = nn.Linear(d_model, d_model) # [d_model, num_heads * head_dim]
self.W_k = nn.Linear(d_model, d_model) # [d_model, num_heads * head_dim]
self.W_v = nn.Linear(d_model, d_model) # [d_model, num_heads * head_dim]
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, L, D = x.shape
Q = self.W_q(x).view(B, L, self.num_heads, self.head_dim)
K = self.W_k(x).view(B, L, self.num_heads, self.head_dim)
V = self.W_v(x).view(B, L, self.num_heads, self.head_dim)
# KV cache size: 2 * batch * seq_len * num_heads * head_dim * dtype
if kv_cache is not None:
K = torch.cat([kv_cache[0], K], dim=1)
V = torch.cat([kv_cache[1], V], dim=1)
# Attention: [B, num_heads, L, L]
scores = torch.einsum('blhd,bshd->bhls', Q, K) / math.sqrt(self.head_dim)
attn = F.softmax(scores, dim=-1)
out = torch.einsum('bhls,bshd->blhd', attn, V)
return self.W_o(out.reshape(B, L, D)), (K, V)
Memory per token per layer:
KV cache = 2 × num_heads × head_dim × dtype_bytes
= 2 × 32 × 128 × 2 = 16 KB (for 32 heads, FP16)
Multi-Query Attention (MQA)
MQA uses a single K, V head shared across all query heads:
class MultiQueryAttention:
"""
MQA: Single K, V head shared by all Q heads.
Proposed by Shazeer (2019) for faster inference.
"""
def __init__(self, d_model: int, num_heads: int):
self.num_heads = num_heads
self.head_dim = d_model // num_heads
# Full projection for Q
self.W_q = nn.Linear(d_model, d_model)
# Single head projection for K, V
self.W_k = nn.Linear(d_model, self.head_dim) # Only 1 head!
self.W_v = nn.Linear(d_model, self.head_dim) # Only 1 head!
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, L, D = x.shape
Q = self.W_q(x).view(B, L, self.num_heads, self.head_dim)
K = self.W_k(x).view(B, L, 1, self.head_dim) # [B, L, 1, head_dim]
V = self.W_v(x).view(B, L, 1, self.head_dim)
# Broadcast K, V to all heads
K = K.expand(-1, -1, self.num_heads, -1)
V = V.expand(-1, -1, self.num_heads, -1)
# KV cache is 32x smaller!
if kv_cache is not None:
K = torch.cat([kv_cache[0], K], dim=1)
V = torch.cat([kv_cache[1], V], dim=1)
# Rest is identical to MHA
scores = torch.einsum('blhd,bshd->bhls', Q, K) / math.sqrt(self.head_dim)
attn = F.softmax(scores, dim=-1)
out = torch.einsum('bhls,bshd->blhd', attn, V)
return self.W_o(out.reshape(B, L, D)), (K[:,:,0:1,:], V[:,:,0:1,:])
Memory per token per layer:
KV cache = 2 × 1 × head_dim × dtype_bytes
= 2 × 1 × 128 × 2 = 512 bytes (32x reduction!)
MQA reduces model quality by 1-3% on most benchmarks. The single KV head becomes an information bottleneck, especially for tasks requiring fine-grained token interactions.
Grouped-Query Attention (GQA)
GQA interpolates between MHA and MQA:
class GroupedQueryAttention:
"""
GQA: Groups of Q heads share K, V heads.
Used by Llama 2 70B, Llama 3, Mistral, etc.
"""
def __init__(self, d_model: int, num_heads: int, num_kv_heads: int):
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.num_groups = num_heads // num_kv_heads
self.head_dim = d_model // num_heads
self.W_q = nn.Linear(d_model, num_heads * self.head_dim)
self.W_k = nn.Linear(d_model, num_kv_heads * self.head_dim)
self.W_v = nn.Linear(d_model, num_kv_heads * self.head_dim)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, L, D = x.shape
Q = self.W_q(x).view(B, L, self.num_heads, self.head_dim)
K = self.W_k(x).view(B, L, self.num_kv_heads, self.head_dim)
V = self.W_v(x).view(B, L, self.num_kv_heads, self.head_dim)
# Expand K, V to match Q heads
# Each KV head serves num_groups Q heads
K = K.repeat_interleave(self.num_groups, dim=2) # [B, L, num_heads, head_dim]
V = V.repeat_interleave(self.num_groups, dim=2)
if kv_cache is not None:
K = torch.cat([kv_cache[0], K], dim=1)
V = torch.cat([kv_cache[1], V], dim=1)
scores = torch.einsum('blhd,bshd->bhls', Q, K) / math.sqrt(self.head_dim)
attn = F.softmax(scores, dim=-1)
out = torch.einsum('bhls,bshd->blhd', attn, V)
# Store only num_kv_heads in cache
K_cache = K[:, :, ::self.num_groups, :]
V_cache = V[:, :, ::self.num_groups, :]
return self.W_o(out.reshape(B, L, D)), (K_cache, V_cache)
Llama 2 70B configuration:
- num_heads = 64, num_kv_heads = 8
- Compression ratio: 64/8 = 8x
- Memory per token: 2 × 8 × 128 × 2 = 4 KB
Memory Comparison
KV Cache Size Comparison (Per Layer, Per Token)
| Variant | Config | Bytes/Token | Reduction |
|---|---|---|---|
| MHA | 32 heads | 16,384 | 1x (baseline) |
| GQA | 32 Q / 8 KV | 4,096 | 4x |
| GQA | 32 Q / 4 KV | 2,048 | 8x |
| MQA | 32 Q / 1 KV | 512 | 32x |
Maximum Context Length at Fixed Memory (80GB)
(K tokens)Multi-head Latent Attention (MLA)
DeepSeek-V2’s MLA compresses KV differently:
class MultiHeadLatentAttention:
"""
MLA: Compresses KV into low-rank latent space.
Key insight: K, V have low intrinsic dimensionality.
"""
def __init__(self, d_model: int, num_heads: int,
kv_latent_dim: int, rope_dim: int):
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.kv_latent_dim = kv_latent_dim # e.g., 512
self.rope_dim = rope_dim # Rotary embedding dimension
# Q projection (standard)
self.W_q = nn.Linear(d_model, d_model)
# Compress KV to latent space
self.W_kv_compress = nn.Linear(d_model, kv_latent_dim) # Down-project
# Expand from latent to K, V
self.W_k_expand = nn.Linear(kv_latent_dim, d_model)
self.W_v_expand = nn.Linear(kv_latent_dim, d_model)
# Separate RoPE keys (not compressed)
self.W_k_rope = nn.Linear(d_model, rope_dim * num_heads)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x, kv_cache=None):
B, L, D = x.shape
Q = self.W_q(x).view(B, L, self.num_heads, self.head_dim)
# Compress to latent
kv_latent = self.W_kv_compress(x) # [B, L, kv_latent_dim]
# RoPE keys (for position encoding)
k_rope = self.W_k_rope(x).view(B, L, self.num_heads, self.rope_dim)
# KV cache stores: latent + k_rope (much smaller than full K, V!)
if kv_cache is not None:
kv_latent = torch.cat([kv_cache[0], kv_latent], dim=1)
k_rope = torch.cat([kv_cache[1], k_rope], dim=1)
# Expand K, V from latent (done at attention time)
K_content = self.W_k_expand(kv_latent).view(B, -1, self.num_heads, self.head_dim)
V = self.W_v_expand(kv_latent).view(B, -1, self.num_heads, self.head_dim)
# Combine content K with RoPE K
K = torch.cat([K_content[..., :self.head_dim-self.rope_dim],
k_rope], dim=-1)
# Standard attention
scores = torch.einsum('blhd,bshd->bhls', Q, K) / math.sqrt(self.head_dim)
attn = F.softmax(scores, dim=-1)
out = torch.einsum('bhls,bshd->blhd', attn, V)
return self.W_o(out.reshape(B, L, D)), (kv_latent, k_rope)
Memory per token per layer (DeepSeek-V2):
KV cache = kv_latent_dim + rope_dim × num_heads
= 512 + 64 × 128 # Example values
≈ 8.7 KB (vs 32 KB for equivalent MHA)
MLA adds compute (latent expansion) but reduces memory bandwidth. For decode-phase where we’re memory-bound, this is often beneficial.
Quality-Memory Trade-off
Attention Variant Benchmark Results
| Variant | MMLU | HumanEval | KV Memory |
|---|---|---|---|
| MHA (baseline) | 70.2% | 67.5% | 100% |
| GQA (8 groups) | 69.8% | 66.8% | 12.5% |
| GQA (4 groups) | 69.4% | 66.1% | 6.25% |
| MQA | 68.1% | 63.4% | 3.1% |
| MLA | 70.0% | 67.2% | ~25% |
Implementation Considerations
Efficient GQA Kernel
// GQA-aware attention kernel
template<int NUM_Q_HEADS, int NUM_KV_HEADS, int HEAD_DIM>
__global__ void gqa_attention_kernel(
const half* Q, // [batch, seq_q, num_q_heads, head_dim]
const half* K_cache, // [batch, seq_kv, num_kv_heads, head_dim]
const half* V_cache, // [batch, seq_kv, num_kv_heads, head_dim]
half* output,
int seq_q, int seq_kv
) {
constexpr int HEADS_PER_GROUP = NUM_Q_HEADS / NUM_KV_HEADS;
int batch_idx = blockIdx.x;
int q_head_idx = blockIdx.y;
int kv_head_idx = q_head_idx / HEADS_PER_GROUP; // Map Q head to KV head
// Load Q for this head
half q_reg[HEAD_DIM];
load_q(Q, q_reg, batch_idx, q_head_idx);
// Iterate over K, V (using shared KV head)
float acc[HEAD_DIM] = {0};
float max_score = -INFINITY;
float sum_exp = 0;
for (int kv_pos = 0; kv_pos < seq_kv; kv_pos++) {
// Load from KV head (not Q head!)
half k_reg[HEAD_DIM], v_reg[HEAD_DIM];
load_kv(K_cache, V_cache, k_reg, v_reg, batch_idx, kv_head_idx, kv_pos);
// Compute attention score
float score = dot_product(q_reg, k_reg) / sqrtf(HEAD_DIM);
// Online softmax
float new_max = fmaxf(max_score, score);
float exp_diff = expf(max_score - new_max);
float exp_score = expf(score - new_max);
// Update accumulator
for (int d = 0; d < HEAD_DIM; d++) {
acc[d] = acc[d] * exp_diff + exp_score * __half2float(v_reg[d]);
}
sum_exp = sum_exp * exp_diff + exp_score;
max_score = new_max;
}
// Normalize and store
for (int d = 0; d < HEAD_DIM; d++) {
output[...] = __float2half(acc[d] / sum_exp);
}
}
Memory Layout for GQA
def optimal_gqa_kv_layout(num_kv_heads, head_dim, max_seq_len, max_batch):
"""
Choose memory layout optimizing for cache locality.
"""
# Option 1: [batch, seq, kv_heads, head_dim] - good for sequential access
# Option 2: [batch, kv_heads, seq, head_dim] - good for head parallelism
# For most GQA implementations, Option 1 is better because:
# - Multiple Q heads access same KV head sequentially
# - Better L2 cache utilization
return {
'k_cache_shape': (max_batch, max_seq_len, num_kv_heads, head_dim),
'v_cache_shape': (max_batch, max_seq_len, num_kv_heads, head_dim),
'memory_per_token': 2 * num_kv_heads * head_dim * 2, # FP16
}
Recommendations
| Use Case | Recommendation | Rationale |
|---|---|---|
| Quality-critical | MHA or MLA | Preserve attention capacity |
| Long context | GQA (4-8 groups) | Good quality/memory balance |
| Memory-constrained | MQA | Maximum compression |
| Inference-optimized | GQA (8 groups) | Industry standard |
For most production deployments, GQA with 8 KV heads (matching Llama 2/3 70B configuration) offers the best balance. Start there and only move to MQA if memory constraints are severe.
Conclusion
The evolution from MHA → GQA → MQA → MLA reflects the increasing importance of memory efficiency in LLM deployment. GQA has emerged as the practical winner, offering 4-8x memory reduction with less than 1% quality loss. Understanding these trade-offs is essential for architecting inference systems.