Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Lucidrains X transformers MaskGIT Generation Tuning

From Leeroopedia



Knowledge Sources
Domains Deep_Learning, Optimization
Last Updated 2026-02-08 18:00 GMT

Overview

Tuning guide for non-autoregressive masked generation: number of demasking steps, schedule selection, BERT-style masking ratios, and MDLM loss weighting.

Description

The NonAutoregressiveWrapper implements MaskGIT-style iterative unmasking generation. The quality and speed of generation depend on several interrelated hyperparameters: the number of demasking steps, the masking schedule (linear vs cosine), the BERT-style token replacement ratios, and whether to use the Simple MDLM loss weighting. This heuristic captures the default values and their reasoning.

Usage

Use this heuristic when configuring non-autoregressive generation or tuning generation quality vs speed trade-off. Also useful when training a non-autoregressive model and choosing the loss weighting strategy.

The Insight (Rule of Thumb)

  • Demasking Steps: Default `steps=18`. More steps = higher quality but slower generation.
  • Schedule: `linear` (default) or `cosine` (from MaskGIT paper, arxiv 2202.04200). Cosine unmasks more tokens early, then refines.
  • BERT Masking Ratios:
    • `no_replace_prob=0.15` (15% of masked tokens stay unchanged)
    • `random_token_prob=0.1` (10% of masked tokens get random replacement)
    • These follow the original BERT paper convention
  • Self-Conditioning: `self_cond_train_prob=0.75` (use self-conditioning 75% of training steps)
  • MDLM Loss Weighting: `use_simple_mdlm_loss_weight=True` (enabled by default). Applies eq. 10 from Sahoo et al., weighting loss by schedule gradient divided by (1 - schedule value).
  • Token Critic: Optional second model that scores token confidence; `critic_loss_weight=1.0` balances generator and critic losses.

Reasoning

The 18-step default balances generation quality with inference speed. Each step unmasks a fraction of tokens according to the schedule, so 18 steps with a linear schedule unmasks approximately 5.5% of tokens per step. The cosine schedule front-loads unmasking (more tokens revealed early), which empirically produces better results in the MaskGIT paper.

The BERT-style masking ratios (15% unchanged, 10% random, 75% actual mask token) prevent the model from learning a shortcut of only predicting `[MASK]` tokens, improving robustness during generation when no mask tokens are present.

The Simple MDLM loss weighting (Sahoo et al. 2024) uses automatic differentiation to compute `schedule'(t) / (1 - schedule(t))` as loss weights, emphasizing tokens masked at later timesteps where the model should be more confident. This is enabled by default as it improves training.

Code Evidence

Default hyperparameters from `nonautoregressive_wrapper.py:110-120`:

def __init__(
    self,
    net,
    *,
    mask_id,
    steps = 18,
    self_cond = False,
    self_cond_train_prob = 0.75,
    no_replace_prob = 0.15,          # which percentage of the tokens masked will stay the same
    random_token_prob = 0.1,         # which percentage of tokens to be replaced with random token
    schedule = 'linear',
    can_mask_prev_unmasked = False,
    token_critic = None,
    self_token_critic = False,
    critic_loss_weight = 1.,
    use_simple_mdlm_loss_weight = True
):

MDLM loss weight computation (eq. 10) from `nonautoregressive_wrapper.py:155-164`:

if use_simple_mdlm_loss_weight:
    grad_and_value_schedule_fn = vmap(grad_and_value(self.schedule_fn))
    # eq (10)
    def loss_weight_fn(times):
        grad, value = grad_and_value_schedule_fn(times)
        return grad / (1. - value)
    self.loss_weight_fn = loss_weight_fn

Temperature annealing during generation from `nonautoregressive_wrapper.py:240-245`:

annealing_scale = steps_until_x0 / self.steps
temperature = start_temperature * annealing_scale
probs = (logits / max(temperature, 1e-3)).softmax(dim = -1)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment