Multi-Head Attention (MHA) dedicates separate K, V projections to each head. Multi-Query Attention (MQA) shares one K, V across all heads. Grouped Query Attention (GQA) interpolates between them. Understanding these trade-offs is essential for model deployment.

Attention Variant Comparison

def attention_variants_kv_size(
    num_heads: int,
    head_dim: int,
    seq_len: int,
    dtype_bytes: int = 2
) -> dict:
    """Calculate KV cache size for different attention variants."""
    
    # MHA: Each head has its own K, V
    mha_kv = 2 * num_heads * seq_len * head_dim * dtype_bytes
    
    # MQA: Single K, V shared across all heads  
    mqa_kv = 2 * 1 * seq_len * head_dim * dtype_bytes
    
    # GQA with g groups: g sets of K, V
    def gqa_kv(num_groups):
        return 2 * num_groups * seq_len * head_dim * dtype_bytes
    
    return {
        'MHA': mha_kv,
        'GQA-8': gqa_kv(8),  # 8 groups
        'GQA-4': gqa_kv(4),  # 4 groups
        'GQA-2': gqa_kv(2),  # 2 groups
        'MQA': mqa_kv,
    }

# Example: 32 heads, 128 head_dim, 4096 seq_len
sizes = attention_variants_kv_size(32, 128, 4096)
# MHA: 64 MB, GQA-8: 16 MB, GQA-4: 8 MB, MQA: 2 MB

KV Cache Size per Layer (32 heads, seq=4096)

(MB)
MHA (32 KV heads)
64 MB
GQA-8 (8 groups)
16 MB
GQA-4 (4 groups)
8 MB
GQA-2 (2 groups)
4 MB
MQA (1 KV head)
2 MB

GQA Implementation

