~/satyajit

Mixture of Experts, from scratch

mdjsonmcp

2026-06-10 · 18 min · deep-learning · transformers · mixture-of-experts · explainer

Scaling a transformer the dense way is a bad trade. Every parameter you add runs on every token. Double the width of the feed-forward layers and you double both the model's capacity and the FLOPs it burns per token — capacity and compute are welded together. You pay for the whole network on every single token, whether that token needs it or not.

Mixture of Experts breaks the weld. The idea is conditional computation: keep a large pile of parameters around, but for any given token, only run a small slice of them. A tiny router looks at each token and picks a couple of sub-networks — the experts — to handle it. The rest sit idle for that token. You get the capacity of a big model at the compute of a small one.

Here is the whole model we'll build, end to end. The only thing that makes it a MoE is one swapped line — tap the sparse MoE block to see it:

sparse-MoE language model — tap a block to read its code
idx (B, T)
transformer block × 8
logits (B, T, vocab)
sparse MoE + residual(B, T, C) → (B, T, C)
x = x + self.smoe(self.ln2(x))   # <- was a single MLP

# inside SparseMoE.forward:
gates, idx = self.router(x)            # pick top-2 of 8
for i, expert in enumerate(self.experts):
    mask = (idx == i).any(-1)          # tokens routed here
    out[mask] += gates[mask, i] * expert(x[mask])

The whole idea, in one swapped line. A router sends each token to 2 of 8 expert MLPs; their gated outputs are summed.

Everything except that one block is a standard decoder-only transformer: token and position embeddings, a stack of blocks, a final norm, an LM head. Attention is untouched. MoE is a surgical replacement for the feed-forward layer inside each block, and nothing else. So the whole thing reduces to three questions: what is an expert, who decides which experts run, and how do you run only the chosen few.

An expert is just an MLP

Start with the thing we're replacing. In a normal transformer block, after attention, every token goes through the same two-layer MLP — expand to 4 * n_embed, nonlinearity, project back. That's the feed-forward network.

An expert is exactly that MLP. Nothing more.

class Expert(nn.Module):
    def __init__(self, n_embed, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout),
        )
 
    def forward(self, x):
        return self.net(x)

The move is to keep num_experts copies of this MLP instead of one. With 8 experts you have 8× the feed-forward parameters. If every token went through all 8, you'd have spent 8× the compute and gained nothing but a slow, fat FFN. The whole game is to run only top_k of them — say 2 — per token. So you carry 8 experts' worth of parameters and pay for 2.

The piece that makes that decision is the router.

Who decides? The router

The router's job: look at a token's vector xx and produce a weight for each expert, mostly zero, so that only a few experts actually contribute. Build it up in three steps, because the naive versions teach you why the real one looks the way it does.

Attempt 1 — send every token to every expert, weighted. A linear layer maps the token to one logit per expert, softmax over them, take a weighted sum of all expert outputs:

g(x)=softmax(xWg),y=i=1Ng(x)iEi(x)g(x) = \mathrm{softmax}(x W_g), \qquad y = \sum_{i=1}^{N} g(x)_i \, E_i(x)

Here WgW_g is the router's weight matrix (n_embed × num_experts) and EiE_i is the ii-th expert. This is differentiable and trains fine — but it's dense. Every expert runs on every token. We've built an expensive ensemble, not a sparse model.

Attempt 2 — hard pick the single best expert. Take argmax\arg\max of the logits, run only that expert. Now it's sparse and cheap. But argmax\arg\max has zero gradient: the router only ever learns about the one expert it already chose, and never gets a signal to try the others. Routing freezes. Dead end.

Attempt 3 — top-kk softmax. Keep the largest kk logits, set the rest to -\infty, then softmax. The -\infty entries become exactly 0, so only kk experts contribute — sparse like attempt 2 — but the softmax over the survivors is smooth, so gradients flow to all kk chosen experts. This is the real router:

g(x)=softmax(KeepTopK(xWg,k)),KeepTopK(v,k)i={vivi in top kotherwiseg(x) = \mathrm{softmax}\big(\mathrm{KeepTopK}(x W_g,\, k)\big), \qquad \mathrm{KeepTopK}(v, k)_i = \begin{cases} v_i & v_i \text{ in top } k \\ -\infty & \text{otherwise} \end{cases}

With k=2k = 2 and N=8N = 8, six of the eight gate weights are zero for every token, and the two survivors sum to 1. Watch one token go through it — logits, keep the top two, softmax to gates, combine:

noisy top-2 router · 8 expertsstep 1/5
1.9
E0
-0.6
E1
1.4
E2
0.8
E3
-1.2
E4
2.1
E5
0.1
E6
-0.3
E7
router logitslogits = x @ W_g

One linear layer turns the token into one score per expert.

That stepper is the entire routing mechanism. The bars are the per-expert logits; top-2 keeps two; softmax turns them into weights; the output is just those two experts' outputs scaled by their gates and added.

Why the noise

There's one addition that the bare top-kk router needs in practice: noise. Before picking the top kk, add a learned, per-expert amount of Gaussian noise to the logits:

H(x)i=(xWg)i+εisoftplus((xWnoise)i),εiN(0,1)H(x)_i = (x W_g)_i + \varepsilon_i \cdot \mathrm{softplus}\big((x W_{\text{noise}})_i\big), \qquad \varepsilon_i \sim \mathcal{N}(0, 1)

The noise scale is itself learned (a second linear layer WnoiseW_{\text{noise}}, passed through softplus to keep it positive). Why bother? Because early in training the router is random, and whichever experts happen to win first get all the gradient and pull ahead — a rich-get-richer collapse. The noise jitters the top-kk selection so borderline experts occasionally win, get some tokens, and get a chance to become useful. It's exploration, baked into the forward pass. Hit resample noise in the widget above and you can watch which two experts win flip.

In code the router is four lines of real work:

class NoisyTopKRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super().__init__()
        self.top_k = top_k
        self.route = nn.Linear(n_embed, num_experts)   # gate logits
        self.noise = nn.Linear(n_embed, num_experts)   # per-expert noise scale
 
    def forward(self, x):
        logits = self.route(x)
        noisy = logits + torch.randn_like(logits) * F.softplus(self.noise(x))
 
        top_logits, idx = noisy.topk(self.top_k, dim=-1)     # the chosen experts
        sparse = torch.full_like(noisy, float("-inf"))
        sparse.scatter_(-1, idx, top_logits)                 # keep top-k, rest -inf
        return F.softmax(sparse, dim=-1), idx

scatter_ is the one trick worth pausing on: it writes the kept logits back into a tensor of -inf, at the indices the topk chose. After the softmax those -inf slots are 0. The router returns the gate weights and the chosen indices — the indices tell the next stage which experts to actually run.

The sparse forward pass

Now the part that earns the word sparse. We have gate weights and, for each token, the indices of its top-kk experts. We want to run each expert on only the tokens routed to it, scale by the gate, and add the result back.

The straightforward way: loop over experts, and for each one, mask out the tokens that picked it.

class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super().__init__()
        self.router = NoisyTopKRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
 
    def forward(self, x):
        gates, idx = self.router(x)            # (B,T,N), (B,T,k)
        out = torch.zeros_like(x)
 
        flat_x = x.view(-1, x.size(-1))        # (B*T, C)
        flat_gates = gates.view(-1, gates.size(-1))
        flat_out = out.view(-1, x.size(-1))
 
        for i, expert in enumerate(self.experts):
            mask = (idx == i).any(dim=-1).view(-1)   # tokens routed to expert i
            if mask.any():
                y = expert(flat_x[mask])             # run on its tokens only
                flat_out[mask] += flat_gates[mask, i:i+1] * y
        return out

The mask = (idx == i).any(dim=-1) line is the dispatch: it's true for exactly the tokens that have expert i somewhere in their top-kk. We gather those tokens, run the expert once on the batch of them, scale each by its gate weight, and scatter-add back into the output. A token routed to experts 2 and 5 gets contributions from both loop iterations, summed — which is exactly ig(x)iEi(x)\sum_i g(x)_i E_i(x) with all but kk terms zero.

Picture the dispatch over a short sequence. Each token connects to just two of the eight experts, so most of the grid stays dark — that darkness is the compute you're not spending:

token → top-2 expert routingall tokens
Thequickbrownfoxjumpsoverthedogexpert 0expert 1expert 2expert 3expert 4expert 5expert 6expert 7
tokens routed per expert (this sequence)
3
E0
1
E1
2
E2
2
E3
1
E4
5
E5
2
E6
0
E7
hover a token to isolate its route

The bars underneath are the per-expert load: how many tokens each expert handled. Notice it's already uneven — some experts attract more traffic than others. Hold that thought; it's the central problem with MoE.

The one line that changes

With the experts and the router in hand, dropping MoE into a transformer block is anticlimactic — which is the point. A standard block is attention → FFN, each wrapped in a layer-norm and a residual. MoE swaps the FFN for the SparseMoE module and touches nothing else:

class Block(nn.Module):
    def __init__(self, n_embed, n_head, num_experts, top_k, block_size):
        super().__init__()
        self.sa = MultiHeadAttention(n_head, n_embed, block_size)
        self.smoe = SparseMoE(n_embed, num_experts, top_k)   # was: FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)
 
    def forward(self, x):
        x = x + self.sa(self.ln1(x))      # attention — unchanged
        x = x + self.smoe(self.ln2(x))    # MoE replaces the feed-forward layer
        return x

That's the whole architectural delta. One FeedForward becomes one SparseMoE:

dense FFNtokensone MLPevery token100% of params, every tokensparse MoEtokensrouter8 experts stored · 2 run
Same slot in the block. Dense runs one MLP on every token; sparse runs a router plus the two chosen experts.

Stack eight of these blocks, add embeddings and an LM head, and you have the model from the top of the page. Train it exactly like a dense transformer — cross-entropy on next-token prediction. The router learns its weights from the same gradient as everything else. No special routing supervision; it figures out a useful assignment on its own.

Run it yourself

Here is everything above assembled into one file — a char-level model that trains on tiny Shakespeare in about 200 lines, with no dependency past PyTorch. The Expert, NoisyTopKRouter, and SparseMoE are exactly the pieces we just built; the rest is the smallest transformer that can hold them. Copy it, run python tinymoe.py, and watch the loss come down.

"""
tinymoe — a tiny Mixture-of-Experts language model in one file.
Char-level, trains on tiny Shakespeare. ~4.5M params, ~1.4M active per token.
Runs on CPU; much faster on a GPU.
 
    python tinymoe.py        # download data, train, then sample
 
It's a small decoder-only transformer where the feed-forward layer of every
block is replaced by a sparse mixture of experts with noisy top-k routing.
"""
import os
import urllib.request
 
import torch
import torch.nn as nn
from torch.nn import functional as F
 
# --------------------------------------------------------------------- config
batch_size = 32          # sequences per step
block_size = 128         # context length (chars)
n_embed = 128            # embedding / residual width
n_head = 4               # attention heads
n_layer = 4              # transformer blocks
num_experts = 8          # experts per MoE layer
top_k = 2                # experts actually run per token
dropout = 0.1
learning_rate = 3e-4
max_iters = 5000
eval_interval = 500
eval_iters = 100
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(1337)
 
# ----------------------------------------------------------- data (shakespeare)
if not os.path.exists("input.txt"):
    url = ("https://raw.githubusercontent.com/karpathy/char-rnn/"
           "master/data/tinyshakespeare/input.txt")
    urllib.request.urlretrieve(url, "input.txt")
text = open("input.txt", encoding="utf-8").read()
 
chars = sorted(set(text))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda t: "".join(itos[i] for i in t)
 
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]
 
 
def get_batch(split):
    d = train_data if split == "train" else val_data
    ix = torch.randint(len(d) - block_size, (batch_size,))
    x = torch.stack([d[i:i + block_size] for i in ix])
    y = torch.stack([d[i + 1:i + block_size + 1] for i in ix])
    return x.to(device), y.to(device)
 
 
# ------------------------------------------------------------------- attention
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        self.drop = nn.Dropout(dropout)
 
    def forward(self, x):
        B, T, C = x.shape
        k, q = self.key(x), self.query(x)
        wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        wei = self.drop(F.softmax(wei, dim=-1))
        return wei @ self.value(x)
 
 
class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.drop = nn.Dropout(dropout)
 
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.drop(self.proj(out))
 
 
# --------------------------------------------------------- mixture of experts
class Expert(nn.Module):
    """One expert = one MLP. Same shape as a normal transformer FFN."""
 
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed), nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed), nn.Dropout(dropout),
        )
 
    def forward(self, x):
        return self.net(x)
 
 
class NoisyTopKRouter(nn.Module):
    """Score experts per token, add learned noise, keep top-k, softmax."""
 
    def __init__(self):
        super().__init__()
        self.route = nn.Linear(n_embed, num_experts)
        self.noise = nn.Linear(n_embed, num_experts)
 
    def forward(self, x):
        logits = self.route(x)
        noisy = logits + torch.randn_like(logits) * F.softplus(self.noise(x))
        top_logits, idx = noisy.topk(top_k, dim=-1)
        sparse = torch.full_like(noisy, float("-inf")).scatter(-1, idx, top_logits)
        return F.softmax(sparse, dim=-1), idx
 
 
class SparseMoE(nn.Module):
    """Run only the top-k experts per token; combine them by gate weight."""
 
    def __init__(self):
        super().__init__()
        self.router = NoisyTopKRouter()
        self.experts = nn.ModuleList([Expert() for _ in range(num_experts)])
 
    def forward(self, x):
        gates, idx = self.router(x)                  # (B,T,E), (B,T,k)
        out = torch.zeros_like(x)
        flat_x = x.reshape(-1, x.size(-1))
        flat_gates = gates.reshape(-1, gates.size(-1))
        flat_out = out.reshape(-1, x.size(-1))
        for i, expert in enumerate(self.experts):
            mask = (idx == i).any(dim=-1).reshape(-1)  # tokens routed to expert i
            if mask.any():
                flat_out[mask] += flat_gates[mask, i:i + 1] * expert(flat_x[mask])
        return out
 
 
# ------------------------------------------------------------- block + model
class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa = MultiHeadAttention(n_head, n_embed // n_head)
        self.smoe = SparseMoE()                      # <- replaces the FFN
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)
 
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.smoe(self.ln2(x))
        return x
 
 
class MoELanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, n_embed)
        self.pos_emb = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block() for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.head = nn.Linear(n_embed, vocab_size)
 
    def forward(self, idx, targets=None):
        B, T = idx.shape
        x = self.tok_emb(idx) + self.pos_emb(torch.arange(T, device=idx.device))
        x = self.ln_f(self.blocks(x))
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
        return logits, loss
 
    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, _ = self(idx[:, -block_size:])
            probs = F.softmax(logits[:, -1, :], dim=-1)
            idx = torch.cat([idx, torch.multinomial(probs, 1)], dim=1)
        return idx
 
 
# --------------------------------------------------------------------- train
@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ("train", "val"):
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split)
            _, losses[k] = model(x, y)
        out[split] = losses.mean().item()
    model.train()
    return out
 
 
model = MoELanguageModel().to(device)
total = sum(p.numel() for p in model.parameters())
print(f"{total / 1e6:.2f}M params on {device}")
opt = torch.optim.AdamW(model.parameters(), lr=learning_rate)
 
for it in range(max_iters):
    if it % eval_interval == 0:
        l = estimate_loss(model)
        print(f"step {it:5d} | train {l['train']:.3f} | val {l['val']:.3f}")
    x, y = get_batch("train")
    _, loss = model(x, y)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()
 
# -------------------------------------------------------------------- sample
ctx = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(ctx, 500)[0].tolist()))

At the default size it prints 4.52M params — but only ~1.4M of them run on any given token, because 6 of every 8 experts sit out. That's the parameter-vs-compute split in miniature. Raise num_experts and the total climbs while the active count barely moves; lower top_k to 1 and it gets sparser still. The same lever Mixtral pulls, in a model you can train on a laptop.

