Principle:Lucidrains X transformers Iterative Masked Generation
Metadata
| Field | Value |
|---|---|
| Paper | Mask-Predict |
| Paper | MaskGIT |
| Repository | x-transformers |
| Domains | Deep_Learning, Generative_Models, Inference |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Iterative refinement generation strategy that produces sequences by progressively unmasking tokens from a fully masked initial sequence over multiple denoising steps.
Description
Unlike autoregressive generation (one token at a time, left to right), iterative masked generation starts from a fully masked sequence and refines it over T steps. At each step the algorithm:
- Predicts all masked positions by running the full sequence through the bidirectional model.
- Samples tokens using temperature-scaled probabilities, where the temperature is annealed from
start_temperaturedown toward 0 over the course of generation. - Computes confidence scores for each prediction — either
(1 - softmax(logits))gathered at the sampled token indices, or token critic output combined with Gumbel noise. - Keeps the k most confident predictions and re-masks the rest, where k increases each step according to the masking schedule.
This enables parallel prediction at each step, making it significantly faster than autoregressive generation for fixed-length sequences. An optional token critic provides better confidence estimates than raw logit probabilities, leading to higher quality generations.
The method also supports self-conditioning, where the model's embeddings from the previous step are fed back as an additive conditioning signal, and the option to prevent re-masking of previously unmasked positions (controlled by the can_mask_prev_unmasked flag).
Usage
Use for generating fixed-length discrete sequences where speed is important. Typical settings use steps=18 which is sufficient for good quality. Higher start_temperature values increase diversity in the generated outputs:
# Generate a batch of sequences
generated = model.generate(batch_size=16, start_temperature=1.0)
print(generated.shape) # (16, max_seq_len)
Theoretical Basis
Iterative refinement: at step s of T total steps, the number of tokens to remain masked is:
mask_count = schedule(s / T) * seq_len
The schedule function (linear or cosine) determines how quickly tokens are unmasked. With the linear schedule, tokens are revealed at a constant rate. With the cosine schedule (MaskGIT), more tokens are unmasked in later steps, concentrating early steps on high-level structure.
Confidence scoring uses one of two strategies:
- Logit-based (default): confidence =
1 - softmax(logits)gathered at the sampled token positions. Lower softmax probability at the sampled index means lower confidence, so these positions are more likely to be re-masked. - Token critic: a learned binary classifier scores whether each token is correct. Gumbel noise is added to the critic scores, scaled by the annealing factor, to inject stochasticity:
score = critic(seq) + noise_scale * gumbel_noise * (steps_until_x0 / T).
Temperature annealing reduces randomness over the course of generation:
T_s = T_start * (T - s) / T
This ensures that early steps explore diverse predictions while final steps converge to high-confidence greedy selections.