Heuristic:Lucidrains X transformers MaskGIT Generation Tuning
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Optimization |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Tuning guide for non-autoregressive masked generation: number of demasking steps, schedule selection, BERT-style masking ratios, and MDLM loss weighting.
Description
The NonAutoregressiveWrapper implements MaskGIT-style iterative unmasking generation. The quality and speed of generation depend on several interrelated hyperparameters: the number of demasking steps, the masking schedule (linear vs cosine), the BERT-style token replacement ratios, and whether to use the Simple MDLM loss weighting. This heuristic captures the default values and their reasoning.
Usage
Use this heuristic when configuring non-autoregressive generation or tuning generation quality vs speed trade-off. Also useful when training a non-autoregressive model and choosing the loss weighting strategy.
The Insight (Rule of Thumb)
- Demasking Steps: Default `steps=18`. More steps = higher quality but slower generation.
- Schedule: `linear` (default) or `cosine` (from MaskGIT paper, arxiv 2202.04200). Cosine unmasks more tokens early, then refines.
- BERT Masking Ratios:
- `no_replace_prob=0.15` (15% of masked tokens stay unchanged)
- `random_token_prob=0.1` (10% of masked tokens get random replacement)
- These follow the original BERT paper convention
- Self-Conditioning: `self_cond_train_prob=0.75` (use self-conditioning 75% of training steps)
- MDLM Loss Weighting: `use_simple_mdlm_loss_weight=True` (enabled by default). Applies eq. 10 from Sahoo et al., weighting loss by schedule gradient divided by (1 - schedule value).
- Token Critic: Optional second model that scores token confidence; `critic_loss_weight=1.0` balances generator and critic losses.
Reasoning
The 18-step default balances generation quality with inference speed. Each step unmasks a fraction of tokens according to the schedule, so 18 steps with a linear schedule unmasks approximately 5.5% of tokens per step. The cosine schedule front-loads unmasking (more tokens revealed early), which empirically produces better results in the MaskGIT paper.
The BERT-style masking ratios (15% unchanged, 10% random, 75% actual mask token) prevent the model from learning a shortcut of only predicting `[MASK]` tokens, improving robustness during generation when no mask tokens are present.
The Simple MDLM loss weighting (Sahoo et al. 2024) uses automatic differentiation to compute `schedule'(t) / (1 - schedule(t))` as loss weights, emphasizing tokens masked at later timesteps where the model should be more confident. This is enabled by default as it improves training.
Code Evidence
Default hyperparameters from `nonautoregressive_wrapper.py:110-120`:
def __init__(
self,
net,
*,
mask_id,
steps = 18,
self_cond = False,
self_cond_train_prob = 0.75,
no_replace_prob = 0.15, # which percentage of the tokens masked will stay the same
random_token_prob = 0.1, # which percentage of tokens to be replaced with random token
schedule = 'linear',
can_mask_prev_unmasked = False,
token_critic = None,
self_token_critic = False,
critic_loss_weight = 1.,
use_simple_mdlm_loss_weight = True
):
MDLM loss weight computation (eq. 10) from `nonautoregressive_wrapper.py:155-164`:
if use_simple_mdlm_loss_weight:
grad_and_value_schedule_fn = vmap(grad_and_value(self.schedule_fn))
# eq (10)
def loss_weight_fn(times):
grad, value = grad_and_value_schedule_fn(times)
return grad / (1. - value)
self.loss_weight_fn = loss_weight_fn
Temperature annealing during generation from `nonautoregressive_wrapper.py:240-245`:
annealing_scale = steps_until_x0 / self.steps
temperature = start_temperature * annealing_scale
probs = (logits / max(temperature, 1e-3)).softmax(dim = -1)
Related Pages
- Implementation:Lucidrains_X_transformers_NonAutoregressiveWrapper_Init
- Implementation:Lucidrains_X_transformers_NonAutoregressiveWrapper_Forward
- Implementation:Lucidrains_X_transformers_NonAutoregressiveWrapper_Generate
- Principle:Lucidrains_X_transformers_Non_Autoregressive_Wrapper_Setup
- Principle:Lucidrains_X_transformers_Masked_Token_Prediction_Training
- Principle:Lucidrains_X_transformers_Iterative_Masked_Generation