This post builds a Mixture of Experts layer from zero. Not a conceptual overview — a working PyTorch implementation you can copy into a training loop. By the end, you will have a MoELayer class that takes a [batch, seq_len, d_model] tensor, routes each token to its top-kk experts, computes SwiGLU FFN outputs per expert, combines the results, computes the load-balancing auxiliary loss, and enforces a capacity factor. Every design choice is grounded in the math.

We start with the interface, then build each component: the gating function, token dispatch, load-balancing loss, and capacity factor enforcement. Section 6 assembles them into a single, copy-paste-ready class.


1. The MoE Layer Interface

Inputs and Outputs

An MoE layer is a drop-in replacement for a standard FFN block inside a transformer. The external contract is identical:

  • Input: x with shape [batch_size, seq_len, d_model]
  • Output: y with shape [batch_size, seq_len, d_model]
  • Side output: aux_loss, a scalar added to the training loss for load balancing

Internally, the MoE layer contains three subsystems: a router (gating network), a set of N expert FFNs, and a combine step that merges expert outputs back into the original token positions.

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple

class MoELayerInterface(nn.Module):
    """
    Interface contract for an MoE layer.
    Drop-in replacement for a standard FFN block.
    """
    def __init__(
        self,
        d_model: int,       # e.g. 4096
        d_ff: int,           # e.g. 11008 (SwiGLU intermediate)
        num_experts: int,    # e.g. 8
        top_k: int,          # e.g. 2
        capacity_factor: float,  # e.g. 1.25
        balance_coeff: float,    # e.g. 0.01
    ):
        super().__init__()
        # Components defined in later sections
        ...

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Args:
            x: [batch_size, seq_len, d_model]
        Returns:
            output: [batch_size, seq_len, d_model]
            aux_loss: scalar tensor for load balancing
        """
        ...

Dataflow

The forward pass proceeds in five stages. Each stage has a well-defined tensor shape.

MoE Layer Dataflow

Data shapes through each stage of the MoE forward pass

1. Input Reshape [B, S, D] -> [B*S, D] Flatten batch and sequence into token dimension
2. Router [B*S, D] @ [D, N] -> [B*S, N] -> top-k -> [B*S, K] Produce gate weights and expert assignments per token
3. Token Dispatch [B*S, D] -> N groups of [tokens_i, D] Scatter tokens to their assigned experts
4. Expert FFNs [tokens_i, D] -> [tokens_i, D] for each expert i Each expert processes only its assigned tokens
5. Combine N groups -> [B*S, D] (weighted sum) Gather outputs, multiply by gate weights, sum

The key insight: after flattening batch and sequence dimensions, we have T=B×ST = B \times S tokens. Each token is independently routed to kk of NN experts. The router’s job is to produce, for each token, the indices of kk experts and their corresponding weights (which sum to 1).

Parameter Count

For NN experts, each with a SwiGLU FFN of dimensions dmodeldffd_\text{model} \to d_\text{ff}, the total parameter count is:

Pexperts=N×(3×dmodel×dff)P_\text{experts} = N \times (3 \times d_\text{model} \times d_\text{ff})

The factor of 3 comes from SwiGLU having three weight matrices: gate projection WgateW_\text{gate}, up-projection WupW_\text{up}, and down-projection WdownW_\text{down}. The router adds dmodel×Nd_\text{model} \times N parameters — negligible compared to the expert weights.

For Mixtral 8x7B: N=8N = 8, dmodel=4096d_\text{model} = 4096, dff=14336d_\text{ff} = 14336. Total expert parameters: 8×3×4096×14336=1.41B8 \times 3 \times 4096 \times 14336 = 1.41\text{B}. Activated per token (top-2): 2×3×4096×14336=352M2 \times 3 \times 4096 \times 14336 = 352\text{M}. The ratio k/N=2/8=25%k/N = 2/8 = 25\% gives a 4x reduction in per-token FLOPs compared to a single-expert model of the same total size.

FLOPs vs Parameters

An MoE layer with N=8N = 8 experts and top-2 routing has 8x the parameters of a single FFN but only 2x the FLOPs. This is the core MoE value proposition: parameter count scales with NN, but compute scales with kk. For DeepSeek V3 (N=256N = 256, k=8k = 8): 256x parameters, 8x FLOPs vs a single expert.


2. The Gating Function

The Router Linear Layer

The gating network is a single linear projection from the token’s hidden state to a vector of logits over all experts:

class GatingNetwork(nn.Module):
    def __init__(self, d_model: int, num_experts: int, top_k: int):
        super().__init__()
        self.top_k = top_k
        self.num_experts = num_experts
        # W_gate: [d_model, num_experts]
        self.gate = nn.Linear(d_model, num_experts, bias=False)

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: [num_tokens, d_model]
        Returns:
            gate_weights: [num_tokens, top_k]  (softmax-normalized)
            expert_indices: [num_tokens, top_k] (integer indices)
            router_probs: [num_tokens, num_experts] (full softmax, for aux loss)
        """
        # Step 1: Compute logits
        logits = self.gate(x)  # [num_tokens, num_experts]

        # Step 2: Full softmax over all experts (needed for aux loss)
        router_probs = F.softmax(logits, dim=-1)  # [num_tokens, num_experts]

        # Step 3: Select top-k experts per token
        topk_weights, topk_indices = torch.topk(
            router_probs, self.top_k, dim=-1
        )  # both: [num_tokens, top_k]

        # Step 4: Re-normalize top-k weights to sum to 1
        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

        return topk_weights, topk_indices, router_probs

The matrix multiply x @ W_gate.T costs 2×T×dmodel×N2 \times T \times d_\text{model} \times N FLOPs. For T=4096T = 4096 tokens, dmodel=4096d_\text{model} = 4096, N=8N = 8: that is 2×4096×4096×8=268M2 \times 4096 \times 4096 \times 8 = 268\text{M} FLOPs — under 0.1% of the total layer FLOPs. The router is computationally negligible.

The Gradient Problem: Why argmax Fails

The ideal gating function would be a hard argmax: pick the single best expert, send the full token there, ignore the rest. But argmax has zero gradient almost everywhere:

argmax(z)zi=0i\frac{\partial \text{argmax}(z)}{\partial z_i} = 0 \quad \forall i

The argmax output is a discrete integer. Perturbing the logits ziz_i by a small ϵ\epsilon does not change the argmax unless it crosses a decision boundary — and at the boundary, the function is discontinuous. This means no gradient flows through the expert selection, making it impossible to train the router via backpropagation.

Softmax + Top-k: A Differentiable Approximation

The solution used in practice is a two-step process:

  1. Apply softmax to get continuous probabilities: pi=exp(zi)jexp(zj)p_i = \frac{\exp(z_i)}{\sum_j \exp(z_j)}
  2. Select the top-kk indices (discrete, non-differentiable)
  3. Use the continuous softmax values as the gate weights

The top-kk selection itself is still non-differentiable (it is a discrete operation). But the gate weights wi=piw_i = p_i for the selected experts ARE differentiable with respect to the logits ziz_i. This means:

  • The combination weights (how much each expert’s output contributes) have well-defined gradients
  • The expert selection (which experts are chosen) does not have gradients

The router learns through the combination weights: if expert jj produces a good output for a token, the gradient will increase pjp_j for that token’s hidden state, which will increase zjz_j via the gate weights. Over time, this makes expert jj more likely to be selected for similar tokens.

Σ Definition: Straight-Through Estimator for Top-k Routing

In the forward pass, the top-kk mask mm is discrete: mi=1m_i = 1 if expert ii is in the top-kk, else mi=0m_i = 0. In the backward pass, the gradient ignores the mask and flows through the softmax probabilities as if the mask were the identity. Formally: Lzi=Lwiwizi\frac{\partial L}{\partial z_i} = \frac{\partial L}{\partial w_i} \cdot \frac{\partial w_i}{\partial z_i} where wiw_i is the (re-normalized) softmax probability for expert ii in the top-kk set. The selection of which experts are in top-kk is treated as constant.

This is an implicit straight-through estimator. The forward pass is discontinuous (expert selection can change abruptly), but the backward pass provides smooth gradients through the softmax weights. It works well in practice because:

  1. The softmax probabilities are smooth functions of the logits
  2. Small changes in logits cause small changes in the combination weights
  3. The expert that “almost” got selected has a high softmax probability that generates a gradient pushing it to be selected next time
# Illustration of the gradient flow
# Forward:
logits = x @ W_gate.T                 # Differentiable
probs = softmax(logits)               # Differentiable
top_k_indices = topk(probs, k).indices  # NOT differentiable (discrete)
top_k_weights = probs[top_k_indices]    # Differentiable (indexing by constant)
weights = top_k_weights / top_k_weights.sum()  # Differentiable

# Backward:
# d_loss/d_logits = d_loss/d_weights * d_weights/d_probs * d_probs/d_logits
# The top_k_indices are treated as constants in the backward pass.
# Gradients flow through probs -> logits -> W_gate and x.

Alternative: Softmax-then-Top-k vs Top-k-then-Softmax

There are two valid orderings:

Softmax-then-Top-k (Switch, GShard)
# Apply softmax over ALL experts first
probs = F.softmax(logits, dim=-1)
# Then select top-k from the probabilities
topk_w, topk_i = torch.topk(probs, k, dim=-1)
# Re-normalize selected weights
topk_w = topk_w / topk_w.sum(dim=-1, keepdim=True)
+ Top-k-then-Softmax (Mixtral)
# Select top-k from raw logits first
topk_logits, topk_i = torch.topk(logits, k, dim=-1)
# Apply softmax only over selected experts
topk_w = F.softmax(topk_logits, dim=-1)
# Already normalized (softmax over k values)

The difference: softmax-then-top-k computes the full distribution over all NN experts before selecting. Top-k-then-softmax discards all non-selected experts before normalizing. In practice, top-k-then-softmax produces sharper gates (the normalization denominator is smaller), while softmax-then-top-k preserves more information about the rejected experts in the auxiliary loss. The full-softmax approach is standard in implementations that compute the load-balancing loss, because the loss requires the probability assigned to every expert — not just the selected ones.


3. Expert Computation with Token Dispatch

The Core Problem

After the router assigns each of TT tokens to kk experts, we need to actually run the expert FFNs. The difficulty: different tokens go to different experts. Unlike a standard FFN where all tokens go through the same weight matrix (one big batched GEMM), MoE requires routing tokens to their respective experts and collecting the results.

There are three implementation strategies, each with different performance characteristics.

Approach A: Scatter-Gather (Loop Over Experts)

The simplest approach: for each expert, find which tokens are assigned to it, batch those tokens, run the expert’s FFN, and scatter the results back.

def scatter_gather_dispatch(
    x: torch.Tensor,           # [num_tokens, d_model]
    expert_indices: torch.Tensor,  # [num_tokens, top_k]
    gate_weights: torch.Tensor,    # [num_tokens, top_k]
    experts: nn.ModuleList,    # List of N expert FFNs
) -> torch.Tensor:
    """
    Simple loop-based dispatch. Easy to understand, slow to execute.
    """
    num_tokens, d_model = x.shape
    top_k = expert_indices.shape[1]
    output = torch.zeros_like(x)  # [num_tokens, d_model]

    for k_idx in range(top_k):
        for expert_id in range(len(experts)):
            # Find tokens assigned to this expert at this k-position
            mask = (expert_indices[:, k_idx] == expert_id)
            if mask.any():
                # Gather tokens for this expert
                token_subset = x[mask]  # [num_assigned, d_model]
                # Run expert FFN
                expert_out = experts[expert_id](token_subset)
                # Scatter back, weighted by gate
                weights = gate_weights[mask, k_idx].unsqueeze(-1)
                output[mask] += weights * expert_out

    return output

This works but is slow: N×kN \times k kernel launches, each with a small batch. GPU utilization is poor because each expert processes a different number of tokens, and the loop serializes what could be parallel work.

📊

Token Dispatch Strategy Performance

StrategyKernel LaunchesGPU UtilizationLatency (A100, T=4096, N=8, k=2)
Scatter-Gather (loop) N x k = 16 ~15-25% ~4.2 ms
Permutation + Block GEMM 3 (sort + GEMM + unsort) ~60-75% ~1.8 ms
Grouped GEMM (Triton) 1 (fused) ~80-90% ~0.9 ms
Note: Measured for d_model=4096, d_ff=11008 SwiGLU experts. Grouped GEMM reduces kernel launch overhead and enables better SM occupancy.

Approach B: Permutation-Based Dispatch

Sort all tokens by their expert assignment, then run contiguous blocks through each expert. This eliminates the inner loop over experts.

def permutation_dispatch(
    x: torch.Tensor,           # [num_tokens, d_model]
    expert_indices: torch.Tensor,  # [num_tokens, top_k]
    gate_weights: torch.Tensor,    # [num_tokens, top_k]
    experts: nn.ModuleList,
) -> torch.Tensor:
    num_tokens, d_model = x.shape
    num_experts = len(experts)
    top_k = expert_indices.shape[1]

    # Flatten: each token appears top_k times (once per assigned expert)
    flat_indices = expert_indices.view(-1)        # [num_tokens * top_k]
    flat_weights = gate_weights.view(-1)          # [num_tokens * top_k]
    # Repeat each token top_k times to match
    x_repeated = x.repeat_interleave(top_k, dim=0)  # [num_tokens * top_k, d_model]

    # Sort by expert index -> contiguous blocks per expert
    sort_indices = flat_indices.argsort()
    sorted_tokens = x_repeated[sort_indices]      # [num_tokens * top_k, d_model]
    sorted_experts = flat_indices[sort_indices]    # [num_tokens * top_k]
    sorted_weights = flat_weights[sort_indices]    # [num_tokens * top_k]

    # Find boundaries for each expert's block
    # expert_counts[i] = number of tokens assigned to expert i
    expert_counts = torch.bincount(sorted_experts, minlength=num_experts)

    # Process each expert's contiguous block
    output_sorted = torch.empty_like(sorted_tokens)
    offset = 0
    for i in range(num_experts):
        count = expert_counts[i].item()
        if count > 0:
            expert_input = sorted_tokens[offset:offset + count]
            output_sorted[offset:offset + count] = experts[i](expert_input)
        offset += count

    # Weight by gate values
    output_sorted = output_sorted * sorted_weights.unsqueeze(-1)

    # Unsort: scatter back to original positions
    unsort_indices = sort_indices.argsort()
    output_unsorted = output_sorted[unsort_indices]  # [num_tokens * top_k, d_model]

    # Sum over top_k contributions for each token
    output = output_unsorted.view(num_tokens, top_k, d_model).sum(dim=1)

    return output

The permutation approach reduces kernel launches from N×kN \times k to N+2N + 2 (one sort, NN expert calls, one unsort). Each expert call is a contiguous GEMM, which is much more GPU-friendly. The sort and unsort are O(Tklog(Tk))O(T \cdot k \cdot \log(T \cdot k)) — cheap relative to the expert GEMMs.

Approach C: Grouped GEMM (Production Systems)

The production approach eliminates even the per-expert loop. A single grouped GEMM kernel processes all experts simultaneously. The kernel receives:

  • A single concatenated input tensor (sorted by expert)
  • A list of weight matrices (one per expert)
  • A list of group sizes (how many tokens per expert)

The kernel internally dispatches tiles of tokens to the correct weight matrices within a single kernel launch.

# Pseudocode for grouped GEMM dispatch
# This requires a custom CUDA/Triton kernel (e.g., from vLLM or Megablocks)

def grouped_gemm_dispatch(
    x: torch.Tensor,           # [num_tokens, d_model]
    expert_indices: torch.Tensor,
    gate_weights: torch.Tensor,
    w_gate: torch.Tensor,      # [num_experts, d_ff, d_model] - all gate projections
    w_up: torch.Tensor,        # [num_experts, d_ff, d_model] - all up projections
    w_down: torch.Tensor,      # [num_experts, d_model, d_ff] - all down projections
) -> torch.Tensor:
    num_tokens, d_model = x.shape
    top_k = expert_indices.shape[1]
    num_experts = w_gate.shape[0]

    # Step 1: Flatten and sort by expert (same as permutation approach)
    flat_indices = expert_indices.view(-1)
    flat_weights = gate_weights.view(-1)
    x_repeated = x.repeat_interleave(top_k, dim=0)

    sort_order = flat_indices.argsort()
    sorted_tokens = x_repeated[sort_order]
    sorted_experts = flat_indices[sort_order]

    # Step 2: Compute expert group sizes
    expert_counts = torch.bincount(sorted_experts, minlength=num_experts)

    # Step 3: Grouped GEMM -- single kernel launch
    # SwiGLU: output = W_down @ (silu(W_gate @ x) * (W_up @ x))
    # In practice, this is a Triton kernel that:
    #   - Iterates over experts in parallel across SMs
    #   - Each SM processes a tile of tokens for one expert
    #   - Weight matrices are indexed by expert_id

    # Compute gate and up projections (fused grouped GEMM)
    gate_out = grouped_gemm(sorted_tokens, w_gate, expert_counts)  # [T*k, d_ff]
    up_out = grouped_gemm(sorted_tokens, w_up, expert_counts)      # [T*k, d_ff]

    # SwiGLU activation
    hidden = F.silu(gate_out) * up_out  # [T*k, d_ff]

    # Down projection (grouped GEMM)
    expert_out = grouped_gemm(hidden, w_down, expert_counts)  # [T*k, d_model]

    # Step 4: Weight and unsort
    expert_out = expert_out * flat_weights[sort_order].unsqueeze(-1)
    unsort_order = sort_order.argsort()
    output = expert_out[unsort_order].view(num_tokens, top_k, d_model).sum(dim=1)

    return output
