~/satyajit

Multi-token prediction: training a model to see further than one step

mdjsonmcp

2026-06-27 · 9 min · llm · multi-token-prediction · inference-optimization · speculative-decoding · explainer

Standard language models are trained on a deceptively narrow task: given everything so far, predict the single next token. The objective is one cross-entropy term per position,

Lnext  =  tlogPθ(xt+1xt).L_{\text{next}} \;=\; -\sum_t \log P_\theta\big(x_{t+1}\mid x_{\le t}\big).

Multi-token prediction (MTP) changes one thing: from each position, predict the next nn tokens at once. That small change buys two unrelated-looking wins — a model that trains better, and a model that decodes faster — and the second one is why it's now in shipping products from DeepSeek to Google.

A provenance note up front, because the "Google multi-token prediction" framing gets the credit wrong. The modern MTP training objective was defined by Meta (FAIR) in Gloeckle et al., 2024. The predict-several-then- verify decoding idea it reuses traces to Google Brain's 2018 blockwise parallel decoding. DeepSeek-V3 productionized a sequential variant at pretraining scale. And Google's genuine 2026 contribution is applied: MTP-based speculative decoding in Gemma 4 and on-device Gemini Nano. I'll attribute as I go.

The objective: predict n futures

Keep one shared transformer trunk that turns the context into a latent ztz_t. Attach nn output heads. The loss sums cross-entropy over the next nn positions:

LMTP  =  ti=1nlogPθ(xt+izt).L_{\text{MTP}} \;=\; -\sum_t \sum_{i=1}^{n} \log P_\theta\big(x_{t+i}\mid z_t\big).

From position tt, head ii predicts xt+ix_{t+i}. Pick nn and a flavor and read the loss off directly:

multi-token prediction · n heads
Thecatsatonthe← position t
shared transformer trunk → zt
head 1
xt+1
mat
head 2
xt+2
and
head 3
xt+3
the
head 4
xt+4
dog
L = log P(xt+1 | zt)log P(xt+2 | zt)log P(xt+3 | zt)log P(xt+4 | zt)
n =· n=4 is the sweet spot for ~7B on code

Meta's heads are independent off the same trunk — they don't see each other's predictions. Cheaper, and at inference you can discard the extra heads for a zero-overhead next-token model, or keep them to self-speculate.

The obvious worry is memory: materializing nn full vocabulary-sized logit tensors at once is brutal. The fix is mundane and important — compute each head's forward/backward sequentially and accumulate gradients at the trunk, so peak memory stays flat in nn. You pay a little time, not a lot of VRAM.

Two flavors: parallel heads vs sequential modules

The flavor toggle above is the real architectural fork.

Meta's multi-token prediction architecture: a shared trunk feeds four parallel output heads, each predicting one of the next four tokens; at inference the next-token head is used for generation and the others for speculative speedup.
Meta's MTP architecture (Gloeckle et al., 2024, Figure 1): a shared trunk with n parallel output heads sharing one unembedding. At training, all heads predict; at inference, the extra heads draft for self-speculative decoding.
hik  =  Mk[RMSNorm(hik1); RMSNorm(Emb(ti+k))]h'^{\,k}_i \;=\; M_k\big[\operatorname{RMSNorm}(h^{\,k-1}_i)\,;\ \operatorname{RMSNorm}(\operatorname{Emb}(t_{i+k}))\big]
DeepSeek-V3's sequential multi-token prediction: the main model plus MTP modules each predict a further token, with shared embedding and output head, passing hidden states forward to preserve the causal chain.
DeepSeek-V3's MTP (Figure 3): sequential MTP modules that keep the complete causal chain — each module conditions on the previous depth's prediction, unlike Meta's independent heads.

The trade is exactly what you'd guess: parallel heads are cheaper and discardable; sequential modules draft more coherent blocks because each step sees the last.

Win one: it trains a better model

The training objective is the whole reason for the quality gain. Slide the window: from each position the model predicts the next nn tokens, so nn loss terms fire where an ordinary model gets one. Flip nn between 1 and 4 to feel the supervision get denser:

multi-token training objective · live
+1+2+3+4thecatsatonawarmmatbythefire.position t
L(t) = log P(cat | zt)log P(sat | zt)log P(on | zt)log P(a | zt)4 signals at this position
n =

At n=4 each position supplies 4 loss terms, forcing the trunk to encode information about tokens further ahead — a denser signal that, at scale, yields a better model even after the extra heads are thrown away.

Predicting further is a denser, more demanding signal, and at scale it produces a model that's better even when you throw the extra heads away. Meta's 13B model, trained with MTP, solves materially more coding problems than the matched next-token model on the same data and compute:

