Part of Series Transformer Anatomy 39 of 36
1 The Transformer Attention Mechanism: From First Principles to Performance Reality 2 Tokenization and BPE: How LLMs See Text — From Characters to Subwords 3 Embedding Layers: The Geometry of Meaning in LLMs 4 Position Encoding in Transformers: From Sinusoidal to RoPE, ALiBi, and Long-Context Scaling 5 Softmax Numerics: Log-Sum-Exp, Temperature, and Why Numerical Stability Matters 6 Attention Variants Compared: MHA, MQA, GQA, and MLA 7 Normalization in Transformers: LayerNorm, RMSNorm, and the Training Stability Story 8 Residual Connections and Skip Paths: Why Transformers Can Be 100 Layers Deep 9 The Feed-Forward Network: SwiGLU, Gating, and the FFN-as-Memory Hypothesis 10 Mixture of Experts: Why Conditional Computation Is the Path to Trillion-Parameter Models 11 The Output Head: Unembedding, Weight Tying, and Vocabulary Projection 12 Cross-Entropy Loss: How the Loss Function Shapes What an LLM Learns 13 Encoder vs Decoder: Why Decoder-Only Won 14 DeepSeek V3: How 671B Parameters Trained for the Cost of a 70B Dense Model 15 Building a Transformer From Scratch: Putting Every Component Together 16 Gradient Flow and Backpropagation Through Transformers: What Happens During the Backward Pass 17 Weight Initialization: Xavier, Kaiming, and Why mu-P Changes Everything for Large Models 18 Training Loop Anatomy: Forward Pass, Loss Computation, Backward Pass, Optimizer Step 19 Learning Rate Schedules: Warmup, Cosine Decay, and Why WSD Changes Everything 20 Distributed Data Parallel: Gradient Synchronization, Bucket All-Reduce, and Overlap with Backward 21 Activation Functions Deep Dive: ReLU, GELU, SiLU, and Why Each Matters for Transformers 22 Dropout and Regularization in Transformers: Where It Helps, Where It Hurts 23 Attention Masking: Causal, Bidirectional, Sliding Window, Block Sparse, and Custom Patterns 24 Mixed Precision Training: BF16 Forward, FP32 Master Weights, and the Precision Hierarchy 25 Token Prediction Heads: Next-Token, Multi-Token, and Classifier Heads 26 Mixture of Depths: Conditional Computation Per Layer for Faster Inference 27 Sparse Attention Patterns: Local, Strided, Hash-Based, and Learnable Sparsity 28 Rotary Position Embedding: The Complete Mathematical Derivation 29 Knowledge Distillation: Training Small Models to Match Large Ones 30 Model Merging: Weight Averaging, TIES, DARE, and Evolutionary Search 31 Pruning at Scale: SparseGPT, Wanda, and Structured Removal of Redundant Parameters 32 The Transformer in 2026: What Changed, What Stayed, and What's Next 33 Data Loading: Tokenization, Sequence Packing, Padding Strategies, and Attention Masks 34 The FlashAttention Backward Pass: Recomputation, Memory Savings, and the 33% Compute Overhead 35 The Inference Engine: Token Generation Loop, KV Cache Management, and Autoregressive Decoding 36 Tensor Parallelism Implementation: Splitting Weights Across GPUs for Training and Inference

DDP replicates the entire model on every GPU. For Llama 70B (140 GB FP16), that means every GPU needs 140 GB — impossible on a single 80 GB H100. Tensor Parallelism (TP) splits each layer’s weight matrices across NN GPUs. Each GPU holds 1/N1/N of the weights and performs 1/N1/N of the computation. Communication (all-reduce) synchronizes the results.

Column-Parallel Linear

Split the weight matrix by columns. Each GPU computes a partial output:

import torch
import torch.nn as nn
import torch.distributed as dist

class ColumnParallelLinear(nn.Module):
    """Split weight columns across TP ranks. No communication in forward.

    Full weight: [in_features, out_features]
    Per-GPU weight: [in_features, out_features // tp_size]
    """
    def __init__(self, in_features, out_features, tp_size, tp_rank, bias=False):
        super().__init__()
        assert out_features % tp_size == 0
        self.out_per_rank = out_features // tp_size
        self.weight = nn.Parameter(
            torch.randn(in_features, self.out_per_rank) * 0.02
        )
        self.bias = nn.Parameter(torch.zeros(self.out_per_rank)) if bias else None

    def forward(self, x):
        # x: [B, S, in_features] — same on all GPUs
        # output: [B, S, out_features // tp_size] — different per GPU
        output = x @ self.weight
        if self.bias is not None:
            output = output + self.bias
        return output

No communication needed: each GPU multiplies the full input by its column shard independently.

Row-Parallel Linear

Split the weight matrix by rows. Each GPU gets a partial input and computes a partial result. All-reduce sums the partial results:

class RowParallelLinear(nn.Module):
    """Split weight rows across TP ranks. All-reduce in forward.

    Full weight: [in_features, out_features]
    Per-GPU weight: [in_features // tp_size, out_features]
    """
    def __init__(self, in_features, out_features, tp_size, tp_rank, bias=False):
        super().__init__()
        assert in_features % tp_size == 0
        self.in_per_rank = in_features // tp_size
        self.weight = nn.Parameter(
            torch.randn(self.in_per_rank, out_features) * 0.02
        )
        self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
        self.tp_group = None  # Set during initialization

    def forward(self, x):
        # x: [B, S, in_features // tp_size] — partial input (from column-parallel)
        # Partial output: [B, S, out_features]
        output = x @ self.weight

        # All-reduce: sum partial outputs across all TP ranks
        dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.tp_group)

        if self.bias is not None:
            output = output + self.bias
        return output

The Megatron Pattern

The key insight from Megatron-LM: pair column-parallel (first linear) with row-parallel (second linear). This requires only ONE all-reduce per sublayer:

Megatron TP Pattern for FFN

Input x (same on all GPUs) [B, S, d_model] Replicated across TP ranks
Column-Parallel W_up [d_model, d_ff/N] per GPU No communication needed
Activation (SiLU) [B, S, d_ff/N] per GPU Applied locally
Row-Parallel W_down [d_ff/N, d_model] per GPU ALL-REDUCE after matmul
Output (same on all GPUs) [B, S, d_model] Synchronized by all-reduce
class TPSwiGLUFFN(nn.Module):
    """Tensor-parallel SwiGLU FFN using Megatron pattern."""

    def __init__(self, d_model, d_ff, tp_size, tp_rank):
        super().__init__()
        # Column-parallel: split output dim
        self.w1 = ColumnParallelLinear(d_model, d_ff, tp_size, tp_rank)
        self.w3 = ColumnParallelLinear(d_model, d_ff, tp_size, tp_rank)
        # Row-parallel: split input dim, all-reduce output
        self.w2 = RowParallelLinear(d_ff, d_model, tp_size, tp_rank)

    def forward(self, x):
        # x: [B, S, d_model] — same on all GPUs
        gate = torch.nn.functional.silu(self.w1(x))  # [B, S, d_ff/N]
        up = self.w3(x)                                # [B, S, d_ff/N]
        hidden = gate * up                             # [B, S, d_ff/N]
        output = self.w2(hidden)                       # [B, S, d_model] — all-reduced
        return output

Same pattern for attention: QKV projection is column-parallel (split heads across GPUs), output projection is row-parallel (all-reduce).

class TPAttention(nn.Module):
    """Tensor-parallel attention using Megatron pattern."""

    def __init__(self, d_model, n_heads, n_kv_heads, tp_size, tp_rank):
        super().__init__()
        assert n_heads % tp_size == 0
        assert n_kv_heads % tp_size == 0

        self.heads_per_rank = n_heads // tp_size
        self.kv_heads_per_rank = n_kv_heads // tp_size
        d_head = d_model // n_heads

        # Column-parallel: each GPU gets a subset of heads
        self.W_q = ColumnParallelLinear(d_model, self.heads_per_rank * d_head, tp_size, tp_rank)
        self.W_k = ColumnParallelLinear(d_model, self.kv_heads_per_rank * d_head, tp_size, tp_rank)
        self.W_v = ColumnParallelLinear(d_model, self.kv_heads_per_rank * d_head, tp_size, tp_rank)
        # Row-parallel: all-reduce after output projection
        self.W_o = RowParallelLinear(self.heads_per_rank * d_head, d_model, tp_size, tp_rank)

    def forward(self, x, kv_cache=None):
        Q = self.W_q(x)  # [B, S, heads_per_rank * d_head]
        K = self.W_k(x)  # [B, S, kv_heads_per_rank * d_head]
        V = self.W_v(x)  # [B, S, kv_heads_per_rank * d_head]
        # ... attention computation on local heads ...
        output = self.W_o(attn_output)  # All-reduced to [B, S, d_model]
        return output

Communication Cost

Each all-reduce on NN GPUs transfers 2×(N1)/N×M2 \times (N-1)/N \times M bytes where MM is the message size. Per transformer layer: 2 all-reduces (attention + FFN).

For Llama 70B at TP=8, batch_size x seq_len = 4096 tokens, d_model=8192, FP16:

M=4096×8192×2=64 MB per all-reduceM = 4096 \times 8192 \times 2 = 64 \text{ MB per all-reduce} Volume=2×78×64=112 MB per all-reduce\text{Volume} = 2 \times \frac{7}{8} \times 64 = 112 \text{ MB per all-reduce} Time (NVLink 900 GB/s)=112/900000=0.124 ms per all-reduce\text{Time (NVLink 900 GB/s)} = 112 / 900000 = 0.124 \text{ ms per all-reduce} Per layer: 2×0.124=0.249 ms\text{Per layer: } 2 \times 0.124 = 0.249 \text{ ms} 80 layers: 80×0.249=19.9 ms total\text{80 layers: } 80 \times 0.249 = 19.9 \text{ ms total}

📊

TP Communication Overhead (Llama 70B, B*S=4096)

TP SizeAll-Reduce VolumeTime (NVLink)% of Forward Pass
TP=2 64 MB/layer 0.071 ms 0.6%
TP=4 96 MB/layer 0.107 ms 1.7%
TP=8 (1 node) 112 MB/layer 0.124 ms 4.0%
TP=8 (InfiniBand) 112 MB/layer 2.24 ms 72% (too slow!)
Note: NVLink is essential for TP. InfiniBand adds 18x latency, making TP impractical across nodes.
⚠️ TP Requires NVLink

At TP=8 on NVLink: 4% overhead (acceptable). At TP=8 on InfiniBand: 72% overhead (unacceptable). This is why TP is always intra-node (8 GPUs connected by NVLink) and pipeline parallelism is used for inter-node communication (less frequent, larger messages).

When to Use TP vs DDP vs PP

📊

Parallelism Strategy Decision

Model SizeGPUs AvailableInterconnectRecommended Strategy
7B 1-8 Any DDP only (model fits on 1 GPU)
70B 8 (1 node) NVLink TP=8 (split across node)
70B 64 (8 nodes) NVLink + IB TP=8 intra-node, DDP=8 across nodes
405B 64 (8 nodes) NVLink + IB TP=8, PP=4, DP=2
671B MoE 2048 NVLink + IB TP=4, EP=64, PP=4, DP=2