Implementation:Lucidrains X transformers NonAutoregressiveWrapper Init
Appearance
Metadata
| Field | Value |
|---|---|
| Repository | x-transformers |
| Domains | Generative_Models, NLP |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for wrapping encoder models with non-autoregressive masked prediction training and iterative generation provided by the x-transformers library.
Description
NonAutoregressiveWrapper wraps a TransformerWrapper (with Encoder) to add:
- Masked token prediction training with configurable schedule
- Iterative demasking generation (MaskGIT-style)
- Optional self-conditioning on embeddings
- BERT-style no-replace and random-token augmentation
- Optional token critic for generation quality
- MDLM loss weighting
The steps parameter controls the number of iterative demasking steps during generation.
Usage
Import after creating a TransformerWrapper with Encoder. Wrap the model with a mask token ID.
Code Reference
| Field | Value |
|---|---|
| Repository | x-transformers |
| File | x_transformers/nonautoregressive_wrapper.py
|
| Lines | L99-186 |
Signature:
class NonAutoregressiveWrapper(Module):
def __init__(
self,
net,
*,
mask_id,
steps = 18,
self_cond = False,
self_cond_train_prob = 0.75,
no_replace_prob = 0.15,
random_token_prob = 0.1,
schedule = 'linear',
can_mask_prev_unmasked = False,
token_critic: TransformerWrapper | None = None,
self_token_critic = False,
critic_loss_weight = 1.,
use_simple_mdlm_loss_weight = True
):
Import:
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| net | TransformerWrapper |
Yes | Encoder model to wrap |
| mask_id | int |
Yes | Token ID used for masking |
| steps | int |
No | Number of iterative demasking steps (default 18) |
| schedule | str |
No | Masking schedule: 'linear' or 'cosine'
|
| self_cond | bool |
No | Enable self-conditioning on embeddings |
| self_token_critic | bool |
No | Use self-token critic for generation |
| no_replace_prob | float |
No | Fraction of masked tokens kept unchanged (BERT-style, 0.15) |
| random_token_prob | float |
No | Fraction replaced with random tokens (BERT-style, 0.1) |
| use_simple_mdlm_loss_weight | bool |
No | Apply MDLM loss weighting (default True) |
Outputs
| Name | Type | Description |
|---|---|---|
| wrapper | NonAutoregressiveWrapper |
Model with .forward() for training and .generate() for iterative generation
|
Usage Examples
from x_transformers import TransformerWrapper, Encoder
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
NUM_TOKENS = 256
MASK_ID = NUM_TOKENS
model = TransformerWrapper(
num_tokens = NUM_TOKENS + 1,
max_seq_len = 512,
attn_layers = Encoder(
dim = 512,
depth = 6,
heads = 8
)
)
model = NonAutoregressiveWrapper(
model,
mask_id = MASK_ID,
steps = 18,
schedule = 'cosine',
self_cond = True
).cuda()
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment