~/satyajit

iLLaDA: how far a masked-diffusion language model scales

mdjsonmcp

2026-06-28 · 11 min · llm · diffusion · language-models · architecture · explainer

Almost every language model you use is autoregressive: it factorizes text left-to-right, p(x)=ip(xix<i)p(x) = \prod_i p(x_i \mid x_{<i}), with a causal attention mask, and generates one token per forward pass. It works so well that the alternatives barely get airtime.

iLLaDA is one of the alternatives, scaled up until it's hard to ignore. It's an 8B masked diffusion language model with fully bidirectional attention, trained from scratch on 12 trillion tokens by a team from Renmin University and ByteDance Seed — the direct successor to LLaDA. No causal mask, no left-to-right factorization. The question it's built to answer: can a bidirectional diffusion model, trained from scratch, actually keep up with a strong autoregressive model? The honest answer turns out to be yes for base models, not yet for instruct — and the path to that answer is worth understanding.

How a masked diffusion language model works

Forget Gaussian noise. The "diffusion" here is masking — a discrete, absorbing-state process over tokens.

The forward process: corrupt by masking

Pick a masking ratio tU[0,1]t \sim \mathcal{U}[0,1]. Each token is independently replaced by a special [MASK] token with probability tt. At t=0t=0 the sequence is clean; at t=1t=1 it's fully masked. Drag tt and watch the corruption — and the loss weighting — change:

forward masking process · loss on masked positions, weighted 1/t
masked[MASK]predictsthe[MASK]tokens[MASK]everymasked[MASK]inone[MASK]
masking ratio t0.40
t→0 · nearly cleant→1 · fully masked
masked tokens
5/13
scored positions
5
loss weight 1/t
2.50×

The “noise” here is masking, not Gaussian — an absorbing-state diffusion over discrete tokens. The model sees the corrupted sequence and predicts the originals, but the loss only counts the 5 masked positions, scaled by 1/t = 2.50 so heavily- and lightly-masked samples both pull their weight. Averaged over all t, this is an upper bound on the negative log-likelihood — the same objective iLLaDA keeps through pre-training and fine-tuning.

The model pθp_\theta sees the corrupted sequence xtx_t and is trained to predict the original tokens at every masked position at once. The objective is a masked cross-entropy, computed only on masked positions and reweighted by 1/t1/t:

L(θ)  =  Et,x0,xt ⁣[1ti=1L1 ⁣[xti=M]logpθ ⁣(x0ixt)]\mathcal{L}(\theta) \;=\; -\,\mathbb{E}_{t,\,x_0,\,x_t}\!\left[\frac{1}{t}\sum_{i=1}^{L} \mathbf{1}\!\left[x_t^{i} = \mathrm{M}\right]\,\log p_\theta\!\left(x_0^{i} \mid x_t\right)\right]

The indicator 1[xti=M]\mathbf{1}[x_t^i = \mathrm{M}] restricts the loss to masked positions; the 1/t1/t factor re-normalizes so heavily- and lightly-masked samples both contribute correctly. Averaged over tt, this is a Monte-Carlo upper bound on the negative log-likelihood — a principled training objective, not a heuristic. iLLaDA keeps this same objective through pre-training and supervised fine-tuning.

Bidirectional attention comes for free

An autoregressive model must hide the future — if token ii could attend to token i+1i{+}1, it would just read the answer it's supposed to predict. A masked diffusion model predicts masked positions anywhere in the sequence, not "the next" one, so there's nothing to hide. Every position attends to every other, left and right:

attention mask
amaskedLMreadsbothsidesamaskedLMreadsbothsidesrow = query · column = key it attends to

Bidirectional: every query attends to all 6 positions, future included. A masked diffusion LM predicts maskedpositions, not “the next” one, so there’s nothing to hide — it gets full left-and-right context at every layer, which is exactly what helps on infilling and global-structure tasks.

That full context is the structural argument for diffusion LMs: on infilling and tasks where later text disambiguates earlier text, seeing both sides at every layer should help.

Generation: unmask in parallel, over a few steps

Generation runs the process backward. Start from a block of all-[MASK] tokens. Each denoising step, the model predicts every masked position, commits its most confident predictions, and re-masks the low-confidence ones to try again next step. A whole block resolves over a handful of steps, in confidence order — not reading order. Flip between the two paradigms:

generation order
adiffusionLMunmasksmanytokensatoncenotonebyone
denoising step 0/4 · 0/12 tokens unmasked
forward passes
4
tokens / pass
many
attention
bidirectional

Diffusion decodes the whole block at once and commits its most confident predictions each step, so a 12-token answer lands in ~4 passes — and because there's no causal mask, every position conditions on the entire sequence, left and right. The cost is that quality depends on how many denoising steps you spend.