ℹ️ Grouped GEMM Implementations

Production grouped GEMM kernels are available in: Megablocks (Stanford, Triton-based), vLLM (fused MoE kernel), CUTLASS (NVIDIA, C++ templates), and cuBLAS grouped GEMM (CUDA 12+). The Triton-based Megablocks implementation is the most accessible for research use. It achieves 80-90% of peak FLOP/s on A100 by tiling tokens and experts across SMs.

Why Grouped GEMM Wins

The performance advantage comes from three factors:

  1. Single kernel launch: One kernel launch costs 5-10 μ\mus of overhead. The scatter-gather approach launches N×kN \times k kernels, adding 80-160 μ\mus of pure overhead for N=8,k=2N = 8, k = 2. Grouped GEMM launches one.

  2. Better SM utilization: With scatter-gather, each expert’s GEMM may only have enough tokens to occupy a fraction of the GPU’s SMs. With 108 SMs on an A100 and 8 experts, each expert gets ~13 SMs — but if the token count for one expert is small, many SMs sit idle. Grouped GEMM distributes work across all SMs dynamically.

  3. Memory access patterns: Sorting tokens by expert before the GEMM means each expert’s tokens are contiguous in memory, enabling coalesced reads. The scatter-gather approach reads tokens from scattered memory locations.

Expert Dispatch Strategies: Latency vs Token Count

(ms)
Scatter-Gather (N=8, k=2)
4.2 ms
Permutation + Loop
1.8 ms
Grouped GEMM (Triton)
0.9 ms
Grouped GEMM (CUTLASS)
0.7 ms

For the runnable implementation in Section 6, we use the permutation-based approach (Approach B) because it runs on vanilla PyTorch without custom kernels. In production, you would replace the inner loop with a grouped GEMM call.


4. Load-Balancing Loss

The Collapse Problem

Without any load-balancing mechanism, MoE training converges to a degenerate state: the router sends nearly all tokens to 2-3 “popular” experts while the rest receive almost nothing. This happens because of a positive feedback loop:

  1. Expert jj receives many tokens early in training (by random initialization)
  2. Expert jj gets many gradient updates, so it becomes better
  3. The router sends even more tokens to expert jj because it produces lower loss
  4. Experts that receive few tokens get few gradient updates and fall further behind

The result: N2N - 2 experts are wasted. The model degrades to a ~2-expert MoE, losing almost all of the capacity advantage.

The Auxiliary Loss

The standard solution (introduced in Switch Transformer, refined in GShard) is an auxiliary loss term that penalizes uneven load distribution. The loss is defined as:

Lbalance=αNi=1NfiPiL_\text{balance} = \alpha \cdot N \cdot \sum_{i=1}^{N} f_i \cdot P_i

where:

  • α\alpha is a hyperparameter (typically 0.01), controlling the strength of load balancing
  • NN is the number of experts
  • fif_i is the fraction of tokens routed to expert ii: fi=1Tt=1T1[token t routes to expert i]f_i = \frac{1}{T} \sum_{t=1}^{T} \mathbb{1}[\text{token } t \text{ routes to expert } i]
  • PiP_i is the average router probability assigned to expert ii: Pi=1Tt=1Tpt,iP_i = \frac{1}{T} \sum_{t=1}^{T} p_{t,i} where pt,ip_{t,i} is the softmax probability for expert ii on token tt
def load_balancing_loss(
    router_probs: torch.Tensor,    # [num_tokens, num_experts] (full softmax)
    expert_indices: torch.Tensor,  # [num_tokens, top_k]
    num_experts: int,
    top_k: int,
) -> torch.Tensor:
    """
    Compute the auxiliary load-balancing loss.

    Returns a scalar tensor that should be added to the main loss,
    multiplied by balance_coeff.
    """
    num_tokens = router_probs.shape[0]

    # f_i: fraction of tokens routed to expert i
    # Create a one-hot mask of shape [num_tokens, num_experts] for all top-k selections
    expert_mask = torch.zeros(
        num_tokens, num_experts, device=router_probs.device
    )
    # For each top-k slot, mark which expert was selected
    for k in range(top_k):
        expert_mask.scatter_add_(
            1,
            expert_indices[:, k:k+1],
            torch.ones(num_tokens, 1, device=router_probs.device),
        )
    # Clip to binary (a token may route to an expert via multiple k-slots, count once)
    expert_mask = expert_mask.clamp(max=1.0)

    # f_i = mean of mask across tokens (fraction of tokens routed to expert i)
    f = expert_mask.mean(dim=0)  # [num_experts]

    # P_i = mean of router probability across tokens for expert i
    P = router_probs.mean(dim=0)  # [num_experts]

    # Loss = N * sum(f_i * P_i)
    aux_loss = num_experts * (f * P).sum()

    return aux_loss

Why This Loss Works

The product fiPif_i \cdot P_i is key. Consider what happens when expert jj receives a disproportionate share of tokens:

  • fjf_j is large (many tokens routed to expert jj)
  • PjP_j is large (the router assigns high probability to expert jj)
  • The product fjPjf_j \cdot P_j dominates the sum

The gradient of the loss penalizes this by pushing PjP_j down: LbalancePj=αNfj\frac{\partial L_\text{balance}}{\partial P_j} = \alpha \cdot N \cdot f_j. Since fjf_j is large, the gradient is strong, reducing the router’s probability for expert jj. Conversely, underutilized experts have small fif_i, so the gradient pushing their probability down is weak.

Σ Lemma: Minimum of the Balance Loss

