Learning How Mixture of Experts Works: A Code Walkthrough

⬅️ Back to Projects

Overview

This article teaches how Mixture of Experts (MoE) works in a practical, top-down way:

  1. Start from the big picture (what problem MoE solves in Transformers).
  2. Use the GShard paper as the conceptual anchor.
  3. Move quickly into code and verify each idea against implementation details.

The key idea is simple: instead of running one dense feed-forward network for every token, an MoE layer uses a router (gating network) to send each token to a small subset of experts (often top-2), which increases model capacity without linearly increasing compute per token.

This walkthrough uses the original GShard paper and the referenced video learning strategy, then teaches the complete mixture_of_experts.py implementation step by step.

Learning Strategy (From the Video)

The video’s strategy is strong for fast understanding:

  • Use minimal prerequisites: basic Transformer diagram intuition and familiarity with nn.Module.
  • Read with priorities: find the highest-value parts first (usually architecture diagrams and routing equations).
  • Read by comparison: compare baseline Transformer FFN vs MoE replacement.
  • Jump to code early: verify conceptual understanding quickly using implementation.

This avoids getting blocked by every detail too early. You build an 80% correct mental model first, then tighten it with code and equations.

Core References

Video embed:

Complete Code for reference

import torch
from torch import nn
import torch.nn.functional as F

import math
from inspect import isfunction

# constants

MIN_EXPERT_CAPACITY = 4

# helper functions

def default(val, default_val):
    default_val = default_val() if isfunction(default_val) else default_val
    return val if val is not None else default_val

def cast_tuple(el):
    return el if isinstance(el, tuple) else (el,)

# tensor related helper functions

def top1(t):
    values, index = t.topk(k=1, dim=-1)
    values, index = map(lambda x: x.squeeze(dim=-1), (values, index))
    return values, index

def cumsum_exclusive(t, dim=-1):
    num_dims = len(t.shape)
    num_pad_dims = - dim - 1
    pre_padding = (0, 0) * num_pad_dims
    pre_slice   = (slice(None),) * num_pad_dims
    padded_t = F.pad(t, (*pre_padding, 1, 0)).cumsum(dim=dim)
    return padded_t[(..., slice(None, -1), *pre_slice)]

# pytorch one hot throws an error if there are out of bound indices.
# tensorflow, in contrast, does not throw an error
def safe_one_hot(indexes, max_length):
    max_index = indexes.max() + 1
    return F.one_hot(indexes, max(max_index + 1, max_length))[..., :max_length]

def init_(t):
    dim = t.shape[-1]
    std = 1 / math.sqrt(dim)
    return t.uniform_(-std, std)

# activations

class GELU_(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))

GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_

# expert class

class Experts(nn.Module):
    def __init__(self,
        dim,
        num_experts = 16,
        hidden_dim = None,
        activation = GELU):
        super().__init__()

        hidden_dim = default(hidden_dim, dim * 4)
        num_experts = cast_tuple(num_experts)

        w1 = torch.zeros(*num_experts, dim, hidden_dim)
        w2 = torch.zeros(*num_experts, hidden_dim, dim)

        w1 = init_(w1)
        w2 = init_(w2)

        self.w1 = nn.Parameter(w1)
        self.w2 = nn.Parameter(w2)
        self.act = activation()

    def forward(self, x):
        hidden = torch.einsum('...nd,...dh->...nh', x, self.w1)
        hidden = self.act(hidden)
        out    = torch.einsum('...nh,...hd->...nd', hidden, self.w2)
        return out

# the below code is almost all transcribed from the official tensorflow version, from which the papers are written
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/research/moe.py

# gating network

class Top2Gating(nn.Module):
    def __init__(
        self,
        dim,
        num_gates,
        eps = 1e-9,
        outer_expert_dims = tuple(),
        second_policy_train = 'random',
        second_policy_eval = 'random',
        second_threshold_train = 0.2,
        second_threshold_eval = 0.2,
        capacity_factor_train = 1.25,
        capacity_factor_eval = 2.):
        super().__init__()

        self.eps = eps
        self.num_gates = num_gates
        self.w_gating = nn.Parameter(torch.randn(*outer_expert_dims, dim, num_gates))

        self.second_policy_train = second_policy_train
        self.second_policy_eval = second_policy_eval
        self.second_threshold_train = second_threshold_train
        self.second_threshold_eval = second_threshold_eval
        self.capacity_factor_train = capacity_factor_train
        self.capacity_factor_eval = capacity_factor_eval

    def forward(self, x, importance = None):
        *_, b, group_size, dim = x.shape
        num_gates = self.num_gates

        if self.training:
            policy = self.second_policy_train
            threshold = self.second_threshold_train
            capacity_factor = self.capacity_factor_train
        else:
            policy = self.second_policy_eval
            threshold = self.second_threshold_eval
            capacity_factor = self.capacity_factor_eval

        raw_gates = torch.einsum('...bnd,...de->...bne', x, self.w_gating)
        raw_gates = raw_gates.softmax(dim=-1)

        # FIND TOP 2 EXPERTS PER POSITON
        # Find the top expert for each position. shape=[batch, group]

        gate_1, index_1 = top1(raw_gates)
        mask_1 = F.one_hot(index_1, num_gates).float()
        density_1_proxy = raw_gates

        if importance is not None:
            equals_one_mask = (importance == 1.).float()
            mask_1 *= equals_one_mask[..., None]
            gate_1 *= equals_one_mask
            density_1_proxy *= equals_one_mask[..., None]
            del equals_one_mask

        gates_without_top_1 = raw_gates * (1. - mask_1)

        gate_2, index_2 = top1(gates_without_top_1)
        mask_2 = F.one_hot(index_2, num_gates).float()

        if importance is not None:
            greater_zero_mask = (importance > 0.).float()
            mask_2 *= greater_zero_mask[..., None]
            del greater_zero_mask

        # normalize top2 gate scores
        denom = gate_1 + gate_2 + self.eps
        gate_1 /= denom
        gate_2 /= denom

        # BALANCING LOSSES
        # shape = [batch, experts]
        # We want to equalize the fraction of the batch assigned to each expert
        density_1 = mask_1.mean(dim=-2)
        # Something continuous that is correlated with what we want to equalize.
        density_1_proxy = density_1_proxy.mean(dim=-2)
        loss = (density_1_proxy * density_1).mean() * float(num_gates ** 2)

        # Depending on the policy in the hparams, we may drop out some of the
        # second-place experts.
        if policy == "all":
            pass
        elif policy == "none":
            mask_2 = torch.zeros_like(mask_2)
        elif policy == "threshold":
            mask_2 *= (gate_2 > threshold).float()
        elif policy == "random":
            probs = torch.zeros_like(gate_2).uniform_(0., 1.)
            mask_2 *= (probs < (gate_2 / max(threshold, self.eps))).float().unsqueeze(-1)
        else:
            raise ValueError(f"Unknown policy {policy}")

        # Each sequence sends (at most?) expert_capacity positions to each expert.
        # Static expert_capacity dimension is needed for expert batch sizes
        expert_capacity = min(group_size, int((group_size * capacity_factor) / num_gates))
        expert_capacity = max(expert_capacity, MIN_EXPERT_CAPACITY)
        expert_capacity_f = float(expert_capacity)

        # COMPUTE ASSIGNMENT TO EXPERTS
        # [batch, group, experts]
        # This is the position within the expert's mini-batch for this sequence
        position_in_expert_1 = cumsum_exclusive(mask_1, dim=-2) * mask_1
        # Remove the elements that don't fit. [batch, group, experts]
        mask_1 *= (position_in_expert_1 < expert_capacity_f).float()
        # [batch, experts]
        # How many examples in this sequence go to this expert
        mask_1_count = mask_1.sum(dim=-2, keepdim=True)
        # [batch, group] - mostly ones, but zeros where something didn't fit
        mask_1_flat = mask_1.sum(dim=-1)
        # [batch, group]
        position_in_expert_1 = position_in_expert_1.sum(dim=-1)
        # Weight assigned to first expert.  [batch, group]
        gate_1 *= mask_1_flat

        position_in_expert_2 = cumsum_exclusive(mask_2, dim=-2) + mask_1_count
        position_in_expert_2 *= mask_2
        mask_2 *= (position_in_expert_2 < expert_capacity_f).float()
        mask_2_flat = mask_2.sum(dim=-1)

        position_in_expert_2 = position_in_expert_2.sum(dim=-1)
        gate_2 *= mask_2_flat

        # [batch, group, experts, expert_capacity]
        combine_tensor = (
            gate_1[..., None, None]
            * mask_1_flat[..., None, None]
            * F.one_hot(index_1, num_gates)[..., None]
            * safe_one_hot(position_in_expert_1.long(), expert_capacity)[..., None, :] +
            gate_2[..., None, None]
            * mask_2_flat[..., None, None]
            * F.one_hot(index_2, num_gates)[..., None]
            * safe_one_hot(position_in_expert_2.long(), expert_capacity)[..., None, :]
        )

        dispatch_tensor = combine_tensor.bool().to(combine_tensor)
        return dispatch_tensor, combine_tensor, loss

# plain mixture of experts

class MoE(nn.Module):
    def __init__(self,
        dim,
        num_experts = 16,
        hidden_dim = None,
        activation = nn.ReLU,
        second_policy_train = 'random',
        second_policy_eval = 'random',
        second_threshold_train = 0.2,
        second_threshold_eval = 0.2,
        capacity_factor_train = 1.25,
        capacity_factor_eval = 2.,
        loss_coef = 1e-2,
        experts = None):
        super().__init__()

        self.num_experts = num_experts

        gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}
        self.gate = Top2Gating(dim, num_gates = num_experts, **gating_kwargs)
        self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
        self.loss_coef = loss_coef

    def forward(self, inputs, **kwargs):
        b, n, d, e = *inputs.shape, self.num_experts
        dispatch_tensor, combine_tensor, loss = self.gate(inputs)
        expert_inputs = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor)

        # Now feed the expert inputs through the experts.
        orig_shape = expert_inputs.shape
        expert_inputs = expert_inputs.reshape(e, -1, d)
        expert_outputs = self.experts(expert_inputs)
        expert_outputs = expert_outputs.reshape(*orig_shape)

        output = torch.einsum('ebcd,bnec->bnd', expert_outputs, combine_tensor)
        return output, loss * self.loss_coef

# 2-level heirarchical mixture of experts

class HeirarchicalMoE(nn.Module):
    def __init__(self,
        dim,
        num_experts = (4, 4),
        hidden_dim = None,
        activation = nn.ReLU,
        second_policy_train = 'random',
        second_policy_eval = 'random',
        second_threshold_train = 0.2,
        second_threshold_eval = 0.2,
        capacity_factor_train = 1.25,
        capacity_factor_eval = 2.,
        loss_coef = 1e-2,
        experts = None):
        super().__init__()

        assert len(num_experts) == 2, 'only 2 levels of heirarchy for experts allowed for now'
        num_experts_outer, num_experts_inner = num_experts
        self.num_experts_outer = num_experts_outer
        self.num_experts_inner = num_experts_inner

        gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}

        self.gate_outer = Top2Gating(dim, num_gates = num_experts_outer, **gating_kwargs)
        self.gate_inner = Top2Gating(dim, num_gates = num_experts_inner, outer_expert_dims = (num_experts_outer,), **gating_kwargs)

        self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
        self.loss_coef = loss_coef

    def forward(self, inputs, **kwargs):
        b, n, d, eo, ei = *inputs.shape, self.num_experts_outer, self.num_experts_inner
        dispatch_tensor_outer, combine_tensor_outer, loss_outer = self.gate_outer(inputs)
        expert_inputs_outer = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor_outer)

        # we construct an "importance" Tensor for the inputs to the second-level
        # gating.  The importance of an input is 1.0 if it represents the
        # first-choice expert-group and 0.5 if it represents the second-choice expert
        # group.  This is used by the second-level gating.
        importance = combine_tensor_outer.permute(2, 0, 3, 1).sum(dim=-1)
        importance = 0.5 * ((importance > 0.5).float() + (importance > 0.).float())

        dispatch_tensor_inner, combine_tensor_inner, loss_inner = self.gate_inner(expert_inputs_outer, importance = importance)
        expert_inputs = torch.einsum('ebnd,ebnfc->efbcd', expert_inputs_outer, dispatch_tensor_inner)

        # Now feed the expert inputs through the experts.
        orig_shape = expert_inputs.shape
        expert_inputs = expert_inputs.reshape(eo, ei, -1, d)
        expert_outputs = self.experts(expert_inputs)
        expert_outputs = expert_outputs.reshape(*orig_shape)

        # NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
        # expert_output has shape [y0, x1, h, d, n]

        expert_outputs_outer = torch.einsum('efbcd,ebnfc->ebnd', expert_outputs, combine_tensor_inner)
        output = torch.einsum('ebcd,bnec->bnd', expert_outputs_outer, combine_tensor_outer)
        return output, (loss_outer + loss_inner) * self.loss_coef

Code Walkthrough: Step by Step

1) Helper functions: setup and defensive defaults

  • default(val, default_val): if val is None, use a fallback (evaluated lazily if it is a function).
  • cast_tuple(el): makes sure values like num_experts can be treated uniformly (single value or tuple).
  • init_(t): initializes expert weights with uniform scale based on input dimension.

These functions reduce branching and keep the main classes readable.

2) Tensor utilities: routing building blocks

  • top1(t): returns both top score and top index along the expert dimension.
  • cumsum_exclusive(t, dim=-1): exclusive prefix sum; used to assign each token a slot inside an expert buffer.
  • safe_one_hot(indexes, max_length): robust one-hot for cases where indices can exceed expected bounds.

These are critical for token-to-expert assignment mechanics.

3) Experts: parameterized FFN bank

Experts stores expert parameters in big tensors:

  • w1: [..., dim, hidden_dim]
  • w2: [..., hidden_dim, dim]

Then it runs expert FFNs with two einsums:

  1. hidden = einsum('...nd,...dh->...nh', x, w1)
  2. activation (GELU by default)
  3. out = einsum('...nh,...hd->...nd', hidden, w2)

So each expert is a feed-forward network, and all experts are batched into one parameter structure.

4) Top2Gating.__init__: router configuration

The gate stores:

  • w_gating: linear projection from token dimension to num_gates logits.
  • second-expert policy knobs:
    • second_policy_train/eval: all, none, threshold, or random
    • thresholds and capacity factors for train/eval
  • eps for numerical stability

This matches the top-2 routing spirit from GShard: sparse dispatch with controllable load behavior.

5) Top2Gating.forward: the routing pipeline

Input shape convention from code:

  • x: [..., b, group_size, dim]
  • output tensors:
    • dispatch_tensor: binary-like routing map
    • combine_tensor: weighted routing map used to merge outputs
    • loss: load-balancing auxiliary loss

Flow:

  1. Compute gate probabilities

    • raw_gates = softmax(einsum(x, w_gating)) gives per-token expert probabilities.
  2. Pick top-1 expert

    • gate_1, index_1 = top1(raw_gates)
    • mask_1 is one-hot expert assignment for first choice.
  3. Pick top-2 expert

    • zero out top-1 in raw_gates and run top1 again.
    • build mask_2 for second choice.
  4. Normalize top-2 weights

    • gate_1 and gate_2 are renormalized so they sum to ~1 for active tokens.
  5. Load-balancing auxiliary loss

    • density_1: observed first-choice expert usage.
    • density_1_proxy: soft proxy from probabilities.
    • loss = mean(density_1_proxy * density_1) * num_gates^2.
  6. Second expert policy application

    • all: always keep second expert.
    • none: always drop second expert.
    • threshold: keep if gate_2 > threshold.
    • random: keep probabilistically based on gate_2 / threshold.
  7. Capacity calculation

    • expert_capacity = min(group_size, int(group_size * capacity_factor / num_gates))
    • clamped by MIN_EXPERT_CAPACITY.
  8. Assign positions inside each expert buffer

    • cumsum_exclusive(mask_1) and cumsum_exclusive(mask_2) compute token slot indices.
    • tokens beyond capacity are masked out.
  9. Build combine/dispatch tensors

    • combine_tensor shape: [batch, group, experts, expert_capacity].
    • carries both routing locations and mixing weights.
    • dispatch_tensor = combine_tensor.bool().to(combine_tensor) gives selection mask for scattering inputs.