One honesty note: this minimal version relies entirely on the routing noise to keep experts balanced — there's no auxiliary loss. At toy scale it trains fine. Scale it up and a few experts quietly take over, which is the next problem.

The catch: load balancing

MoE has one failure mode that dominates everything else, and you saw it forming in the dispatch map: expert collapse. Routing is a positive feedback loop. An expert that wins a few tokens early gets gradient, improves, and so becomes the router's favourite for even more tokens. Meanwhile the experts that lost early get no tokens, no gradient, and never improve. Left alone, a handful of experts end up doing all the work and the rest are dead weight — you're paying to store 8 experts and effectively running 2 or 3.

expert load distribution
ideal 12.5%
12%
E0
13%
E1
11%
E2
14%
E3
12%
E4
13%
E5
12%
E6
13%
E7

Every expert sees a fair share, so all of them keep learning. Getting here is not free — it takes the routing noise plus an auxiliary load-balancing loss that penalises lopsided routing during training.

The noise we added earlier is the first defense — it keeps the routing from hardening too fast. The second, used in every serious MoE, is an auxiliary load-balancing loss: a term added to the training objective that measures how lopsided the routing is across a batch and penalises imbalance, nudging the router toward spreading tokens evenly. It's a soft constraint — you're not forcing exactly equal load, just paying a cost for collapse. Tuning its weight is part of the unglamorous reality of training a MoE: too little and experts collapse, too much and you fight the router's ability to actually specialise.

This is the honest tradeoff. A dense FFN has no routing, no balance to maintain, no extra loss to tune. MoE buys you cheap capacity and hands you a load-balancing problem in return.

What the experts actually learn

It's tempting to picture expert 3 as "the Python expert" and expert 5 as "the French expert." That's mostly not what happens. When the Mixtral authors inspected their router, they found no clean topic or domain specialization — experts don't map to subjects. What the router learns is lower-level and more syntactic: routing is strongly correlated across consecutive tokens, and individual experts lean toward things like indentation, punctuation, or particular token shapes. The specialization is real, but it's structural, not semantic, and not especially interpretable. "Experts" is a useful name, not a promise that each one becomes a tidy domain specialist.

Beyond the basic router

The router we built is token-choice: each token picks its experts. Three variations are worth knowing, because they're all different answers to the same load-balancing problem:

Same tradeoff surface — cheap capacity versus keeping every expert fed — approached from different sides.

What you actually buy

Why put up with the routing machinery? Because the parameter-vs-compute decoupling is real and large. Mixtral 8×7B is the clean reference: 8 experts per layer, top-2 routing — the exact configuration we just built. It holds 47B parameters total, but because only 2 of 8 experts run per token, a forward pass touches about 13B active parameters. It runs at the speed and memory-bandwidth cost of a ~13B dense model while matching or beating a 70B dense one across benchmarks.

That's the pitch in one line: capacity you don't pay for on every token. The parameters are the model's knowledge; the active fraction is what each token can afford to consult.

There's a cost on the other side of the ledger, and it's worth stating plainly. MoE trades compute for memory. Only kk experts run, but all of them have to be resident — you still hold 47B parameters in memory even though each token uses 13B. And at batch scale the router scatters tokens across all experts, so the bandwidth and the all-to-all communication of shipping tokens to the right expert (across GPUs) becomes the real bottleneck, not the matmuls. MoE doesn't make models free. It moves the cost from FLOPs, which you pay per token, to memory and bandwidth, which you pay once. For inference-bound serving at scale, that's usually the trade you want.

The whole thing, in one breath

Strip away the engineering and MoE is small: an expert is the FFN you already had; keep several of them; a one-layer router scores them per token; keep the top two, softmax for weights, run only those two, add a little noise so routing explores and a balancing loss so it doesn't collapse. One line in the transformer block changes. In return, the model's parameter count and its per-token compute stop being the same number — and that decoupling is the entire reason the largest models you can name are built this way.

share