This is the crux. Autoregression spends one forward pass per output token, in series. Diffusion spends a fixed, smaller number of denoising passes over the whole block — the promise being fewer sequential steps, at the cost of needing enough steps for quality.

trainingclean x₀→ mask (t) →corrupted xₜ→ predict →x̂₀ on maskedgenerationall [MASK]→ unmask conf. →partly filled→ repeat →completere-mask low-confidence positions
The two directions of the same model. Training: mask a fraction t of the tokens and predict the originals (loss on masked positions only). Generation: start fully masked and iteratively unmask the confident predictions, re-masking the rest, until the block resolves.

What iLLaDA changes over LLaDA

iLLaDA is, more than anything, a careful scale-up of LLaDA — proof that the recipe keeps paying off with more tokens and a better post-training pass.

A bigger, leaner backbone

The architecture is a standard dense Transformer (RMSNorm, SwiGLU, RoPE, no biases), but re-tuned for cheaper inference:

iLLaDALLaDA
Attention heads3232
Key/Value heads8 (GQA)32 (MHA)
FFN dim14,33612,288
Vocabulary155,136126,464
Max sequence length81924096
Embedding / LM headtieduntied
Total parameters7.62B8.02B

The load-bearing change is grouped-query attention (8 KV heads instead of 32), adopted to shrink the cached key/value footprint at inference — plus a larger vocab, doubled context, and tied embeddings.

The headline spend: 12T tokens

And the fine-tuning clearly hadn't saturated. The SFT-epoch ablation rises monotonically through all 12 epochs (they stopped on compute, not convergence):

4560759036912SFT epochGSM8KMATHMMLU-Pro
SFT-epoch ablation (from the paper, Figure 1), redrawn. Accuracy on GSM8K, MATH, and MMLU-Pro keeps climbing through 12 epochs of fine-tuning — the curve had not flattened where they stopped.

Two inference-side tricks

The results, honestly

Two stories live in these tables, and they point in different directions.

Base models: genuine parity with Qwen2.5

As a base model, iLLaDA improves broadly over LLaDA and lands even with Qwen2.5-7B on average — winning several benchmarks outright:

Base models — average over 8 benchmarks (%)
iLLaDA 8B
63.9
Qwen2.5 7B
63.3
Dream 7B
61.4
LLaDA 8B
51.1
020406080

The per-benchmark picture, with the gains over LLaDA that the abstract leads on:

BaseiLLaDALLaDAQwen2.5Δ vs LLaDA
MMLU74.865.971.9+8.9
BBH71.349.763.9+21.6
ARC-C60.845.951.5+14.9
HellaSwag76.670.579.0+6.1
GSM8K81.970.378.9+11.6
MATH38.431.441.1+7.0
HumanEval50.035.456.7+14.6
MBPP57.840.063.6+17.8

iLLaDA-Base beats Qwen2.5-Base on MMLU, BBH, ARC-C, and GSM8K; Qwen still wins on HellaSwag, MATH, and code. But the average edges ahead — and that's the real result: a from-scratch bidirectional diffusion model matching a strong autoregressive base.

Instruct models: the gap that's left

This is the part the abstract's "competitive on several benchmarks" softens. After instruction tuning, iLLaDA trails Qwen2.5 by ~10 points on average, with double-digit gaps on the hard reasoning and coding tasks:

Instruct models — average over 7 benchmarks (%)
Qwen2.5 7B
77.1
iLLaDA 8B
67.1
Dream 7B
60.2
LLaDA 8B
54.5
020406080
InstructiLLaDALLaDAQwen2.5Δ vs LLaDA
MMLU71.665.576.6+6.1
MMLU-Pro52.337.056.3+15.3
GSM8K89.077.591.6+11.5
MATH56.742.275.5+14.5
HumanEval65.949.484.8+16.5
MBPP58.041.079.2+17.0

The improvement over LLaDA is huge and real (+12.6 average). The gap to Qwen on MATH (56.7 vs 75.5), HumanEval (65.9 vs 84.8), and MBPP (58.0 vs 79.2) is also real, and the authors don't hide it — they point to the lack of RL alignment as part of the cause.

What I make of it

The fair summary: bidirectional masked diffusion is now a scalable paradigm at parity with autoregressive base models — no longer something you can wave off — but not yet a proven win on post-trained quality, and not yet a proven efficiency advantage. That's a meaningful place to have gotten to, stated without the gloss.


Built on Improved Large Language Diffusion Models (Nie et al., Renmin University & ByteDance Seed, 2026) and its predecessor LLaDA. All numbers are from the paper's Tables 1–3; the SFT-epoch curves are redrawn from its Figure 1.

share