# Mixture of Experts, from scratch

> Satyajit Ghana — Head of Engineering @ Inkers Technology
> canonical: https://ai.thesatyajit.com/articles/mixture-of-experts-from-scratch
> date: 2026-06-10
> tags: deep-learning, transformers, mixture-of-experts, explainer

Scaling a transformer the dense way is a bad trade. Every parameter you add runs
on every token. Double the width of the feed-forward layers and you double both
the model's capacity *and* the FLOPs it burns per token — capacity and compute are
welded together. You pay for the whole network on every single token, whether that
token needs it or not.

Mixture of Experts breaks the weld. The idea is **conditional computation**: keep a
large pile of parameters around, but for any given token, only run a small slice of
them. A tiny router looks at each token and picks a couple of sub-networks — the
*experts* — to handle it. The rest sit idle for that token. You get the capacity of
a big model at the compute of a small one.

Here is the whole model we'll build, end to end. The only thing that makes it a MoE
is one swapped line — tap the **sparse MoE** block to see it:

<MoeArchitecture />

Everything except that one block is a standard decoder-only transformer: token and
position embeddings, a stack of blocks, a final norm, an LM head. Attention is
untouched. MoE is a surgical replacement for the feed-forward layer inside each
block, and nothing else. So the whole thing reduces to three questions: what is an
expert, who decides which experts run, and how do you run only the chosen few.

## An expert is just an MLP

Start with the thing we're replacing. In a normal transformer block, after
attention, every token goes through the same two-layer MLP — expand to `4 * n_embed`,
nonlinearity, project back. That's the feed-forward network.

An expert is exactly that MLP. Nothing more.

```python
class Expert(nn.Module):
    def __init__(self, n_embed, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
```

The move is to keep `num_experts` copies of this MLP instead of one. With 8 experts
you have 8× the feed-forward parameters. If every token went through all 8, you'd
have spent 8× the compute and gained nothing but a slow, fat FFN. The whole game is
to run only `top_k` of them — say 2 — per token. So you carry 8 experts' worth of
parameters and pay for 2.

The piece that makes that decision is the router.

## Who decides? The router

The router's job: look at a token's vector $x$ and produce a weight for each expert,
mostly zero, so that only a few experts actually contribute. Build it up in three
steps, because the naive versions teach you why the real one looks the way it does.

**Attempt 1 — send every token to every expert, weighted.** A linear layer maps the
token to one logit per expert, softmax over them, take a weighted sum of all expert
outputs:

$$
g(x) = \mathrm{softmax}(x W_g), \qquad y = \sum_{i=1}^{N} g(x)_i \, E_i(x)
$$

Here $W_g$ is the router's weight matrix (`n_embed × num_experts`) and $E_i$ is the
$i$-th expert. This is differentiable and trains fine — but it's *dense*. Every
expert runs on every token. We've built an expensive ensemble, not a sparse model.

**Attempt 2 — hard pick the single best expert.** Take $\arg\max$ of the logits, run
only that expert. Now it's sparse and cheap. But $\arg\max$ has zero gradient: the
router only ever learns about the one expert it already chose, and never gets a
signal to try the others. Routing freezes. Dead end.

**Attempt 3 — top-$k$ softmax.** Keep the largest $k$ logits, set the rest to
$-\infty$, *then* softmax. The $-\infty$ entries become exactly 0, so only $k$ experts
contribute — sparse like attempt 2 — but the softmax over the survivors is smooth, so
gradients flow to all $k$ chosen experts. This is the real router:

$$
g(x) = \mathrm{softmax}\big(\mathrm{KeepTopK}(x W_g,\, k)\big), \qquad
\mathrm{KeepTopK}(v, k)_i = \begin{cases} v_i & v_i \text{ in top } k \\ -\infty & \text{otherwise} \end{cases}
$$

With $k = 2$ and $N = 8$, six of the eight gate weights are zero for every token, and
the two survivors sum to 1. Watch one token go through it — logits, keep the top two,
softmax to gates, combine:

<MoeRouter />

That stepper is the entire routing mechanism. The bars are the per-expert logits;
top-2 keeps two; softmax turns them into weights; the output is just those two
experts' outputs scaled by their gates and added.

## Why the noise

There's one addition that the bare top-$k$ router needs in practice: noise. Before
picking the top $k$, add a learned, per-expert amount of Gaussian noise to the
logits:

$$
H(x)_i = (x W_g)_i + \varepsilon_i \cdot \mathrm{softplus}\big((x W_{\text{noise}})_i\big), \qquad \varepsilon_i \sim \mathcal{N}(0, 1)
$$

The noise scale is itself learned (a second linear layer $W_{\text{noise}}$, passed
through `softplus` to keep it positive). Why bother? Because early in training the
router is random, and whichever experts happen to win first get all the gradient and
pull ahead — a rich-get-richer collapse. The noise jitters the top-$k$ selection so
borderline experts occasionally win, get some tokens, and get a chance to become
useful. It's exploration, baked into the forward pass. Hit *resample noise* in the
widget above and you can watch which two experts win flip.

In code the router is four lines of real work:

```python
class NoisyTopKRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super().__init__()
        self.top_k = top_k
        self.route = nn.Linear(n_embed, num_experts)   # gate logits
        self.noise = nn.Linear(n_embed, num_experts)   # per-expert noise scale

    def forward(self, x):
        logits = self.route(x)
        noisy = logits + torch.randn_like(logits) * F.softplus(self.noise(x))

        top_logits, idx = noisy.topk(self.top_k, dim=-1)     # the chosen experts
        sparse = torch.full_like(noisy, float("-inf"))
        sparse.scatter_(-1, idx, top_logits)                 # keep top-k, rest -inf
        return F.softmax(sparse, dim=-1), idx
```

`scatter_` is the one trick worth pausing on: it writes the kept logits back into a
tensor of `-inf`, at the indices the `topk` chose. After the softmax those `-inf`
slots are 0. The router returns the gate weights and the chosen indices — the
indices tell the next stage which experts to actually run.

## The sparse forward pass

Now the part that earns the word *sparse*. We have gate weights and, for each token,
the indices of its top-$k$ experts. We want to run each expert on only the tokens
routed to it, scale by the gate, and add the result back.

The straightforward way: loop over experts, and for each one, mask out the tokens
that picked it.

```python
class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super().__init__()
        self.router = NoisyTopKRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])

    def forward(self, x):
        gates, idx = self.router(x)            # (B,T,N), (B,T,k)
        out = torch.zeros_like(x)

        flat_x = x.view(-1, x.size(-1))        # (B*T, C)
        flat_gates = gates.view(-1, gates.size(-1))
        flat_out = out.view(-1, x.size(-1))

        for i, expert in enumerate(self.experts):
            mask = (idx == i).any(dim=-1).view(-1)   # tokens routed to expert i
            if mask.any():
                y = expert(flat_x[mask])             # run on its tokens only
                flat_out[mask] += flat_gates[mask, i:i+1] * y
        return out
```

The `mask = (idx == i).any(dim=-1)` line is the dispatch: it's true for exactly the
tokens that have expert `i` somewhere in their top-$k$. We gather those tokens, run
the expert once on the batch of them, scale each by its gate weight, and scatter-add
back into the output. A token routed to experts 2 and 5 gets contributions from both
loop iterations, summed — which is exactly $\sum_i g(x)_i E_i(x)$ with all but $k$
terms zero.

Picture the dispatch over a short sequence. Each token connects to just two of the
eight experts, so most of the grid stays dark — that darkness is the compute you're
*not* spending:

<MoeRouting />

The bars underneath are the per-expert load: how many tokens each expert handled.
Notice it's already uneven — some experts attract more traffic than others. Hold that
thought; it's the central problem with MoE.

<Callout type="note">
  This masked loop is the *teaching* implementation. It's correct but it runs every
  expert as a separate kernel and materialises a mask per expert. Production MoE
  instead sorts/permutes tokens by expert and does one grouped matmul, and in the
  distributed case each expert lives on a different GPU and tokens are shipped to
  them (expert parallelism). Same math, very different plumbing.
</Callout>

## The one line that changes

With the experts and the router in hand, dropping MoE into a transformer block is
anticlimactic — which is the point. A standard block is `attention → FFN`, each
wrapped in a layer-norm and a residual. MoE swaps the FFN for the `SparseMoE` module
and touches nothing else:

```python
class Block(nn.Module):
    def __init__(self, n_embed, n_head, num_experts, top_k, block_size):
        super().__init__()
        self.sa = MultiHeadAttention(n_head, n_embed, block_size)
        self.smoe = SparseMoE(n_embed, num_experts, top_k)   # was: FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))      # attention — unchanged
        x = x + self.smoe(self.ln2(x))    # MoE replaces the feed-forward layer
        return x
```

That's the whole architectural delta. One `FeedForward` becomes one `SparseMoE`:

<Diagram caption="Same slot in the block. Dense runs one MLP on every token; sparse runs a router plus the two chosen experts.">
  <svg viewBox="0 0 640 250" role="img" aria-label="A dense feed-forward layer applies one MLP to every token; the sparse MoE layer routes each token to two of eight experts." style={{ width: "100%", height: "auto" }}>
    <defs>
      <marker id="moe-arrow" viewBox="0 0 10 10" refX="8" refY="5" markerWidth="6" markerHeight="6" orient="auto-start-reverse">
        <path d="M0,0 L10,5 L0,10 z" fill="var(--muted-foreground)" />
      </marker>
    </defs>

    {/* dense side */}
    <text x="150" y="24" textAnchor="middle" fontFamily="monospace" fontSize="13" fill="var(--foreground)">dense FFN</text>
    <rect x="110" y="44" width="80" height="22" rx="5" fill="var(--background)" stroke="var(--border)" />
    <text x="150" y="59" textAnchor="middle" fontFamily="monospace" fontSize="10" fill="var(--foreground)">tokens</text>
    <line x1="150" y1="66" x2="150" y2="92" stroke="var(--muted-foreground)" strokeWidth="1.2" markerEnd="url(#moe-arrow)" />
    <rect x="95" y="94" width="110" height="50" rx="8" fill="oklch(0.72 0.13 250)" opacity="0.9" />
    <text x="150" y="116" textAnchor="middle" fontFamily="monospace" fontSize="11" fill="oklch(0.2 0 0)">one MLP</text>
    <text x="150" y="132" textAnchor="middle" fontFamily="monospace" fontSize="9" fill="oklch(0.2 0 0)">every token</text>
    <line x1="150" y1="144" x2="150" y2="170" stroke="var(--muted-foreground)" strokeWidth="1.2" markerEnd="url(#moe-arrow)" />
    <text x="150" y="190" textAnchor="middle" fontFamily="monospace" fontSize="10" fill="var(--muted-foreground)">100% of params, every token</text>

    {/* divider */}
    <line x1="320" y1="30" x2="320" y2="210" stroke="var(--border)" strokeDasharray="3 4" />

    {/* sparse side */}
    <text x="490" y="24" textAnchor="middle" fontFamily="monospace" fontSize="13" fill="var(--foreground)">sparse MoE</text>
    <rect x="450" y="44" width="80" height="22" rx="5" fill="var(--background)" stroke="var(--border)" />
    <text x="490" y="59" textAnchor="middle" fontFamily="monospace" fontSize="10" fill="var(--foreground)">tokens</text>
    <line x1="490" y1="66" x2="490" y2="84" stroke="var(--muted-foreground)" strokeWidth="1.2" markerEnd="url(#moe-arrow)" />
    <rect x="448" y="86" width="84" height="20" rx="5" fill="var(--background)" stroke="var(--foreground)" strokeOpacity="0.4" />
    <text x="490" y="100" textAnchor="middle" fontFamily="monospace" fontSize="10" fill="var(--foreground)">router</text>

    {/* 8 experts, 2 lit */}
    {[0,1,2,3,4,5,6,7].map((i) => {
      const x = 372 + i * 30
      const lit = i === 2 || i === 5
      return (
        <g key={i}>
          <line x1="490" y1="106" x2={x + 11} y2="150" stroke={`oklch(0.72 0.13 ${(i*45)%360})`} strokeWidth={lit ? 2 : 1} opacity={lit ? 0.9 : 0.12} />
          <rect x={x} y="150" width="22" height="34" rx="4" fill={`oklch(0.72 0.13 ${(i*45)%360})`} opacity={lit ? 1 : 0.16} />
        </g>
      )
    })}
    <text x="490" y="204" textAnchor="middle" fontFamily="monospace" fontSize="10" fill="var(--muted-foreground)">8 experts stored · 2 run</text>
  </svg>
</Diagram>

Stack eight of these blocks, add embeddings and an LM head, and you have the model
from the top of the page. Train it exactly like a dense transformer — cross-entropy
on next-token prediction. The router learns its weights from the same gradient as
everything else. No special routing supervision; it figures out a useful assignment
on its own.

## Run it yourself

Here is everything above assembled into one file — a char-level model that trains on
tiny Shakespeare in about 200 lines, with no dependency past PyTorch. The `Expert`,
`NoisyTopKRouter`, and `SparseMoE` are exactly the pieces we just built; the rest is
the smallest transformer that can hold them. Copy it, run `python tinymoe.py`, and
watch the loss come down.

```python
"""
tinymoe — a tiny Mixture-of-Experts language model in one file.
Char-level, trains on tiny Shakespeare. ~4.5M params, ~1.4M active per token.
Runs on CPU; much faster on a GPU.

    python tinymoe.py        # download data, train, then sample

It's a small decoder-only transformer where the feed-forward layer of every
block is replaced by a sparse mixture of experts with noisy top-k routing.
"""
import os
import urllib.request

import torch
import torch.nn as nn
from torch.nn import functional as F

# --------------------------------------------------------------------- config
batch_size = 32          # sequences per step
block_size = 128         # context length (chars)
n_embed = 128            # embedding / residual width
n_head = 4               # attention heads
n_layer = 4              # transformer blocks
num_experts = 8          # experts per MoE layer
top_k = 2                # experts actually run per token
dropout = 0.1
learning_rate = 3e-4
max_iters = 5000
eval_interval = 500
eval_iters = 100
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(1337)

# ----------------------------------------------------------- data (shakespeare)
if not os.path.exists("input.txt"):
    url = ("https://raw.githubusercontent.com/karpathy/char-rnn/"
           "master/data/tinyshakespeare/input.txt")
    urllib.request.urlretrieve(url, "input.txt")
text = open("input.txt", encoding="utf-8").read()

chars = sorted(set(text))
vocab_size = len(chars)
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda t: "".join(itos[i] for i in t)

data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data, val_data = data[:n], data[n:]


def get_batch(split):
    d = train_data if split == "train" else val_data
    ix = torch.randint(len(d) - block_size, (batch_size,))
    x = torch.stack([d[i:i + block_size] for i in ix])
    y = torch.stack([d[i + 1:i + block_size + 1] for i in ix])
    return x.to(device), y.to(device)


# ------------------------------------------------------------------- attention
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k, q = self.key(x), self.query(x)
        wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        wei = self.drop(F.softmax(wei, dim=-1))
        return wei @ self.value(x)


class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.drop(self.proj(out))


# --------------------------------------------------------- mixture of experts
class Expert(nn.Module):
    """One expert = one MLP. Same shape as a normal transformer FFN."""

    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed), nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed), nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class NoisyTopKRouter(nn.Module):
    """Score experts per token, add learned noise, keep top-k, softmax."""

    def __init__(self):
        super().__init__()
        self.route = nn.Linear(n_embed, num_experts)
        self.noise = nn.Linear(n_embed, num_experts)

    def forward(self, x):
        logits = self.route(x)
        noisy = logits + torch.randn_like(logits) * F.softplus(self.noise(x))
        top_logits, idx = noisy.topk(top_k, dim=-1)
        sparse = torch.full_like(noisy, float("-inf")).scatter(-1, idx, top_logits)
        return F.softmax(sparse, dim=-1), idx


class SparseMoE(nn.Module):
    """Run only the top-k experts per token; combine them by gate weight."""

    def __init__(self):
        super().__init__()
        self.router = NoisyTopKRouter()
        self.experts = nn.ModuleList([Expert() for _ in range(num_experts)])

    def forward(self, x):
        gates, idx = self.router(x)                  # (B,T,E), (B,T,k)
        out = torch.zeros_like(x)
        flat_x = x.reshape(-1, x.size(-1))
        flat_gates = gates.reshape(-1, gates.size(-1))
        flat_out = out.reshape(-1, x.size(-1))
        for i, expert in enumerate(self.experts):
            mask = (idx == i).any(dim=-1).reshape(-1)  # tokens routed to expert i
            if mask.any():
                flat_out[mask] += flat_gates[mask, i:i + 1] * expert(flat_x[mask])
        return out


# ------------------------------------------------------------- block + model
class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa = MultiHeadAttention(n_head, n_embed // n_head)
        self.smoe = SparseMoE()                      # <- replaces the FFN
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.smoe(self.ln2(x))
        return x


class MoELanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, n_embed)
        self.pos_emb = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block() for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        x = self.tok_emb(idx) + self.pos_emb(torch.arange(T, device=idx.device))
        x = self.ln_f(self.blocks(x))
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, _ = self(idx[:, -block_size:])
            probs = F.softmax(logits[:, -1, :], dim=-1)
            idx = torch.cat([idx, torch.multinomial(probs, 1)], dim=1)
        return idx


# --------------------------------------------------------------------- train
@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ("train", "val"):
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            x, y = get_batch(split)
            _, losses[k] = model(x, y)
        out[split] = losses.mean().item()
    model.train()
    return out


model = MoELanguageModel().to(device)
total = sum(p.numel() for p in model.parameters())
print(f"{total / 1e6:.2f}M params on {device}")
opt = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for it in range(max_iters):
    if it % eval_interval == 0:
        l = estimate_loss(model)
        print(f"step {it:5d} | train {l['train']:.3f} | val {l['val']:.3f}")
    x, y = get_batch("train")
    _, loss = model(x, y)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    opt.step()

# -------------------------------------------------------------------- sample
ctx = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(ctx, 500)[0].tolist()))
```

At the default size it prints `4.52M params` — but only **~1.4M of them run on any
given token**, because 6 of every 8 experts sit out. That's the parameter-vs-compute
split in miniature. Raise `num_experts` and the total climbs while the active count
barely moves; lower `top_k` to 1 and it gets sparser still. The same lever Mixtral
pulls, in a model you can train on a laptop.

One honesty note: this minimal version relies entirely on the routing noise to keep
experts balanced — there's no auxiliary loss. At toy scale it trains fine. Scale it
up and a few experts quietly take over, which is the next problem.

## The catch: load balancing

MoE has one failure mode that dominates everything else, and you saw it forming in the
dispatch map: **expert collapse**. Routing is a positive feedback loop. An expert that
wins a few tokens early gets gradient, improves, and so becomes the router's favourite
for even more tokens. Meanwhile the experts that lost early get no tokens, no
gradient, and never improve. Left alone, a handful of experts end up doing all the
work and the rest are dead weight — you're paying to store 8 experts and effectively
running 2 or 3.

<MoeLoadBalance />

The noise we added earlier is the first defense — it keeps the routing from hardening
too fast. The second, used in every serious MoE, is an **auxiliary load-balancing
loss**: a term added to the training objective that measures how lopsided the routing
is across a batch and penalises imbalance, nudging the router toward spreading tokens
evenly. It's a soft constraint — you're not forcing exactly equal load, just paying a
cost for collapse. Tuning its weight is part of the unglamorous reality of training a
MoE: too little and experts collapse, too much and you fight the router's ability to
actually specialise.

