Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Lucidrains X transformers Iterative Masked Generation

From Leeroopedia


Metadata

Field Value
Paper Mask-Predict
Paper MaskGIT
Repository x-transformers
Domains Deep_Learning, Generative_Models, Inference
Last Updated 2026-02-08 18:00 GMT

Overview

Iterative refinement generation strategy that produces sequences by progressively unmasking tokens from a fully masked initial sequence over multiple denoising steps.

Description

Unlike autoregressive generation (one token at a time, left to right), iterative masked generation starts from a fully masked sequence and refines it over T steps. At each step the algorithm:

  1. Predicts all masked positions by running the full sequence through the bidirectional model.
  2. Samples tokens using temperature-scaled probabilities, where the temperature is annealed from start_temperature down toward 0 over the course of generation.
  3. Computes confidence scores for each prediction — either (1 - softmax(logits)) gathered at the sampled token indices, or token critic output combined with Gumbel noise.
  4. Keeps the k most confident predictions and re-masks the rest, where k increases each step according to the masking schedule.

This enables parallel prediction at each step, making it significantly faster than autoregressive generation for fixed-length sequences. An optional token critic provides better confidence estimates than raw logit probabilities, leading to higher quality generations.

The method also supports self-conditioning, where the model's embeddings from the previous step are fed back as an additive conditioning signal, and the option to prevent re-masking of previously unmasked positions (controlled by the can_mask_prev_unmasked flag).

Usage

Use for generating fixed-length discrete sequences where speed is important. Typical settings use steps=18 which is sufficient for good quality. Higher start_temperature values increase diversity in the generated outputs:

# Generate a batch of sequences
generated = model.generate(batch_size=16, start_temperature=1.0)
print(generated.shape)  # (16, max_seq_len)

Theoretical Basis

Iterative refinement: at step s of T total steps, the number of tokens to remain masked is:

mask_count = schedule(s / T) * seq_len

The schedule function (linear or cosine) determines how quickly tokens are unmasked. With the linear schedule, tokens are revealed at a constant rate. With the cosine schedule (MaskGIT), more tokens are unmasked in later steps, concentrating early steps on high-level structure.

Confidence scoring uses one of two strategies:

  • Logit-based (default): confidence = 1 - softmax(logits) gathered at the sampled token positions. Lower softmax probability at the sampled index means lower confidence, so these positions are more likely to be re-masked.
  • Token critic: a learned binary classifier scores whether each token is correct. Gumbel noise is added to the critic scores, scaled by the annealing factor, to inject stochasticity: score = critic(seq) + noise_scale * gumbel_noise * (steps_until_x0 / T).

Temperature annealing reduces randomness over the course of generation:

T_s = T_start * (T - s) / T

This ensures that early steps explore diverse predictions while final steps converge to high-confidence greedy selections.

Related Pages

Implemented By

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment