The attention mechanism’s KV cache grows linearly with sequence length and batch size, dominating memory in long-context inference. Different attention variants trade model quality for memory efficiency. Let’s analyze them quantitatively.

Multi-Head Attention (MHA)

The standard attention formulation:

class MultiHeadAttention:
    """
    Standard MHA: Each head has independent Q, K, V projections.
    """
    def __init__(self, d_model: int, num_heads: int):
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # Separate projections for each head
        self.W_q = nn.Linear(d_model, d_model)  # [d_model, num_heads * head_dim]
        self.W_k = nn.Linear(d_model, d_model)  # [d_model, num_heads * head_dim]
        self.W_v = nn.Linear(d_model, d_model)  # [d_model, num_heads * head_dim]
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x, kv_cache=None):
        B, L, D = x.shape
        
        Q = self.W_q(x).view(B, L, self.num_heads, self.head_dim)
        K = self.W_k(x).view(B, L, self.num_heads, self.head_dim)
        V = self.W_v(x).view(B, L, self.num_heads, self.head_dim)
        
        # KV cache size: 2 * batch * seq_len * num_heads * head_dim * dtype
        if kv_cache is not None:
            K = torch.cat([kv_cache[0], K], dim=1)
            V = torch.cat([kv_cache[1], V], dim=1)
        
        # Attention: [B, num_heads, L, L]
        scores = torch.einsum('blhd,bshd->bhls', Q, K) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)
        out = torch.einsum('bhls,bshd->blhd', attn, V)
        
        return self.W_o(out.reshape(B, L, D)), (K, V)

Memory per token per layer:

KV cache = 2 × num_heads × head_dim × dtype_bytes
         = 2 × 32 × 128 × 2 = 16 KB (for 32 heads, FP16)

Multi-Query Attention (MQA)

MQA uses a single K, V head shared across all query heads:

class MultiQueryAttention:
    """
    MQA: Single K, V head shared by all Q heads.
    Proposed by Shazeer (2019) for faster inference.
    """
    def __init__(self, d_model: int, num_heads: int):
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # Full projection for Q
        self.W_q = nn.Linear(d_model, d_model)
        # Single head projection for K, V
        self.W_k = nn.Linear(d_model, self.head_dim)  # Only 1 head!
        self.W_v = nn.Linear(d_model, self.head_dim)  # Only 1 head!
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x, kv_cache=None):
        B, L, D = x.shape
        
        Q = self.W_q(x).view(B, L, self.num_heads, self.head_dim)
        K = self.W_k(x).view(B, L, 1, self.head_dim)  # [B, L, 1, head_dim]
        V = self.W_v(x).view(B, L, 1, self.head_dim)
        
        # Broadcast K, V to all heads
        K = K.expand(-1, -1, self.num_heads, -1)
        V = V.expand(-1, -1, self.num_heads, -1)
        
        # KV cache is 32x smaller!
        if kv_cache is not None:
            K = torch.cat([kv_cache[0], K], dim=1)
            V = torch.cat([kv_cache[1], V], dim=1)
        
        # Rest is identical to MHA
        scores = torch.einsum('blhd,bshd->bhls', Q, K) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)
        out = torch.einsum('bhls,bshd->blhd', attn, V)
        
        return self.W_o(out.reshape(B, L, D)), (K[:,:,0:1,:], V[:,:,0:1,:])

Memory per token per layer:

KV cache = 2 × 1 × head_dim × dtype_bytes
         = 2 × 1 × 128 × 2 = 512 bytes (32x reduction!)
⚠️ Quality Trade-off

MQA reduces model quality by 1-3% on most benchmarks. The single KV head becomes an information bottleneck, especially for tasks requiring fine-grained token interactions.

Grouped-Query Attention (GQA)

GQA interpolates between MHA and MQA:

class GroupedQueryAttention:
    """
    GQA: Groups of Q heads share K, V heads.
    Used by Llama 2 70B, Llama 3, Mistral, etc.
    """
    def __init__(self, d_model: int, num_heads: int, num_kv_heads: int):
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups = num_heads // num_kv_heads
        self.head_dim = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, num_heads * self.head_dim)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.head_dim)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.head_dim)
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x, kv_cache=None):
        B, L, D = x.shape
        
        Q = self.W_q(x).view(B, L, self.num_heads, self.head_dim)
        K = self.W_k(x).view(B, L, self.num_kv_heads, self.head_dim)
        V = self.W_v(x).view(B, L, self.num_kv_heads, self.head_dim)
        
        # Expand K, V to match Q heads
        # Each KV head serves num_groups Q heads
        K = K.repeat_interleave(self.num_groups, dim=2)  # [B, L, num_heads, head_dim]
        V = V.repeat_interleave(self.num_groups, dim=2)
        
        if kv_cache is not None:
            K = torch.cat([kv_cache[0], K], dim=1)
            V = torch.cat([kv_cache[1], V], dim=1)
        
        scores = torch.einsum('blhd,bshd->bhls', Q, K) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)
        out = torch.einsum('bhls,bshd->blhd', attn, V)
        
        # Store only num_kv_heads in cache
        K_cache = K[:, :, ::self.num_groups, :]
        V_cache = V[:, :, ::self.num_groups, :]
        
        return self.W_o(out.reshape(B, L, D)), (K_cache, V_cache)

Llama 2 70B configuration:

  • num_heads = 64, num_kv_heads = 8
  • Compression ratio: 64/8 = 8x
  • Memory per token: 2 × 8 × 128 × 2 = 4 KB

Memory Comparison

📊

KV Cache Size Comparison (Per Layer, Per Token)

VariantConfigBytes/TokenReduction
MHA 32 heads 16,384 1x (baseline)
GQA 32 Q / 8 KV 4,096 4x
GQA 32 Q / 4 KV 2,048 8x
MQA 32 Q / 1 KV 512 32x
Note: FP16, head_dim=128

Maximum Context Length at Fixed Memory (80GB)

(K tokens)
MHA (Llama 7B)
32 K tokens
GQA 8x (Llama 70B)
128 K tokens
MQA 32x
256 K tokens

Multi-head Latent Attention (MLA)

DeepSeek-V2’s MLA compresses KV differently:

class MultiHeadLatentAttention:
    """
    MLA: Compresses KV into low-rank latent space.
    Key insight: K, V have low intrinsic dimensionality.
    """
    def __init__(self, d_model: int, num_heads: int, 
                 kv_latent_dim: int, rope_dim: int):
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.kv_latent_dim = kv_latent_dim  # e.g., 512
        self.rope_dim = rope_dim  # Rotary embedding dimension
        
        # Q projection (standard)
        self.W_q = nn.Linear(d_model, d_model)
        
        # Compress KV to latent space
        self.W_kv_compress = nn.Linear(d_model, kv_latent_dim)  # Down-project
        
        # Expand from latent to K, V
        self.W_k_expand = nn.Linear(kv_latent_dim, d_model)
        self.W_v_expand = nn.Linear(kv_latent_dim, d_model)
        
        # Separate RoPE keys (not compressed)
        self.W_k_rope = nn.Linear(d_model, rope_dim * num_heads)
        
        self.W_o = nn.Linear(d_model, d_model)
    
    def forward(self, x, kv_cache=None):
        B, L, D = x.shape
        
        Q = self.W_q(x).view(B, L, self.num_heads, self.head_dim)
        
        # Compress to latent
        kv_latent = self.W_kv_compress(x)  # [B, L, kv_latent_dim]
        
        # RoPE keys (for position encoding)
        k_rope = self.W_k_rope(x).view(B, L, self.num_heads, self.rope_dim)
        
        # KV cache stores: latent + k_rope (much smaller than full K, V!)
        if kv_cache is not None:
            kv_latent = torch.cat([kv_cache[0], kv_latent], dim=1)
            k_rope = torch.cat([kv_cache[1], k_rope], dim=1)
        
        # Expand K, V from latent (done at attention time)
        K_content = self.W_k_expand(kv_latent).view(B, -1, self.num_heads, self.head_dim)
        V = self.W_v_expand(kv_latent).view(B, -1, self.num_heads, self.head_dim)
        
        # Combine content K with RoPE K
        K = torch.cat([K_content[..., :self.head_dim-self.rope_dim], 
                       k_rope], dim=-1)
        
        # Standard attention
        scores = torch.einsum('blhd,bshd->bhls', Q, K) / math.sqrt(self.head_dim)
        attn = F.softmax(scores, dim=-1)
        out = torch.einsum('bhls,bshd->blhd', attn, V)
        
        return self.W_o(out.reshape(B, L, D)), (kv_latent, k_rope)

Memory per token per layer (DeepSeek-V2):

KV cache = kv_latent_dim + rope_dim × num_heads
         = 512 + 64 × 128  # Example values
         ≈ 8.7 KB (vs 32 KB for equivalent MHA)