This is the honest tradeoff. A dense FFN has no routing, no balance to maintain, no
extra loss to tune. MoE buys you cheap capacity and hands you a load-balancing problem
in return.

## What the experts actually learn

It's tempting to picture expert 3 as "the Python expert" and expert 5 as "the French
expert." That's mostly not what happens. When the Mixtral authors inspected their
router, they found no clean topic or domain specialization — experts don't map to
subjects. What the router learns is lower-level and more syntactic: routing is
strongly correlated across consecutive tokens, and individual experts lean toward
things like indentation, punctuation, or particular token shapes. The specialization
is real, but it's structural, not semantic, and not especially interpretable.
"Experts" is a useful name, not a promise that each one becomes a tidy domain
specialist.

## Beyond the basic router

The router we built is *token-choice*: each token picks its experts. Three variations
are worth knowing, because they're all different answers to the same load-balancing
problem:

- **Expert-choice routing** flips the selection — each expert picks its top tokens.
  Load is balanced by construction (every expert takes a fixed budget), at the cost of
  some tokens getting chosen by many experts and others by none.
- **Shared experts** (as in DeepSeek-MoE) keep one or two experts always on for every
  token, so the routed experts don't burn capacity re-learning common patterns and can
  specialize at the margin.
- **Capacity and token dropping** — in batched or distributed training each expert gets
  a fixed number of slots per batch; tokens that overflow their chosen expert are
  dropped and pass through on the residual alone. A blunt cap that keeps the per-expert
  matmuls a fixed, rectangular shape.

Same tradeoff surface — cheap capacity versus keeping every expert fed — approached
from different sides.

## What you actually buy

Why put up with the routing machinery? Because the parameter-vs-compute decoupling is
real and large. Mixtral 8×7B is the clean reference: 8 experts per layer, top-2
routing — the exact configuration we just built. It holds **47B parameters total**,
but because only 2 of 8 experts run per token, a forward pass touches **about 13B
active parameters**. It runs at the speed and memory-bandwidth cost of a ~13B dense
model while matching or beating a 70B dense one across benchmarks.

That's the pitch in one line: **capacity you don't pay for on every token.** The
parameters are the model's knowledge; the active fraction is what each token can
afford to consult.

There's a cost on the other side of the ledger, and it's worth stating plainly. MoE
trades **compute for memory**. Only $k$ experts run, but *all* of them have to be
resident — you still hold 47B parameters in memory even though each token uses 13B.
And at batch scale the router scatters tokens across all experts, so the bandwidth and
the all-to-all communication of shipping tokens to the right expert (across GPUs)
becomes the real bottleneck, not the matmuls. MoE doesn't make models free. It moves
the cost from FLOPs, which you pay per token, to memory and bandwidth, which you pay
once. For inference-bound serving at scale, that's usually the trade you want.

## The whole thing, in one breath

Strip away the engineering and MoE is small: an expert is the FFN you already had;
keep several of them; a one-layer router scores them per token; keep the top two,
softmax for weights, run only those two, add a little noise so routing explores and a
balancing loss so it doesn't collapse. One line in the transformer block changes. In
return, the model's parameter count and its per-token compute stop being the same
number — and that decoupling is the entire reason the largest models you can name are
built this way.