This is the heart of MoE routing: sparse token dispatch, bounded capacity, and weighted recombination.

6) MoE.forward: single-level MoE execution

Given inputs of shape [b, n, d]:

  1. Router returns dispatch_tensor, combine_tensor, and auxiliary loss.
  2. Dispatch step:
    • expert_inputs = einsum('bnd,bnec->ebcd', inputs, dispatch_tensor)
    • reorganizes tokens by expert and capacity slot.
  3. Experts run FFN transform.
  4. Combine step:
    • output = einsum('ebcd,bnec->bnd', expert_outputs, combine_tensor)
    • merges weighted expert outputs back to token order.
  5. Return (output, loss * loss_coef).

So: route -> process -> merge, with a balancing regularizer.

7) HeirarchicalMoE.forward: two-level routing

This class stacks two gating levels:

  1. Outer routing chooses expert groups.
  2. Build importance from outer combine weights:
    • first-choice-like paths get higher importance,
    • second-choice-like paths get reduced importance (0.5 logic).
  3. Inner routing then routes within each selected outer group.
  4. Experts run at shape [outer_expert, inner_expert, ...].
  5. Combine inner then outer outputs back to [b, n, d].
  6. Final auxiliary loss is (loss_outer + loss_inner) * loss_coef.

This is a hierarchical extension of the same sparse routing principle.

How This Maps to the GShard Paper

The implementation aligns directly with the core GShard MoE concepts from the original paper:

  • Top-2 routing: each token can be sent to up to two experts.
  • Expert capacity: per-expert token budget limits overload.
  • Overflow handling: tokens exceeding capacity are effectively dropped from expert processing path.
  • Auxiliary balancing loss: encourages even expert utilization.
  • Sparse conditional compute: model capacity scales with number of experts, while per-token active compute stays limited.

Paper links for direct reading:

Practical Reading Path

If you are learning this for the first time, follow this order:

  1. Read GShard Section 2.1 and 2.2 (Transformer + MoE layer concept).
  2. Scan Algorithm 1 (top-2 gating + auxiliary loss).
  3. Read this Top2Gating.forward implementation end to end.
  4. Verify dispatch/combine einsums in MoE.forward.
  5. Only then inspect hierarchical routing (HeirarchicalMoE).

This mirrors the “priorities -> comparison -> code early” method from the video.

Common Confusions and Checks

  • Why two tensors (dispatch_tensor and combine_tensor)?

    • dispatch_tensor is for selecting where token inputs go.
    • combine_tensor is for weighted reconstruction of outputs.
  • Why normalize gate_1 and gate_2?

    • To keep top-2 contributions stable and interpretable as local mixture weights.
  • Why can tokens be dropped?

    • Capacity enforces bounded expert load; dropped tokens continue via residual stream in full Transformer contexts.
  • Why use random second policy?

    • It reduces unnecessary second-expert traffic while keeping stochastic exploration and load flexibility.
  • What is the main training stabilizer here?

    • The auxiliary load-balancing term that discourages collapse to only a few experts.

Conclusion

Mixture of Experts is easiest to learn when you connect three layers of understanding:

  1. Architecture intuition (replace dense FFN with routed sparse experts).
  2. Paper mechanism (top-2 gating, capacity, auxiliary load balancing).
  3. Code execution path (dispatch -> expert FFN -> combine).

Using that flow, the MoE implementation stops feeling magical and becomes a clear tensor-routing system you can reason about and modify.

From here, a strong next step is to instrument this code with small synthetic inputs and print tensor shapes at each stage to validate your mental model numerically.

To make that concrete, run tiny experiments where you control everything:

  1. Start with b=2, n=8, d=16, num_experts=4.
  2. Print shapes for raw_gates, mask_1, mask_2, dispatch_tensor, combine_tensor, and expert_inputs.
  3. Force different second-expert policies (none, threshold, random) and compare how many tokens each expert receives.
  4. Lower and raise capacity factors to observe overflow behavior.
  5. Track the auxiliary loss while changing routing behavior.

Once you can predict those outputs before you run them, you truly understand the mechanics.

For broader context, these repositories are worth studying next:

Read them in that order: clarity first, then scale.