ℹ️ MLA Trade-off

MLA adds compute (latent expansion) but reduces memory bandwidth. For decode-phase where we’re memory-bound, this is often beneficial.

Quality-Memory Trade-off

📊

Attention Variant Benchmark Results

VariantMMLUHumanEvalKV Memory
MHA (baseline) 70.2% 67.5% 100%
GQA (8 groups) 69.8% 66.8% 12.5%
GQA (4 groups) 69.4% 66.1% 6.25%
MQA 68.1% 63.4% 3.1%
MLA 70.0% 67.2% ~25%
Note: 70B-class models, normalized memory relative to MHA

Implementation Considerations

Efficient GQA Kernel

// GQA-aware attention kernel
template<int NUM_Q_HEADS, int NUM_KV_HEADS, int HEAD_DIM>
__global__ void gqa_attention_kernel(
    const half* Q,        // [batch, seq_q, num_q_heads, head_dim]
    const half* K_cache,  // [batch, seq_kv, num_kv_heads, head_dim]
    const half* V_cache,  // [batch, seq_kv, num_kv_heads, head_dim]
    half* output,
    int seq_q, int seq_kv
) {
    constexpr int HEADS_PER_GROUP = NUM_Q_HEADS / NUM_KV_HEADS;
    
    int batch_idx = blockIdx.x;
    int q_head_idx = blockIdx.y;
    int kv_head_idx = q_head_idx / HEADS_PER_GROUP;  // Map Q head to KV head
    
    // Load Q for this head
    half q_reg[HEAD_DIM];
    load_q(Q, q_reg, batch_idx, q_head_idx);
    
    // Iterate over K, V (using shared KV head)
    float acc[HEAD_DIM] = {0};
    float max_score = -INFINITY;
    float sum_exp = 0;
    
    for (int kv_pos = 0; kv_pos < seq_kv; kv_pos++) {
        // Load from KV head (not Q head!)
        half k_reg[HEAD_DIM], v_reg[HEAD_DIM];
        load_kv(K_cache, V_cache, k_reg, v_reg, batch_idx, kv_head_idx, kv_pos);
        
        // Compute attention score
        float score = dot_product(q_reg, k_reg) / sqrtf(HEAD_DIM);
        
        // Online softmax
        float new_max = fmaxf(max_score, score);
        float exp_diff = expf(max_score - new_max);
        float exp_score = expf(score - new_max);
        
        // Update accumulator
        for (int d = 0; d < HEAD_DIM; d++) {
            acc[d] = acc[d] * exp_diff + exp_score * __half2float(v_reg[d]);
        }
        sum_exp = sum_exp * exp_diff + exp_score;
        max_score = new_max;
    }
    
    // Normalize and store
    for (int d = 0; d < HEAD_DIM; d++) {
        output[...] = __float2half(acc[d] / sum_exp);
    }
}

Memory Layout for GQA

def optimal_gqa_kv_layout(num_kv_heads, head_dim, max_seq_len, max_batch):
    """
    Choose memory layout optimizing for cache locality.
    """
    # Option 1: [batch, seq, kv_heads, head_dim] - good for sequential access
    # Option 2: [batch, kv_heads, seq, head_dim] - good for head parallelism
    
    # For most GQA implementations, Option 1 is better because:
    # - Multiple Q heads access same KV head sequentially
    # - Better L2 cache utilization
    
    return {
        'k_cache_shape': (max_batch, max_seq_len, num_kv_heads, head_dim),
        'v_cache_shape': (max_batch, max_seq_len, num_kv_heads, head_dim),
        'memory_per_token': 2 * num_kv_heads * head_dim * 2,  # FP16
    }

Recommendations

Use CaseRecommendationRationale
Quality-criticalMHA or MLAPreserve attention capacity
Long contextGQA (4-8 groups)Good quality/memory balance
Memory-constrainedMQAMaximum compression
Inference-optimizedGQA (8 groups)Industry standard
💡 Practical Guidance

For most production deployments, GQA with 8 KV heads (matching Llama 2/3 70B configuration) offers the best balance. Start there and only move to MQA if memory constraints are severe.

Conclusion

The evolution from MHA → GQA → MQA → MLA reflects the increasing importance of memory efficiency in LLM deployment. GQA has emerged as the practical winner, offering 4-8x memory reduction with less than 1% quality loss. Understanding these trade-offs is essential for architecting inference systems.