Learning How Mixture of Experts Works: A Code Walkthrough
Overview
This article teaches how Mixture of Experts (MoE) works in a practical, top-down way:
- Start from the big picture (what problem MoE solves in Transformers).
- Use the GShard paper as the conceptual anchor.
- Move quickly into code and verify each idea against implementation details.
The key idea is simple: instead of running one dense feed-forward network for every token, an MoE layer uses a router (gating network) to send each token to a small subset of experts (often top-2), which increases model capacity without linearly increasing compute per token.
This walkthrough uses the original GShard paper and the referenced video learning strategy, then teaches the complete mixture_of_experts.py implementation step by step.
Learning Strategy (From the Video)
The video’s strategy is strong for fast understanding:
- Use minimal prerequisites: basic Transformer diagram intuition and familiarity with
nn.Module. - Read with priorities: find the highest-value parts first (usually architecture diagrams and routing equations).
- Read by comparison: compare baseline Transformer FFN vs MoE replacement.
- Jump to code early: verify conceptual understanding quickly using implementation.
This avoids getting blocked by every detail too early. You build an 80% correct mental model first, then tighten it with code and equations.
Core References
Video embed:
Complete Code for reference
import torch
from torch import nn
import torch.nn.functional as F
import math
from inspect import isfunction
# constants
MIN_EXPERT_CAPACITY = 4
# helper functions
def default(val, default_val):
default_val = default_val() if isfunction(default_val) else default_val
return val if val is not None else default_val
def cast_tuple(el):
return el if isinstance(el, tuple) else (el,)
# tensor related helper functions
def top1(t):
values, index = t.topk(k=1, dim=-1)
values, index = map(lambda x: x.squeeze(dim=-1), (values, index))
return values, index
def cumsum_exclusive(t, dim=-1):
num_dims = len(t.shape)
num_pad_dims = - dim - 1
pre_padding = (0, 0) * num_pad_dims
pre_slice = (slice(None),) * num_pad_dims
padded_t = F.pad(t, (*pre_padding, 1, 0)).cumsum(dim=dim)
return padded_t[(..., slice(None, -1), *pre_slice)]
# pytorch one hot throws an error if there are out of bound indices.
# tensorflow, in contrast, does not throw an error
def safe_one_hot(indexes, max_length):
max_index = indexes.max() + 1
return F.one_hot(indexes, max(max_index + 1, max_length))[..., :max_length]
def init_(t):
dim = t.shape[-1]
std = 1 / math.sqrt(dim)
return t.uniform_(-std, std)
# activations
class GELU_(nn.Module):
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
GELU = nn.GELU if hasattr(nn, 'GELU') else GELU_
# expert class
class Experts(nn.Module):
def __init__(self,
dim,
num_experts = 16,
hidden_dim = None,
activation = GELU):
super().__init__()
hidden_dim = default(hidden_dim, dim * 4)
num_experts = cast_tuple(num_experts)
w1 = torch.zeros(*num_experts, dim, hidden_dim)
w2 = torch.zeros(*num_experts, hidden_dim, dim)
w1 = init_(w1)
w2 = init_(w2)
self.w1 = nn.Parameter(w1)
self.w2 = nn.Parameter(w2)
self.act = activation()
def forward(self, x):
hidden = torch.einsum('...nd,...dh->...nh', x, self.w1)
hidden = self.act(hidden)
out = torch.einsum('...nh,...hd->...nd', hidden, self.w2)
return out
# the below code is almost all transcribed from the official tensorflow version, from which the papers are written
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/research/moe.py
# gating network
class Top2Gating(nn.Module):
def __init__(
self,
dim,
num_gates,
eps = 1e-9,
outer_expert_dims = tuple(),
second_policy_train = 'random',
second_policy_eval = 'random',
second_threshold_train = 0.2,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.):
super().__init__()
self.eps = eps
self.num_gates = num_gates
self.w_gating = nn.Parameter(torch.randn(*outer_expert_dims, dim, num_gates))
self.second_policy_train = second_policy_train
self.second_policy_eval = second_policy_eval
self.second_threshold_train = second_threshold_train
self.second_threshold_eval = second_threshold_eval
self.capacity_factor_train = capacity_factor_train
self.capacity_factor_eval = capacity_factor_eval
def forward(self, x, importance = None):
*_, b, group_size, dim = x.shape
num_gates = self.num_gates
if self.training:
policy = self.second_policy_train
threshold = self.second_threshold_train
capacity_factor = self.capacity_factor_train
else:
policy = self.second_policy_eval
threshold = self.second_threshold_eval
capacity_factor = self.capacity_factor_eval
raw_gates = torch.einsum('...bnd,...de->...bne', x, self.w_gating)
raw_gates = raw_gates.softmax(dim=-1)
# FIND TOP 2 EXPERTS PER POSITON
# Find the top expert for each position. shape=[batch, group]
gate_1, index_1 = top1(raw_gates)
mask_1 = F.one_hot(index_1, num_gates).float()
density_1_proxy = raw_gates
if importance is not None:
equals_one_mask = (importance == 1.).float()
mask_1 *= equals_one_mask[..., None]
gate_1 *= equals_one_mask
density_1_proxy *= equals_one_mask[..., None]
del equals_one_mask
gates_without_top_1 = raw_gates * (1. - mask_1)
gate_2, index_2 = top1(gates_without_top_1)
mask_2 = F.one_hot(index_2, num_gates).float()
if importance is not None:
greater_zero_mask = (importance > 0.).float()
mask_2 *= greater_zero_mask[..., None]
del greater_zero_mask
# normalize top2 gate scores
denom = gate_1 + gate_2 + self.eps
gate_1 /= denom
gate_2 /= denom
# BALANCING LOSSES
# shape = [batch, experts]
# We want to equalize the fraction of the batch assigned to each expert
density_1 = mask_1.mean(dim=-2)
# Something continuous that is correlated with what we want to equalize.
density_1_proxy = density_1_proxy.mean(dim=-2)
loss = (density_1_proxy * density_1).mean() * float(num_gates ** 2)
# Depending on the policy in the hparams, we may drop out some of the
# second-place experts.
if policy == "all":
pass
elif policy == "none":
mask_2 = torch.zeros_like(mask_2)
elif policy == "threshold":
mask_2 *= (gate_2 > threshold).float()
elif policy == "random":
probs = torch.zeros_like(gate_2).uniform_(0., 1.)
mask_2 *= (probs < (gate_2 / max(threshold, self.eps))).float().unsqueeze(-1)
else:
raise ValueError(f"Unknown policy {policy}")
# Each sequence sends (at most?) expert_capacity positions to each expert.
# Static expert_capacity dimension is needed for expert batch sizes
expert_capacity = min(group_size, int((group_size * capacity_factor) / num_gates))
expert_capacity = max(expert_capacity, MIN_EXPERT_CAPACITY)
expert_capacity_f = float(expert_capacity)
# COMPUTE ASSIGNMENT TO EXPERTS
# [batch, group, experts]
# This is the position within the expert's mini-batch for this sequence
position_in_expert_1 = cumsum_exclusive(mask_1, dim=-2) * mask_1
# Remove the elements that don't fit. [batch, group, experts]
mask_1 *= (position_in_expert_1 < expert_capacity_f).float()
# [batch, experts]
# How many examples in this sequence go to this expert
mask_1_count = mask_1.sum(dim=-2, keepdim=True)
# [batch, group] - mostly ones, but zeros where something didn't fit
mask_1_flat = mask_1.sum(dim=-1)
# [batch, group]
position_in_expert_1 = position_in_expert_1.sum(dim=-1)
# Weight assigned to first expert. [batch, group]
gate_1 *= mask_1_flat
position_in_expert_2 = cumsum_exclusive(mask_2, dim=-2) + mask_1_count
position_in_expert_2 *= mask_2
mask_2 *= (position_in_expert_2 < expert_capacity_f).float()
mask_2_flat = mask_2.sum(dim=-1)
position_in_expert_2 = position_in_expert_2.sum(dim=-1)
gate_2 *= mask_2_flat
# [batch, group, experts, expert_capacity]
combine_tensor = (
gate_1[..., None, None]
* mask_1_flat[..., None, None]
* F.one_hot(index_1, num_gates)[..., None]
* safe_one_hot(position_in_expert_1.long(), expert_capacity)[..., None, :] +
gate_2[..., None, None]
* mask_2_flat[..., None, None]
* F.one_hot(index_2, num_gates)[..., None]
* safe_one_hot(position_in_expert_2.long(), expert_capacity)[..., None, :]
)
dispatch_tensor = combine_tensor.bool().to(combine_tensor)
return dispatch_tensor, combine_tensor, loss
# plain mixture of experts
class MoE(nn.Module):
def __init__(self,
dim,
num_experts = 16,
hidden_dim = None,
activation = nn.ReLU,
second_policy_train = 'random',
second_policy_eval = 'random',
second_threshold_train = 0.2,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.,
loss_coef = 1e-2,
experts = None):
super().__init__()
self.num_experts = num_experts
gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}
self.gate = Top2Gating(dim, num_gates = num_experts, **gating_kwargs)
self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
self.loss_coef = loss_coef
def forward(self, inputs, **kwargs):
b, n, d, e = *inputs.shape, self.num_experts
dispatch_tensor, combine_tensor, loss = self.gate(inputs)
expert_inputs = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor)
# Now feed the expert inputs through the experts.
orig_shape = expert_inputs.shape
expert_inputs = expert_inputs.reshape(e, -1, d)
expert_outputs = self.experts(expert_inputs)
expert_outputs = expert_outputs.reshape(*orig_shape)
output = torch.einsum('ebcd,bnec->bnd', expert_outputs, combine_tensor)
return output, loss * self.loss_coef
# 2-level heirarchical mixture of experts
class HeirarchicalMoE(nn.Module):
def __init__(self,
dim,
num_experts = (4, 4),
hidden_dim = None,
activation = nn.ReLU,
second_policy_train = 'random',
second_policy_eval = 'random',
second_threshold_train = 0.2,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.,
loss_coef = 1e-2,
experts = None):
super().__init__()
assert len(num_experts) == 2, 'only 2 levels of heirarchy for experts allowed for now'
num_experts_outer, num_experts_inner = num_experts
self.num_experts_outer = num_experts_outer
self.num_experts_inner = num_experts_inner
gating_kwargs = {'second_policy_train': second_policy_train, 'second_policy_eval': second_policy_eval, 'second_threshold_train': second_threshold_train, 'second_threshold_eval': second_threshold_eval, 'capacity_factor_train': capacity_factor_train, 'capacity_factor_eval': capacity_factor_eval}
self.gate_outer = Top2Gating(dim, num_gates = num_experts_outer, **gating_kwargs)
self.gate_inner = Top2Gating(dim, num_gates = num_experts_inner, outer_expert_dims = (num_experts_outer,), **gating_kwargs)
self.experts = default(experts, lambda: Experts(dim, num_experts = num_experts, hidden_dim = hidden_dim, activation = activation))
self.loss_coef = loss_coef
def forward(self, inputs, **kwargs):
b, n, d, eo, ei = *inputs.shape, self.num_experts_outer, self.num_experts_inner
dispatch_tensor_outer, combine_tensor_outer, loss_outer = self.gate_outer(inputs)
expert_inputs_outer = torch.einsum('bnd,bnec->ebcd', inputs, dispatch_tensor_outer)
# we construct an "importance" Tensor for the inputs to the second-level
# gating. The importance of an input is 1.0 if it represents the
# first-choice expert-group and 0.5 if it represents the second-choice expert
# group. This is used by the second-level gating.
importance = combine_tensor_outer.permute(2, 0, 3, 1).sum(dim=-1)
importance = 0.5 * ((importance > 0.5).float() + (importance > 0.).float())
dispatch_tensor_inner, combine_tensor_inner, loss_inner = self.gate_inner(expert_inputs_outer, importance = importance)
expert_inputs = torch.einsum('ebnd,ebnfc->efbcd', expert_inputs_outer, dispatch_tensor_inner)
# Now feed the expert inputs through the experts.
orig_shape = expert_inputs.shape
expert_inputs = expert_inputs.reshape(eo, ei, -1, d)
expert_outputs = self.experts(expert_inputs)
expert_outputs = expert_outputs.reshape(*orig_shape)
# NOW COMBINE EXPERT OUTPUTS (reversing everything we have done)
# expert_output has shape [y0, x1, h, d, n]
expert_outputs_outer = torch.einsum('efbcd,ebnfc->ebnd', expert_outputs, combine_tensor_inner)
output = torch.einsum('ebcd,bnec->bnd', expert_outputs_outer, combine_tensor_outer)
return output, (loss_outer + loss_inner) * self.loss_coefCode Walkthrough: Step by Step
1) Helper functions: setup and defensive defaults
default(val, default_val): ifvalisNone, use a fallback (evaluated lazily if it is a function).cast_tuple(el): makes sure values likenum_expertscan be treated uniformly (single value or tuple).init_(t): initializes expert weights with uniform scale based on input dimension.
These functions reduce branching and keep the main classes readable.
2) Tensor utilities: routing building blocks
top1(t): returns both top score and top index along the expert dimension.cumsum_exclusive(t, dim=-1): exclusive prefix sum; used to assign each token a slot inside an expert buffer.safe_one_hot(indexes, max_length): robust one-hot for cases where indices can exceed expected bounds.
These are critical for token-to-expert assignment mechanics.
3) Experts: parameterized FFN bank
Experts stores expert parameters in big tensors:
w1:[..., dim, hidden_dim]w2:[..., hidden_dim, dim]
Then it runs expert FFNs with two einsums:
hidden = einsum('...nd,...dh->...nh', x, w1)- activation (
GELUby default) out = einsum('...nh,...hd->...nd', hidden, w2)
So each expert is a feed-forward network, and all experts are batched into one parameter structure.
4) Top2Gating.__init__: router configuration
The gate stores:
w_gating: linear projection from token dimension tonum_gateslogits.- second-expert policy knobs:
second_policy_train/eval:all,none,threshold, orrandom- thresholds and capacity factors for train/eval
epsfor numerical stability
This matches the top-2 routing spirit from GShard: sparse dispatch with controllable load behavior.
5) Top2Gating.forward: the routing pipeline
Input shape convention from code:
x:[..., b, group_size, dim]- output tensors:
dispatch_tensor: binary-like routing mapcombine_tensor: weighted routing map used to merge outputsloss: load-balancing auxiliary loss
Flow:
Compute gate probabilities
raw_gates = softmax(einsum(x, w_gating))gives per-token expert probabilities.
Pick top-1 expert
gate_1, index_1 = top1(raw_gates)mask_1is one-hot expert assignment for first choice.
Pick top-2 expert
- zero out top-1 in
raw_gatesand runtop1again. - build
mask_2for second choice.
- zero out top-1 in
Normalize top-2 weights
gate_1andgate_2are renormalized so they sum to ~1 for active tokens.
Load-balancing auxiliary loss
density_1: observed first-choice expert usage.density_1_proxy: soft proxy from probabilities.loss = mean(density_1_proxy * density_1) * num_gates^2.
Second expert policy application
all: always keep second expert.none: always drop second expert.threshold: keep ifgate_2 > threshold.random: keep probabilistically based ongate_2 / threshold.
Capacity calculation
expert_capacity = min(group_size, int(group_size * capacity_factor / num_gates))- clamped by
MIN_EXPERT_CAPACITY.
Assign positions inside each expert buffer
cumsum_exclusive(mask_1)andcumsum_exclusive(mask_2)compute token slot indices.- tokens beyond capacity are masked out.
Build combine/dispatch tensors
combine_tensorshape:[batch, group, experts, expert_capacity].- carries both routing locations and mixing weights.
dispatch_tensor = combine_tensor.bool().to(combine_tensor)gives selection mask for scattering inputs.
This is the heart of MoE routing: sparse token dispatch, bounded capacity, and weighted recombination.
6) MoE.forward: single-level MoE execution
Given inputs of shape [b, n, d]:
- Router returns
dispatch_tensor,combine_tensor, and auxiliaryloss. - Dispatch step:
expert_inputs = einsum('bnd,bnec->ebcd', inputs, dispatch_tensor)- reorganizes tokens by expert and capacity slot.
- Experts run FFN transform.
- Combine step:
output = einsum('ebcd,bnec->bnd', expert_outputs, combine_tensor)- merges weighted expert outputs back to token order.
- Return
(output, loss * loss_coef).
So: route -> process -> merge, with a balancing regularizer.
7) HeirarchicalMoE.forward: two-level routing
This class stacks two gating levels:
- Outer routing chooses expert groups.
- Build
importancefrom outer combine weights:- first-choice-like paths get higher importance,
- second-choice-like paths get reduced importance (0.5 logic).
- Inner routing then routes within each selected outer group.
- Experts run at shape
[outer_expert, inner_expert, ...]. - Combine inner then outer outputs back to
[b, n, d]. - Final auxiliary loss is
(loss_outer + loss_inner) * loss_coef.
This is a hierarchical extension of the same sparse routing principle.
How This Maps to the GShard Paper
The implementation aligns directly with the core GShard MoE concepts from the original paper:
- Top-2 routing: each token can be sent to up to two experts.
- Expert capacity: per-expert token budget limits overload.
- Overflow handling: tokens exceeding capacity are effectively dropped from expert processing path.
- Auxiliary balancing loss: encourages even expert utilization.
- Sparse conditional compute: model capacity scales with number of experts, while per-token active compute stays limited.
Paper links for direct reading:
Practical Reading Path
If you are learning this for the first time, follow this order:
- Read GShard Section 2.1 and 2.2 (Transformer + MoE layer concept).
- Scan Algorithm 1 (top-2 gating + auxiliary loss).
- Read this
Top2Gating.forwardimplementation end to end. - Verify dispatch/combine einsums in
MoE.forward. - Only then inspect hierarchical routing (
HeirarchicalMoE).
This mirrors the “priorities -> comparison -> code early” method from the video.
Common Confusions and Checks
Why two tensors (
dispatch_tensorandcombine_tensor)?dispatch_tensoris for selecting where token inputs go.combine_tensoris for weighted reconstruction of outputs.
Why normalize
gate_1andgate_2?- To keep top-2 contributions stable and interpretable as local mixture weights.
Why can tokens be dropped?
- Capacity enforces bounded expert load; dropped tokens continue via residual stream in full Transformer contexts.
Why use
randomsecond policy?- It reduces unnecessary second-expert traffic while keeping stochastic exploration and load flexibility.
What is the main training stabilizer here?
- The auxiliary load-balancing term that discourages collapse to only a few experts.
Conclusion
Mixture of Experts is easiest to learn when you connect three layers of understanding:
- Architecture intuition (replace dense FFN with routed sparse experts).
- Paper mechanism (top-2 gating, capacity, auxiliary load balancing).
- Code execution path (dispatch -> expert FFN -> combine).
Using that flow, the MoE implementation stops feeling magical and becomes a clear tensor-routing system you can reason about and modify.
From here, a strong next step is to instrument this code with small synthetic inputs and print tensor shapes at each stage to validate your mental model numerically.
To make that concrete, run tiny experiments where you control everything:
- Start with
b=2, n=8, d=16, num_experts=4. - Print shapes for
raw_gates,mask_1,mask_2,dispatch_tensor,combine_tensor, andexpert_inputs. - Force different second-expert policies (
none,threshold,random) and compare how many tokens each expert receives. - Lower and raise capacity factors to observe overflow behavior.
- Track the auxiliary loss while changing routing behavior.
Once you can predict those outputs before you run them, you truly understand the mechanics.
For broader context, these repositories are worth studying next:
- MoECollections/mixture-of-experts-2 - the fork used in this walkthrough flow.
- tensorflow/tensor2tensor - historical source of the original
moe.pyimplementation style referenced by many ports. - davidmrau/mixture-of-experts - compact PyTorch re-implementation of sparse-gated MoE that is easy to read and compare.
- microsoft/DeepSpeed - production MoE training infrastructure with expert parallelism.
- NVIDIA/Megatron-LM - large-scale transformer training codebase with MoE support.
- huggingface/transformers - practical model implementations (including MoE-style architectures) used in real inference/training workflows.
Read them in that order: clarity first, then scale.