Principle:Lucidrains X transformers Non Autoregressive Wrapper Setup
Metadata
| Field | Value |
|---|---|
| Papers | Mask-Predict, MaskGIT, Simple Masked Diffusion LM (MDLM) |
| Repository | x-transformers |
| Domains | Deep_Learning, Generative_Models, NLP |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Task wrapper pattern that transforms a bidirectional encoder into a non-autoregressive model with masked token prediction training and iterative refinement generation.
Description
The NonAutoregressiveWrapper wraps a TransformerWrapper (with Encoder layers) to enable MaskGIT/Mask-Predict style training and generation. During training, random tokens are masked according to a schedule (linear or cosine) and the model predicts the masked tokens. During generation, the model starts from a fully masked sequence and iteratively unmasks tokens over a fixed number of steps, choosing the most confident predictions first.
Supported features:
- Self-conditioning on embeddings
- BERT-style no-replace and random-token augmentation
- Token critic for improved generation quality
- MDLM loss weighting
Usage
Use when building non-autoregressive generative models. Ideal for tasks where parallel generation is desired (image tokens, discrete diffusion, etc.) or where generation speed is critical (all tokens predicted simultaneously, refined iteratively).
Theoretical Basis
Mask-Predict algorithm:
- Start with a fully masked sequence.
- At each step t, predict all masked positions.
- Keep the k most confident predictions, re-mask the rest.
- k increases with each step according to a schedule (linear or cosine).
- After T steps, all tokens are unmasked.
Masking schedule controls how many tokens remain masked at each step:
- Linear:
mask_ratio = 1 - t - Cosine:
mask_ratio = cos(pi * t / 2)
MDLM loss weighting (Sahoo et al.):
weight = schedule'(t) / (1 - schedule(t))
This upweights harder-to-predict positions.