The auxiliary loss Lbalance=NifiPiL_\text{balance} = N \cdot \sum_i f_i \cdot P_i achieves its minimum value of 1.0 when the load is perfectly balanced: fi=k/Nf_i = k/N and Pi=1/NP_i = 1/N for all ii. Any deviation from uniform distribution increases the loss (by the AM-GM inequality applied to the constraint ifi=k\sum_i f_i = k and iPi=1\sum_i P_i = 1).

Crucially, fif_i is not differentiable (it is a count of discrete routing decisions). Only PiP_i carries gradients. So the loss works by adjusting the router probabilities PiP_i (differentiable) to reduce imbalance in the routing fractions fif_i (non-differentiable but observed). This is why both terms are needed: fif_i measures the actual imbalance, and PiP_i provides the gradient signal to fix it.

DeepSeek’s Alternative: Auxiliary-Loss-Free Balancing

DeepSeek V3 introduced a different approach that avoids the auxiliary loss entirely. Instead, each expert ii has a learnable bias term bib_i added to the router logits:

g(h)i=hwi+big(h)_i = h \cdot w_i + b_i

The bias terms are NOT updated by gradient descent. Instead, they are updated by a simple rule after each training step:

bibi+γsign(target_loadactual_loadi)b_i \leftarrow b_i + \gamma \cdot \text{sign}(\text{target\_load} - \text{actual\_load}_i)

where γ\gamma is a small step size (e.g., 0.001). If expert ii is underloaded, bib_i increases, making it more likely to be selected. If overloaded, bib_i decreases.

def update_expert_biases(
    bias: torch.Tensor,         # [num_experts]
    expert_counts: torch.Tensor, # [num_experts] tokens per expert this step
    target_count: float,         # T * k / N (ideal per-expert count)
    gamma: float = 0.001,
):
    """DeepSeek-style bias update (outside gradient descent)."""
    imbalance = target_count - expert_counts.float()
    bias.data += gamma * imbalance.sign()
💡 Why Remove the Auxiliary Loss?

The auxiliary loss introduces a hyperparameter α\alpha that trades off model quality against load balance. Too large: the model prioritizes balance over performance. Too small: the model collapses to a few experts. DeepSeek found that bias-based balancing achieves comparable load distribution with no quality-balance tradeoff — the bias only affects routing decisions, not the loss landscape. Their ablation showed 0.3-0.5 points better on benchmarks vs. auxiliary loss with optimal α\alpha.


5. Capacity Factor

The Overflow Problem

Even with load balancing, token distribution across experts is never perfectly uniform. In any given batch, some experts will receive more tokens than others. Without a hard cap, a single expert might receive 3-4x its expected share, causing memory spikes and compute imbalance across GPUs in distributed settings.

The capacity factor CC sets a hard upper bound on how many tokens each expert can process:

expert_capacity=C×T×kN\text{expert\_capacity} = C \times \frac{T \times k}{N}

where TT is the total number of tokens, kk is top-k, and NN is the number of experts. The term T×k/NT \times k / N is the expected number of tokens per expert under uniform distribution.

Tokens that exceed an expert’s capacity are dropped — they receive zero contribution from that expert. If a token’s primary expert is at capacity and its secondary expert is also at capacity, the token is effectively skipped by the MoE layer (only receiving the residual connection).

Capacity Factor Values

📊

Capacity Factor Tradeoffs

CFExpert Capacity (T=4096, k=2, N=8)Dropped Tokens (typical)Memory OverheadUse Case
1.0 1024 5-15% Minimal Memory-constrained inference
1.25 1280 1-3% +25% Standard training (Switch Transformer)
1.5 1536 <0.5% +50% Conservative training
2.0 2048 ~0% +100% Debugging / no-drop baseline
Note: Dropped token rates measured on typical language modeling batches with aux loss balance_coeff=0.01.

With C=1.0C = 1.0, the buffer is exactly the expected size. Any deviation from perfect uniformity causes drops. With C=1.25C = 1.25 (the Switch Transformer default), there is 25% headroom — enough to absorb natural variation in most batches.

Implementation