13B MTP vs matched next-token model — relative gain (%)
MBPP (more solved)
17%
HumanEval (more solved)
12%
05101520
Scaling plots of MBPP and HumanEval performance across six model sizes from 300M to 13B, showing multi-token prediction overtaking next-token prediction as model size grows.
Scaling (Gloeckle et al., Figure 3): the MTP advantage on code grows with model size — small models barely benefit, but by multi-billion scale MTP clearly overtakes next-token training.

Two caveats keep this honest. The gains scale with model size — small models barely benefit, and very large nn erodes quality (n=4n=4 is the sweet spot for ~7B on code). And whether the quality gain transfers beyond pretraining is genuinely contested: "Multi-Token Prediction Needs Registers" (MuToR, NeurIPS 2025) exists precisely because the benefit hasn't consistently generalized to fine-tuning without help. MTP is not a free quality lunch in every regime.

Win two: it decodes ~3x faster

The extra heads are a built-in draft model. They cheaply propose the next n1n-1 tokens in one pass; the model then verifies all of them in a single batched forward and accepts the longest correct prefix — self-speculative decoding. This is lossless: rejected drafts fall back to the true next-token distribution, so output is unchanged.

That verify-and-accept scheme is the part with Google's fingerprints — Stern, Shazeer & Uszkoreit's 2018 blockwise parallel decoding at Google Brain is the ancestor: predict several future positions with auxiliary heads, then accept the longest correct prefix. MTP simply folds those auxiliary heads into the training objective.

How much you gain depends entirely on the acceptance rate — how often a drafted token matches what the target would have produced. Because acceptance compounds along the block, the marginal value of the ii-th drafted token decays, and the whole scheme hits a ceiling no matter how far you draft. That's why a more coherent drafter is worth more than a longer one — and why DeepSeek's sequential modules (higher acceptance) and DSpark's semi-autoregressive head exist at all. Drag the acceptance rate and block size:

self-speculative decoding · expected tokens per target forward
ceiling 3.31n=1n=8
per-position acceptance α0.70
block size n =
tokens / forward
2.77×
marginal (n→n+1)
+0.17
ceiling (n→∞)
3.3×

The gain from the i-th drafted token decays like αi, so the curve saturates at 1 + α/(1−α) no matter how long you draft. High acceptance pushes the ceiling up; low acceptance means even n=8 barely beats n=2. That decay is exactly why n≈4 is the practical sweet spot — and why a more coherent drafter (higher α) buys more than a longer one.

The speedups, across the lineage:

Self-speculative decoding speedup (×, lossless unless noted)
Meta MTP, n=4 (code)
3×
Meta, 8-byte model
6.4×
DeepSeek-V3 (D=1)
1.8×
Google Brain 2018 (lossless)
4×
02468

DeepSeek-V3 is the cleanest production data point: with MTP depth D=1D=1 (predict two tokens total), the second token is accepted 85–90% of the time, and repurposing the module for speculative decoding gives ~1.8× tokens/sec. The training loss weight was annealed — λ=0.3\lambda = 0.3 for the first 10T tokens, then 0.10.1 for the remaining 4.8T — a detail worth noting because MTP at pretraining is a secondary objective, not the main one.

Google's actual 2026 role: applied MTP

Where Google genuinely shows up is deployment, not the objective:

The on-device framing is the interesting one: a separate draft model is a non-starter when you're counting megabytes on a phone, so folding the drafter into the main model as a frozen head is exactly the right move.

The open questions

MTP isn't a closed book — two recent threads are worth knowing because they bound where the simple story breaks.

Both reinforce the same lesson the speedup curve shows: the value is in acceptance and coherence, and the active research is about getting more of both without paying a full retrain.

Who did what

WorkOrgContribution
Blockwise parallel decoding (2018)Google Brainpredict-several-then-verify/accept — the decoding ancestor
Better & Faster LLMs via MTP (2024)Meta / FAIRthe canonical MTP training objective (n parallel heads)
Medusa (2024)academicmultiple decoding heads + tree attention (not Google)
DeepSeek-V3 MTP (2024)DeepSeek-AIsequential MTP modules at pretraining; ~1.8× TPS
MuToR — "MTP needs registers" (2025)academicregister tokens so MTP helps in fine-tuning
Gemma 4 / Gemini Nano MTP (2026)Googleapplied MTP speculative decoding, incl. on-device

What I make of it


Built on Meta's Better & Faster Large Language Models via Multi-token Prediction, the DeepSeek-V3 Technical Report (§2.2), Google Brain's Blockwise Parallel Decoding, MuToR, and Google's 2026 Gemini Nano frozen-MTP work.

share