Large models don’t fit on a single GPU. Tensor parallelism splits individual layers across GPUs, requiring careful placement of AllReduce operations to maintain correctness while minimizing communication overhead.
Column vs Row Parallelism
For a linear layer Y = XW + b, two parallelization strategies exist:
Column Parallelism: Split W along columns
W = [W₁ | W₂] (each GPU holds one column partition)
Y = X[W₁ | W₂] = [XW₁ | XW₂]
Result: Each GPU computes partial output, needs AllGather for full Y
Row Parallelism: Split W along rows
W = [W₁] X = [X₁ | X₂] (input also split)
[W₂]
Y = X₁W₁ + X₂W₂
Result: Each GPU computes partial sum, needs AllReduce
Transformer Block Strategy
The key insight: pair column and row parallelism to minimize communication:
class TensorParallelMLP(nn.Module):
"""
MLP with tensor parallelism.
Gate projection: Column parallel (no communication)
Down projection: Row parallel (AllReduce after)
"""
def __init__(self, hidden_dim, ffn_dim, tp_size, tp_rank):
self.tp_size = tp_size
self.tp_rank = tp_rank
# Column parallel: each GPU has ffn_dim // tp_size columns
self.gate_proj = ColumnParallelLinear(hidden_dim, ffn_dim, tp_size, tp_rank)
self.up_proj = ColumnParallelLinear(hidden_dim, ffn_dim, tp_size, tp_rank)
# Row parallel: each GPU has ffn_dim // tp_size rows
self.down_proj = RowParallelLinear(ffn_dim, hidden_dim, tp_size, tp_rank)
def forward(self, x):
# Column parallel projections (no communication needed)
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
# Element-wise multiply (local operation)
hidden = gate * up
# Row parallel projection + AllReduce
output = self.down_proj(hidden) # AllReduce inside
return output
This design requires only ONE AllReduce per MLP block (after down_proj). Naive parallelism would require 3 AllReduce operations.
AllReduce Optimization
AllReduce performance is critical. For N GPUs with message size M:
Ring AllReduce: 2(N-1)/N × M data transferred per GPU Tree AllReduce: 2 log(N) × M data transferred, higher latency
def estimate_allreduce_time(
message_bytes: int,
num_gpus: int,
bandwidth_gbps: float, # NVLink: ~600 GB/s bidirectional
latency_us: float = 5.0 # Per-hop latency
) -> float:
"""Estimate AllReduce time in microseconds."""
# Ring AllReduce
bytes_per_gpu = 2 * (num_gpus - 1) / num_gpus * message_bytes
transfer_time_us = bytes_per_gpu / (bandwidth_gbps * 1e9 / 1e6)
# 2(N-1) communication steps
total_latency_us = 2 * (num_gpus - 1) * latency_us
return transfer_time_us + total_latency_us
# Example: 70B model, 8 GPUs, hidden_dim=8192
hidden_bytes = 8192 * 2 # FP16
allreduce_time = estimate_allreduce_time(hidden_bytes, 8, 600)
# ~30 microseconds per AllReduce
AllReduce Time vs Message Size (8× A100 NVLink)
(µs)Attention Tensor Parallelism
Attention parallelizes naturally across heads:
class TensorParallelAttention(nn.Module):
"""
Attention with head-wise tensor parallelism.
Each GPU handles num_heads // tp_size heads.
"""
def __init__(self, hidden_dim, num_heads, head_dim, tp_size, tp_rank):
self.tp_size = tp_size
self.tp_rank = tp_rank
self.num_local_heads = num_heads // tp_size
# Each GPU projects to its subset of heads
self.qkv_proj = ColumnParallelLinear(
hidden_dim,
3 * self.num_local_heads * head_dim,
tp_size, tp_rank
)
# Output projection with AllReduce
self.o_proj = RowParallelLinear(
num_heads * head_dim, # Conceptual full size
hidden_dim,
tp_size, tp_rank
)
def forward(self, x, kv_cache=None):
# Local QKV projection (no communication)
qkv = self.qkv_proj(x)
q, k, v = qkv.split(self.num_local_heads * self.head_dim, dim=-1)
# Local attention computation
# Each GPU computes attention for its heads
attn_output = self._compute_attention(q, k, v, kv_cache)
# Output projection + AllReduce
output = self.o_proj(attn_output)
return output
KV Cache with Tensor Parallelism
Each GPU stores KV cache for its local heads only:
class TensorParallelKVCache:
"""
Distributed KV cache for tensor parallel inference.
"""
def __init__(self, config, tp_size, tp_rank):
self.tp_size = tp_size
self.tp_rank = tp_rank
# Local heads
self.num_local_kv_heads = config.num_kv_heads // tp_size
# Each GPU allocates cache for local heads only
self.k_cache = torch.zeros(
config.max_batch_size,
self.num_local_kv_heads, # Not full heads!
config.max_seq_len,
config.head_dim,
dtype=config.dtype,
device=f'cuda:{tp_rank}'
)
self.v_cache = torch.zeros_like(self.k_cache)
def get_memory_per_gpu(self) -> int:
"""Memory usage per GPU (reduced by tp_size)."""
return self.k_cache.numel() * 2 * self.k_cache.element_size()
Memory per GPU with Tensor Parallelism (Llama-70B)
| TP Size | Weights/GPU | KV Cache/GPU | Total/GPU |
|---|---|---|---|
| 1 (no TP) | 140GB | 40GB | 180GB |
| 2 | 70GB | 20GB | 90GB |
| 4 | 35GB | 10GB | 45GB |
| 8 | 17.5GB | 5GB | 22.5GB |
Efficiency Analysis
Tensor parallelism overhead comes from AllReduce operations:
def calculate_tp_efficiency(
model_config,
tp_size: int,
batch_size: int,
seq_len: int
) -> dict:
"""
Calculate tensor parallelism efficiency.
"""
# Compute time (scales linearly with TP)
flops_per_token = model_config.estimate_flops()
local_flops = flops_per_token / tp_size
gpu_tflops = 312 # A100 FP16
compute_time_us = local_flops / (gpu_tflops * 1e12) * 1e6
# Communication time
allreduce_per_layer = 2 # One for attention, one for MLP
allreduce_bytes = batch_size * seq_len * model_config.hidden_dim * 2 # FP16
comm_time_per_layer = estimate_allreduce_time(allreduce_bytes, tp_size, 600)
total_comm_time = comm_time_per_layer * allreduce_per_layer * model_config.num_layers
# Efficiency
total_time = compute_time_us + total_comm_time
efficiency = compute_time_us / total_time
return {
'compute_time_us': compute_time_us,
'comm_time_us': total_comm_time,
'efficiency': efficiency,
'speedup': tp_size * efficiency
}
TP Efficiency vs TP Size (Llama-70B, batch=1)
(%)Beyond TP=8, communication overhead dominates. For higher parallelism, combine with pipeline parallelism or use larger batch sizes to amortize AllReduce cost.
Conclusion
Tensor parallelism enables single-model inference across multiple GPUs with:
- Linear memory scaling: Each GPU holds 1/N of weights and KV cache
- Sub-linear throughput scaling: AllReduce overhead increases with TP size
- Sweet spot at TP=4-8 for most model sizes
For models that don’t fit on 8 GPUs, combine with pipeline parallelism. For throughput-focused workloads, use larger batches to amortize communication.