def apply_capacity_factor(
    expert_indices: torch.Tensor,   # [num_tokens, top_k]
    gate_weights: torch.Tensor,     # [num_tokens, top_k]
    num_experts: int,
    capacity_factor: float,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Enforce capacity factor by masking tokens that exceed expert capacity.

    Returns:
        masked_indices: expert indices with overflow set to -1
        masked_weights: gate weights with overflow set to 0
        dropped_mask: [num_tokens] boolean, True if ALL experts for this token were full
    """
    num_tokens, top_k = expert_indices.shape
    # Compute capacity per expert
    tokens_per_expert = (num_tokens * top_k) / num_experts
    expert_capacity = int(capacity_factor * tokens_per_expert)

    # Track how many tokens each expert has received so far
    expert_counts = torch.zeros(num_experts, dtype=torch.long,
                                device=expert_indices.device)

    # Create output tensors
    masked_indices = expert_indices.clone()
    masked_weights = gate_weights.clone()

    # Process tokens in order, respecting capacity
    # Note: in practice this is vectorized, shown as a loop for clarity
    for t in range(num_tokens):
        for k in range(top_k):
            expert_id = expert_indices[t, k].item()
            if expert_counts[expert_id] < expert_capacity:
                expert_counts[expert_id] += 1
            else:
                # Expert is full -- drop this assignment
                masked_indices[t, k] = -1
                masked_weights[t, k] = 0.0

    # Re-normalize remaining weights
    weight_sum = masked_weights.sum(dim=-1, keepdim=True).clamp(min=1e-9)
    masked_weights = masked_weights / weight_sum

    # Identify fully-dropped tokens (all experts were full)
    dropped_mask = (masked_weights.sum(dim=-1) < 1e-6)

    return masked_indices, masked_weights, dropped_mask
⚠️ Dropped Tokens in Training vs Inference

Dropped tokens are a training concern. During inference, if using continuous batching with small batch sizes, the capacity factor is rarely hit because fewer tokens compete for each expert. During training with large batches (thousands of tokens), dropping is common at C=1.0C = 1.0. Monitoring the drop rate is critical: a sustained rate above 5% means the model is losing gradient signal on those tokens.

A vectorized implementation avoids the Python loop by using cumsum to count expert assignments:

def apply_capacity_factor_vectorized(
    expert_indices: torch.Tensor,   # [num_tokens, top_k]
    gate_weights: torch.Tensor,     # [num_tokens, top_k]
    num_experts: int,
    capacity_factor: float,
) -> tuple[torch.Tensor, torch.Tensor]:
    """Vectorized capacity enforcement using cumulative sums."""
    num_tokens, top_k = expert_indices.shape
    expert_capacity = int(capacity_factor * num_tokens * top_k / num_experts)

    # One-hot encode expert assignments: [num_tokens * top_k, num_experts]
    flat_indices = expert_indices.view(-1)
    one_hot = F.one_hot(flat_indices, num_experts).float()  # [T*k, N]

    # Cumulative count per expert (in token order)
    cumsum = one_hot.cumsum(dim=0)  # [T*k, N]

    # For each (token, k) pair, get the cumulative count for its assigned expert
    # This tells us: "this is the Xth token assigned to this expert"
    position_in_expert = cumsum.gather(
        1, flat_indices.unsqueeze(1)
    ).squeeze(1)  # [T*k]

    # Mask: keep only if position <= capacity
    keep_mask = (position_in_expert <= expert_capacity).view(num_tokens, top_k)

    # Apply mask to weights
    masked_weights = gate_weights * keep_mask.float()
    weight_sum = masked_weights.sum(dim=-1, keepdim=True).clamp(min=1e-9)
    masked_weights = masked_weights / weight_sum

    return expert_indices, masked_weights

The vectorized version replaces the O(T×k)O(T \times k) Python loop with tensor operations that run entirely on the GPU. For T=4096,k=2T = 4096, k = 2: the loop version takes ~12 ms in Python; the vectorized version takes ~0.08 ms on an A100.


6. Complete Implementation

Here is the full, runnable MoE layer. It combines all components from the previous sections into a single module that you can drop into a transformer training loop.

Expert FFN: SwiGLU

class SwiGLUExpert(nn.Module):
    """
    Single expert FFN using SwiGLU activation.
    SwiGLU(x) = (W_gate @ x * silu(W_up @ x)) @ W_down

    Parameter count per expert: 3 * d_model * d_ff
    """
    def __init__(self, d_model: int, d_ff: int):
        super().__init__()
        self.w_gate = nn.Linear(d_model, d_ff, bias=False)
        self.w_up = nn.Linear(d_model, d_ff, bias=False)
        self.w_down = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [num_tokens, d_model]
        return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))

Complete MoE Layer

class MoELayer(nn.Module):
    """
    Complete Mixture of Experts layer with:
    - Top-k gating with softmax routing
    - SwiGLU expert FFNs
    - Load-balancing auxiliary loss
    - Capacity factor enforcement
    - Permutation-based token dispatch

    Drop-in replacement for a standard FFN block in a transformer.

    Args:
        d_model: Model hidden dimension (e.g., 4096)
        d_ff: Expert FFN intermediate dimension (e.g., 11008)
        num_experts: Number of expert FFNs (e.g., 8)
        top_k: Number of experts per token (e.g., 2)
        capacity_factor: Max tokens per expert = CF * (T*k/N). Set 0 to disable.
        balance_coeff: Weight of the auxiliary load-balancing loss (e.g., 0.01)
    """

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        num_experts: int = 8,
        top_k: int = 2,
        capacity_factor: float = 1.25,
        balance_coeff: float = 0.01,
    ):
        super().__init__()
        self.d_model = d_model
        self.num_experts = num_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor
        self.balance_coeff = balance_coeff

        # Router: linear projection from d_model to num_experts
        self.router = nn.Linear(d_model, num_experts, bias=False)

        # Expert FFNs
        self.experts = nn.ModuleList([
            SwiGLUExpert(d_model, d_ff) for _ in range(num_experts)
        ])

    def _compute_routing(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Compute top-k routing decisions.

        Args:
            x: [num_tokens, d_model]
        Returns:
            gate_weights: [num_tokens, top_k] (re-normalized softmax weights)
            expert_indices: [num_tokens, top_k] (expert assignments)
            router_probs: [num_tokens, num_experts] (full softmax, for aux loss)
        """
        # Router logits
        logits = self.router(x)  # [num_tokens, num_experts]

        # Full softmax (needed for load-balancing loss)
        router_probs = F.softmax(logits, dim=-1)  # [num_tokens, num_experts]

        # Top-k selection
        topk_weights, topk_indices = torch.topk(
            router_probs, self.top_k, dim=-1
        )

        # Re-normalize over selected experts
        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

        return topk_weights, topk_indices, router_probs

    def _compute_aux_loss(
        self,
        router_probs: torch.Tensor,
        expert_indices: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute load-balancing auxiliary loss.
        L = balance_coeff * N * sum(f_i * P_i)
        """
        num_tokens = router_probs.shape[0]

        # f_i: fraction of tokens routed to each expert
        expert_mask = torch.zeros(
            num_tokens, self.num_experts,
            device=router_probs.device, dtype=router_probs.dtype,
        )
        for k in range(self.top_k):
            expert_mask.scatter_(
                1, expert_indices[:, k:k+1],
                torch.ones(num_tokens, 1, device=router_probs.device,
                           dtype=router_probs.dtype),
            )
        f = expert_mask.float().mean(dim=0)  # [num_experts]

        # P_i: average router probability per expert
        P = router_probs.mean(dim=0)  # [num_experts]

        # Auxiliary loss
        aux_loss = self.balance_coeff * self.num_experts * (f * P).sum()
        return aux_loss

    def _apply_capacity(
        self,
        expert_indices: torch.Tensor,
        gate_weights: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Enforce capacity factor. Returns modified weights (overflow tokens get 0).
        """
        if self.capacity_factor <= 0:
            return expert_indices, gate_weights

        num_tokens, top_k = expert_indices.shape
        expert_capacity = int(
            self.capacity_factor * num_tokens * top_k / self.num_experts
        )

        # Flatten and one-hot encode
        flat_indices = expert_indices.view(-1)
        one_hot = F.one_hot(flat_indices, self.num_experts).float()

        # Cumulative count per expert
        cumsum = one_hot.cumsum(dim=0)
        position_in_expert = cumsum.gather(
            1, flat_indices.unsqueeze(1)
        ).squeeze(1)

        # Keep mask: only tokens within capacity
        keep = (position_in_expert <= expert_capacity).view(num_tokens, top_k)

        # Mask weights and re-normalize
        masked_weights = gate_weights * keep.float()
        weight_sum = masked_weights.sum(dim=-1, keepdim=True).clamp(min=1e-9)
        masked_weights = masked_weights / weight_sum

        return expert_indices, masked_weights

    def _dispatch_and_combine(
        self,
        x: torch.Tensor,
        gate_weights: torch.Tensor,
        expert_indices: torch.Tensor,
    ) -> torch.Tensor:
        """
        Permutation-based token dispatch:
        1. Repeat tokens for each top-k slot
        2. Sort by expert assignment
        3. Process contiguous blocks per expert
        4. Weight, unsort, and sum
        """
        num_tokens, d_model = x.shape
        top_k = expert_indices.shape[1]

        # Flatten: each token appears top_k times
        flat_indices = expert_indices.view(-1)          # [T * k]
        flat_weights = gate_weights.view(-1)            # [T * k]
        x_rep = x.repeat_interleave(top_k, dim=0)      # [T * k, d_model]

        # Sort by expert index
        sort_order = flat_indices.argsort(stable=True)
        sorted_tokens = x_rep[sort_order]               # [T * k, d_model]
        sorted_experts = flat_indices[sort_order]        # [T * k]
        sorted_weights = flat_weights[sort_order]        # [T * k]

        # Compute per-expert token counts and boundaries
        expert_counts = torch.bincount(
            sorted_experts, minlength=self.num_experts
        )

        # Process each expert's token block
        output_parts = []
        offset = 0
        for i in range(self.num_experts):
            count = expert_counts[i].item()
            if count > 0:
                expert_input = sorted_tokens[offset:offset + count]
                expert_output = self.experts[i](expert_input)
                output_parts.append(expert_output)
            else:
                # No tokens for this expert -- append empty tensor
                output_parts.append(
                    torch.empty(0, d_model, device=x.device, dtype=x.dtype)
                )
            offset += count

        # Concatenate all expert outputs (in sorted order)
        sorted_output = torch.cat(output_parts, dim=0)  # [T * k, d_model]

        # Apply gate weights
        sorted_output = sorted_output * sorted_weights.unsqueeze(-1)

        # Unsort back to original token order
        unsort_order = sort_order.argsort()
        output_flat = sorted_output[unsort_order]        # [T * k, d_model]

        # Sum over top-k contributions for each token
        output = output_flat.view(num_tokens, top_k, d_model).sum(dim=1)

        return output

    def forward(
        self, x: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Full MoE forward pass.

        Args:
            x: [batch_size, seq_len, d_model]
        Returns:
            output: [batch_size, seq_len, d_model]
            aux_loss: scalar tensor (add to training loss)
        """
        batch_size, seq_len, d_model = x.shape

        # Flatten batch and sequence dimensions
        x_flat = x.view(-1, d_model)  # [B * S, d_model]

        # Step 1: Compute routing
        gate_weights, expert_indices, router_probs = self._compute_routing(x_flat)

        # Step 2: Compute auxiliary loss (before capacity masking)
        aux_loss = self._compute_aux_loss(router_probs, expert_indices)

        # Step 3: Apply capacity factor (masks overflow tokens)
        expert_indices, gate_weights = self._apply_capacity(
            expert_indices, gate_weights
        )

        # Step 4: Dispatch tokens to experts, compute, and combine
        output_flat = self._dispatch_and_combine(
            x_flat, gate_weights, expert_indices
        )

        # Reshape back to [batch_size, seq_len, d_model]
        output = output_flat.view(batch_size, seq_len, d_model)

        return output, aux_loss

Verification: Shapes and Smoke Test

if __name__ == "__main__":
    # Configuration (Mixtral-scale)
    d_model = 4096
    d_ff = 14336
    num_experts = 8
    top_k = 2
    batch_size = 2
    seq_len = 512

    # Create layer
    moe = MoELayer(
        d_model=d_model,
        d_ff=d_ff,
        num_experts=num_experts,
        top_k=top_k,
        capacity_factor=1.25,
        balance_coeff=0.01,
    ).cuda()

    # Forward pass
    x = torch.randn(batch_size, seq_len, d_model, device="cuda")
    output, aux_loss = moe(x)

    # Verify shapes
    assert output.shape == (batch_size, seq_len, d_model), \
        f"Expected {(batch_size, seq_len, d_model)}, got {output.shape}"
    assert aux_loss.shape == (), f"Expected scalar, got {aux_loss.shape}"

    # Verify aux_loss is reasonable (should be close to 1.0 for balanced routing)
    print(f"Output shape: {output.shape}")
    print(f"Aux loss: {aux_loss.item():.4f}")
    print(f"Output norm: {output.norm().item():.4f}")

    # Check gradient flow
    loss = output.sum() + aux_loss
    loss.backward()
    print(f"Router grad norm: {moe.router.weight.grad.norm().item():.6f}")
    print(f"Expert[0] w_gate grad norm: "
          f"{moe.experts[0].w_gate.weight.grad.norm().item():.6f}")
    print("All checks passed.")

Parameter and FLOP Accounting

📊

MoE Layer Resource Budget (Mixtral-Scale: d=4096, d_ff=14336, N=8, k=2)

ComponentParametersFLOPs per TokenBytes (FP16)
Router (W_gate) 4096 x 8 = 32,768 65,536 64 KB
Per Expert (SwiGLU) 3 x 4096 x 14336 = 176.2M 352.3M (activated) 336 MB
All 8 Experts (total) 1.41B 352.3M x 2 = 704.6M 2.69 GB
Equivalent Dense FFN 176.2M 352.3M 336 MB
MoE Layer (total) 1.41B 704.7M 2.69 GB
Note: FLOPs per token: each token activates k=2 experts. Total parameters = 8x dense, activated FLOPs = 2x dense. Memory holds all 8 experts.

The critical insight in this table: the MoE layer stores 8x the parameters of a dense FFN (2.69 GB vs 336 MB), but each token’s FLOPs are only 2x a single expert (because k=2k = 2). This is the fundamental MoE tradeoff: memory scales with NN, compute scales with kk.

For training, the optimizer states (Adam momentum + variance) add another 2×2.69=5.382 \times 2.69 = 5.38 GB in FP32, plus 2.692.69 GB for FP32 master weights. Total memory per MoE layer: \sim10.8 GB. A 32-layer model with MoE on every other layer (16 MoE layers) requires \sim173 GB for the expert weights and optimizer states alone — fitting on 2-3 A100-80GB GPUs with expert parallelism.

Training Loop Integration

# Example: integrating MoE into a training loop

model = TransformerWithMoE(...)  # Your model with MoE layers
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for batch in dataloader:
    input_ids = batch["input_ids"].cuda()
    labels = batch["labels"].cuda()

    # Forward pass -- model returns main logits + aggregated aux loss
    logits, total_aux_loss = model(input_ids)

    # Main language modeling loss
    lm_loss = F.cross_entropy(
        logits.view(-1, vocab_size), labels.view(-1)
    )

    # Total loss = LM loss + auxiliary balance loss
    # The aux_loss is already scaled by balance_coeff inside MoELayer
    total_loss = lm_loss + total_aux_loss

    total_loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    # Monitor: log both losses separately
    print(f"LM loss: {lm_loss.item():.4f}, "
          f"Aux loss: {total_aux_loss.item():.4f}, "
          f"Ratio: {total_aux_loss.item() / lm_loss.item():.4f}")
Monitoring Expert Utilization

Track two metrics during training: (1) the coefficient of variation of expert counts per batch: CV=σ(counts)/μ(counts)\text{CV} = \sigma(counts) / \mu(counts). Healthy range: 0.05-0.15. Above 0.3 means severe imbalance. (2) The token drop rate: fraction of tokens where all kk expert slots were capacity-masked. Healthy: below 2%. Above 5% means CC is too low or balancing is failing.

Extending to Production

The implementation above is functionally complete but uses a Python loop over experts in _dispatch_and_combine. To reach production performance:

  1. Replace the expert loop with grouped GEMM: Use megablocks.ops.gmm or write a Triton kernel. This single change gives 2-4x speedup.

  2. Add expert parallelism: Shard experts across GPUs with all-to-all communication. Each GPU holds N/num_gpusN / \text{num\_gpus} experts and dispatches tokens via torch.distributed.all_to_all.

  3. Fuse the router and capacity logic: The routing, capacity enforcement, and permutation sort can be fused into a single Triton kernel, eliminating 3-4 intermediate tensor materializations.

  4. Add jitter noise during training: Small random noise added to the router logits prevents early commitment and improves exploration:

if self.training:
    noise = torch.randn_like(logits) * 0.01
    logits = logits + noise

MoE Layer Optimization Impact (A100, B=4, S=2048, d=4096, N=8, k=2)

(ms)
Vanilla PyTorch (this impl) Python loop over experts
18.4 ms
+ Grouped GEMM (Triton) Single kernel for all experts
6.2 ms
+ Fused routing Fused sort + capacity + dispatch
4.8 ms
+ Expert parallel (2 GPU) 4 experts per GPU + all-to-all
3.1 ms

The vanilla PyTorch implementation in this post is suitable for prototyping, debugging, and understanding the MoE computation. For training runs beyond a few hundred GPU-hours, invest in the grouped GEMM path. The complete implementation above is correct and will produce the same gradients as any optimized version — only the wall-clock time changes.


Summary

The MoE layer has five components, each with a specific purpose:

  1. Router: A linear projection + softmax + top-k that assigns each token to kk experts. Cost: negligible (\sim0.01% of layer FLOPs).

  2. Gating weights: Softmax probabilities for the selected experts, re-normalized to sum to 1. These provide the differentiable signal for training the router via the straight-through estimator.

  3. Token dispatch: Sorting tokens by expert assignment and processing contiguous blocks. The grouped GEMM approach processes all experts in a single kernel launch, achieving 80-90% GPU utilization.

  4. Load-balancing loss: L=αNifiPiL = \alpha \cdot N \cdot \sum_i f_i \cdot P_i penalizes routing imbalance by pushing router probabilities toward uniform distribution. Without it, expert collapse occurs within the first few hundred training steps.

  5. Capacity factor: A hard cap of C×Tk/NC \times T \cdot k / N tokens per expert. Standard value: C=1.25C = 1.25. Prevents memory spikes and ensures bounded compute per expert in distributed settings.

The complete MoELayer class above is 120 lines of core logic, runs on vanilla PyTorch, and produces correct gradients for all components. Copy it, run the smoke test, and start training.

Part 2 of this series covers expert parallelism: how to shard experts across GPUs, implement the all-to-all dispatch, overlap communication with computation, and handle the interaction between capacity factor and distributed token routing.