class GroupedQueryAttention(nn.Module):
    def __init__(
        self,
        hidden_dim: int,
        num_heads: int,      # Query heads
        num_kv_heads: int,   # KV heads (groups)
        head_dim: int,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.num_groups = num_heads // num_kv_heads
        
        # Q projection: full heads
        self.q_proj = nn.Linear(hidden_dim, num_heads * head_dim)
        
        # K, V projections: reduced heads
        self.k_proj = nn.Linear(hidden_dim, num_kv_heads * head_dim)
        self.v_proj = nn.Linear(hidden_dim, num_kv_heads * head_dim)
        
        self.o_proj = nn.Linear(num_heads * head_dim, hidden_dim)
    
    def forward(self, x, kv_cache=None):
        batch, seq_len, _ = x.shape
        
        # Project Q, K, V
        q = self.q_proj(x).view(batch, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
        v = self.v_proj(x).view(batch, seq_len, self.num_kv_heads, self.head_dim)
        
        # Expand K, V to match Q heads
        # [batch, seq, num_kv_heads, head_dim] -> [batch, seq, num_heads, head_dim]
        k = k.repeat_interleave(self.num_groups, dim=2)
        v = v.repeat_interleave(self.num_groups, dim=2)
        
        # Standard attention computation
        q = q.transpose(1, 2)  # [batch, num_heads, seq, head_dim]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        
        out = out.transpose(1, 2).reshape(batch, seq_len, -1)
        return self.o_proj(out)
💡 Memory vs Compute

The repeat_interleave operation is a memory view in PyTorch—no actual data copying. GQA’s memory savings are real; the “expansion” is free.

Quality vs Efficiency Trade-off

📊

Attention Variant Quality (Llama-scale models)

VariantKV HeadsPerplexityKV MemoryThroughput
MHA 32 5.12 1.0x 1.0x
GQA-8 8 5.14 0.25x 2.1x
GQA-4 4 5.18 0.125x 2.8x
GQA-2 2 5.31 0.0625x 3.2x
MQA 1 5.67 0.03x 3.5x
Note: 7B model, trained on same data, evaluated on validation set

The sweet spot: GQA with 4-8 groups preserves quality while significantly reducing memory.

Kernel Optimization for GQA

Efficient GQA kernels avoid explicit K, V expansion:

// Optimized GQA kernel that handles grouping implicitly
template<int NUM_HEADS, int NUM_KV_HEADS, int HEAD_DIM>
__global__ void gqa_attention_kernel(
    const half* __restrict__ q,     // [batch, seq_q, num_heads, head_dim]
    const half* __restrict__ k,     // [batch, seq_kv, num_kv_heads, head_dim]
    const half* __restrict__ v,     // [batch, seq_kv, num_kv_heads, head_dim]
    half* __restrict__ output       // [batch, seq_q, num_heads, head_dim]
) {
    constexpr int HEADS_PER_GROUP = NUM_HEADS / NUM_KV_HEADS;
    
    const int batch_idx = blockIdx.z;
    const int head_idx = blockIdx.y;
    const int q_idx = blockIdx.x;
    
    // Map query head to KV head
    const int kv_head_idx = head_idx / HEADS_PER_GROUP;
    
    // Load query for this head
    half q_vec[HEAD_DIM];
    load_vector(q_vec, &q[batch_idx][q_idx][head_idx][0]);
    
    // Compute attention over K, V for the corresponding KV head
    float acc[HEAD_DIM] = {0};
    float max_score = -INFINITY;
    float sum_exp = 0;
    
    for (int kv_idx = 0; kv_idx < seq_kv; kv_idx++) {
        // K, V use kv_head_idx, not head_idx
        half k_vec[HEAD_DIM];
        load_vector(k_vec, &k[batch_idx][kv_idx][kv_head_idx][0]);
        
        float score = dot_product(q_vec, k_vec) / sqrtf(HEAD_DIM);
        
        // Online softmax
        float new_max = fmaxf(max_score, score);
        float scale = expf(max_score - new_max);
        sum_exp = sum_exp * scale + expf(score - new_max);
        
        // Accumulate weighted V
        half v_vec[HEAD_DIM];
        load_vector(v_vec, &v[batch_idx][kv_idx][kv_head_idx][0]);
        
        float weight = expf(score - new_max);
        for (int d = 0; d < HEAD_DIM; d++) {
            acc[d] = acc[d] * scale + weight * __half2float(v_vec[d]);
        }
        
        max_score = new_max;
    }
    
    // Normalize and store
    for (int d = 0; d < HEAD_DIM; d++) {
        output[batch_idx][q_idx][head_idx][d] = __float2half(acc[d] / sum_exp);
    }
}

Choosing the Right Variant

Decision framework:

def recommend_attention_variant(
    deployment_scenario: str,
    quality_sensitivity: str,
    memory_constraint_gb: float,
    target_throughput: float
) -> str:
    """
    Recommend attention variant based on deployment requirements.
    """
    
    if deployment_scenario == "batch_inference":
        # High throughput, can tolerate slight quality loss
        if memory_constraint_gb < 40:
            return "GQA-4"  # Good balance
        else:
            return "GQA-8"  # Better quality
    
    elif deployment_scenario == "interactive":
        # Latency-sensitive, single requests
        if quality_sensitivity == "high":
            return "GQA-8"
        else:
            return "GQA-4"
    
    elif deployment_scenario == "long_context":
        # Context length is priority
        if memory_constraint_gb < 24:
            return "MQA"  # Maximum context
        else:
            return "GQA-4"
    
    return "GQA-8"  # Default recommendation

Conclusion

GQA represents an elegant interpolation between MHA’s quality and MQA’s efficiency. For most deployments, GQA with 4-8 KV heads provides the optimal balance:

  • 4 KV heads: 8x KV cache reduction, ~0.06 perplexity increase
  • 8 KV heads: 4x KV cache reduction, ~0.02 perplexity increase

Modern models (Llama 2, Mistral) increasingly adopt GQA as the default attention mechanism.