Multi-Head Attention (MHA) dedicates separate K, V projections to each head. Multi-Query Attention (MQA) shares one K, V across all heads. Grouped Query Attention (GQA) interpolates between them. Understanding these trade-offs is essential for model deployment.
Attention Variant Comparison
def attention_variants_kv_size(
num_heads: int,
head_dim: int,
seq_len: int,
dtype_bytes: int = 2
) -> dict:
"""Calculate KV cache size for different attention variants."""
# MHA: Each head has its own K, V
mha_kv = 2 * num_heads * seq_len * head_dim * dtype_bytes
# MQA: Single K, V shared across all heads
mqa_kv = 2 * 1 * seq_len * head_dim * dtype_bytes
# GQA with g groups: g sets of K, V
def gqa_kv(num_groups):
return 2 * num_groups * seq_len * head_dim * dtype_bytes
return {
'MHA': mha_kv,
'GQA-8': gqa_kv(8), # 8 groups
'GQA-4': gqa_kv(4), # 4 groups
'GQA-2': gqa_kv(2), # 2 groups
'MQA': mqa_kv,
}
# Example: 32 heads, 128 head_dim, 4096 seq_len
sizes = attention_variants_kv_size(32, 128, 4096)
# MHA: 64 MB, GQA-8: 16 MB, GQA-4: 8 MB, MQA: 2 MB
KV Cache Size per Layer (32 heads, seq=4096)
(MB)GQA Implementation
class GroupedQueryAttention(nn.Module):
def __init__(
self,
hidden_dim: int,
num_heads: int, # Query heads
num_kv_heads: int, # KV heads (groups)
head_dim: int,
):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.num_groups = num_heads // num_kv_heads
# Q projection: full heads
self.q_proj = nn.Linear(hidden_dim, num_heads * head_dim)
# K, V projections: reduced heads
self.k_proj = nn.Linear(hidden_dim, num_kv_heads * head_dim)
self.v_proj = nn.Linear(hidden_dim, num_kv_heads * head_dim)
self.o_proj = nn.Linear(num_heads * head_dim, hidden_dim)
def forward(self, x, kv_cache=None):
batch, seq_len, _ = x.shape
# Project Q, K, V
q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
v = self.v_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
# Expand K, V to match Q heads
# [batch, seq, num_kv_heads, head_dim] -> [batch, seq, num_heads, head_dim]
k = k.repeat_interleave(self.num_groups, dim=2)
v = v.repeat_interleave(self.num_groups, dim=2)
# Standard attention computation
q = q.transpose(1, 2) # [batch, num_heads, seq, head_dim]
k = k.transpose(1, 2)
v = v.transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).reshape(batch, seq_len, -1)
return self.o_proj(out)
The repeat_interleave operation is a memory view in PyTorch—no actual data copying. GQA’s memory savings are real; the “expansion” is free.
Quality vs Efficiency Trade-off
Attention Variant Quality (Llama-scale models)
| Variant | KV Heads | Perplexity | KV Memory | Throughput |
|---|---|---|---|---|
| MHA | 32 | 5.12 | 1.0x | 1.0x |
| GQA-8 | 8 | 5.14 | 0.25x | 2.1x |
| GQA-4 | 4 | 5.18 | 0.125x | 2.8x |
| GQA-2 | 2 | 5.31 | 0.0625x | 3.2x |
| MQA | 1 | 5.67 | 0.03x | 3.5x |
The sweet spot: GQA with 4-8 groups preserves quality while significantly reducing memory.
Kernel Optimization for GQA
Efficient GQA kernels avoid explicit K, V expansion:
// Optimized GQA kernel that handles grouping implicitly
template<int NUM_HEADS, int NUM_KV_HEADS, int HEAD_DIM>
__global__ void gqa_attention_kernel(
const half* __restrict__ q, // [batch, seq_q, num_heads, head_dim]
const half* __restrict__ k, // [batch, seq_kv, num_kv_heads, head_dim]
const half* __restrict__ v, // [batch, seq_kv, num_kv_heads, head_dim]
half* __restrict__ output // [batch, seq_q, num_heads, head_dim]
) {
constexpr int HEADS_PER_GROUP = NUM_HEADS / NUM_KV_HEADS;
const int batch_idx = blockIdx.z;
const int head_idx = blockIdx.y;
const int q_idx = blockIdx.x;
// Map query head to KV head
const int kv_head_idx = head_idx / HEADS_PER_GROUP;
// Load query for this head
half q_vec[HEAD_DIM];
load_vector(q_vec, &q[batch_idx][q_idx][head_idx][0]);
// Compute attention over K, V for the corresponding KV head
float acc[HEAD_DIM] = {0};
float max_score = -INFINITY;
float sum_exp = 0;
for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) {
// K, V use kv_head_idx, not head_idx
half k_vec[HEAD_DIM];
load_vector(k_vec, &k[batch_idx][kv_idx][kv_head_idx][0]);
float score = dot_product(q_vec, k_vec) / sqrtf(HEAD_DIM);
// Online softmax
float new_max = fmaxf(max_score, score);
float scale = expf(max_score - new_max);
sum_exp = sum_exp * scale + expf(score - new_max);
// Accumulate weighted V
half v_vec[HEAD_DIM];
load_vector(v_vec, &v[batch_idx][kv_idx][kv_head_idx][0]);
float weight = expf(score - new_max);
for (int d = 0; d < HEAD_DIM; d++) {
acc[d] = acc[d] * scale + weight * __half2float(v_vec[d]);
}
max_score = new_max;
}
// Normalize and store
for (int d = 0; d < HEAD_DIM; d++) {
output[batch_idx][q_idx][head_idx][d] = __float2half(acc[d] / sum_exp);
}
}
Choosing the Right Variant
Decision framework:
def recommend_attention_variant(
deployment_scenario: str,
quality_sensitivity: str,
memory_constraint_gb: float,
target_throughput: float
) -> str:
"""
Recommend attention variant based on deployment requirements.
"""
if deployment_scenario == "batch_inference":
# High throughput, can tolerate slight quality loss
if memory_constraint_gb < 40:
return "GQA-4" # Good balance
else:
return "GQA-8" # Better quality
elif deployment_scenario == "interactive":
# Latency-sensitive, single requests
if quality_sensitivity == "high":
return "GQA-8"
else:
return "GQA-4"
elif deployment_scenario == "long_context":
# Context length is priority
if memory_constraint_gb < 24:
return "MQA" # Maximum context
else:
return "GQA-4"
return "GQA-8" # Default recommendation
Conclusion
GQA represents an elegant interpolation between MHA’s quality and MQA’s efficiency. For most deployments, GQA with 4-8 KV heads provides the optimal balance:
- 4 KV heads: 8x KV cache reduction, ~0.06 perplexity increase
- 8 KV heads: 4x KV cache reduction, ~0.02 perplexity increase
Modern models (Llama 2, Mistral) increasingly adopt GQA as the default attention mechanism.