Workflow:Lucidrains X transformers Non Autoregressive Masked Generation
| Knowledge Sources | |
|---|---|
| Domains | Non_Autoregressive_Generation, Deep_Learning, Masked_Language_Modeling |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
End-to-end process for training a non-autoregressive masked token prediction model using x-transformers' NonAutoregressiveWrapper with iterative refinement generation.
Description
This workflow covers building and training a BERT-style masked language model that generates sequences through iterative refinement rather than left-to-right autoregression. It uses a TransformerWrapper wrapping an Encoder (bidirectional attention), wrapped by a NonAutoregressiveWrapper that implements the MaskGIT / Mask-Predict algorithm. During training, random tokens are masked and the model learns to predict them. During generation, the model starts from a fully masked sequence and iteratively unmasks tokens over multiple steps, selecting the most confident predictions first. This approach enables parallel generation and is significantly faster than autoregressive decoding for fixed-length outputs.
Usage
Execute this workflow when you need to generate fixed-length discrete sequences without the sequential bottleneck of autoregressive decoding. This is useful for tasks where generation speed is critical, such as image token generation (MaskGIT-style), text infilling, or any scenario where all output positions can be predicted in parallel. The iterative refinement process trades off a small number of forward passes for much faster generation compared to token-by-token decoding.
Execution Steps
Step 1: Install Dependencies
Install the x-transformers package. The NonAutoregressiveWrapper is part of the core library and is exported from the top-level package.
Key considerations:
- The wrapper depends on einx for masking operations
- Optional self-conditioning and token critic features may be enabled
Step 2: Prepare Masked Training Data
Create a dataset that yields fixed-length sequences of token IDs. Unlike autoregressive training, sequences must all be exactly max_seq_len in length. The NonAutoregressiveWrapper handles the masking internally during training.
Key considerations:
- All sequences must be the same length (equal to max_seq_len)
- A dedicated mask token ID must be reserved in the vocabulary
- The wrapper randomly masks tokens at varying rates during training (following a linear or cosine schedule)
Step 3: Configure Encoder Model
Instantiate a TransformerWrapper with an Encoder (bidirectional self-attention). Configure vocabulary size (including the mask token), maximum sequence length, model dimension, depth, and heads. The encoder uses bidirectional attention so all positions can attend to all other positions.
Key considerations:
- Use Encoder (not Decoder) since the model needs bidirectional context
- The vocabulary size must account for the special mask token
- The model outputs logits over the full vocabulary for each position
Step 4: Wrap with NonAutoregressiveWrapper
Wrap the TransformerWrapper in a NonAutoregressiveWrapper, providing the mask token ID and configuring the number of refinement steps, masking schedule, and optional features like self-conditioning or a token critic.
What happens:
- The wrapper manages random masking during training with configurable schedules (linear or cosine)
- Optional self-conditioning allows the model to condition on its own previous predictions
- An optional token critic (self or external) provides confidence scores for the iterative unmasking process
- The simple MDLM loss weighting from Sahoo et al. can be applied for improved training
Step 5: Train the Model
Run the training loop feeding token sequences into the NonAutoregressiveWrapper. The wrapper internally applies random masking at each step (varying the mask rate according to the schedule), runs the model, and computes the cross-entropy loss on the masked positions.
Key considerations:
- The forward() method returns a Losses namedtuple with total loss, generator loss, and optional critic loss
- Following the original BERT approach, some masked positions keep their original token (no_replace_prob) and some get random tokens (random_token_prob)
- When training with a token critic, losses for generator and critic can be trained separately
Step 6: Generate Sequences
Use the NonAutoregressiveWrapper's generate() method to produce sequences through iterative refinement. Starting from a fully masked sequence, the model predicts all positions in parallel, then re-masks the least confident predictions, repeating for a configured number of steps.
What happens:
- The sequence starts as all mask tokens
- At each step, the model predicts logits for all positions
- Tokens are sampled with Gumbel sampling and temperature annealing
- Confidence scores determine which tokens to keep and which to re-mask
- The number of tokens re-masked decreases at each step following the schedule
- After all steps, the sequence is fully unmasked