Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers NonAutoregressiveWrapper Init

From Leeroopedia


Metadata

Field Value
Repository x-transformers
Domains Generative_Models, NLP
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for wrapping encoder models with non-autoregressive masked prediction training and iterative generation provided by the x-transformers library.

Description

NonAutoregressiveWrapper wraps a TransformerWrapper (with Encoder) to add:

  • Masked token prediction training with configurable schedule
  • Iterative demasking generation (MaskGIT-style)
  • Optional self-conditioning on embeddings
  • BERT-style no-replace and random-token augmentation
  • Optional token critic for generation quality
  • MDLM loss weighting

The steps parameter controls the number of iterative demasking steps during generation.

Usage

Import after creating a TransformerWrapper with Encoder. Wrap the model with a mask token ID.

Code Reference

Field Value
Repository x-transformers
File x_transformers/nonautoregressive_wrapper.py
Lines L99-186

Signature:

class NonAutoregressiveWrapper(Module):
    def __init__(
        self,
        net,
        *,
        mask_id,
        steps = 18,
        self_cond = False,
        self_cond_train_prob = 0.75,
        no_replace_prob = 0.15,
        random_token_prob = 0.1,
        schedule = 'linear',
        can_mask_prev_unmasked = False,
        token_critic: TransformerWrapper | None = None,
        self_token_critic = False,
        critic_loss_weight = 1.,
        use_simple_mdlm_loss_weight = True
    ):

Import:

from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper

I/O Contract

Inputs

Name Type Required Description
net TransformerWrapper Yes Encoder model to wrap
mask_id int Yes Token ID used for masking
steps int No Number of iterative demasking steps (default 18)
schedule str No Masking schedule: 'linear' or 'cosine'
self_cond bool No Enable self-conditioning on embeddings
self_token_critic bool No Use self-token critic for generation
no_replace_prob float No Fraction of masked tokens kept unchanged (BERT-style, 0.15)
random_token_prob float No Fraction replaced with random tokens (BERT-style, 0.1)
use_simple_mdlm_loss_weight bool No Apply MDLM loss weighting (default True)

Outputs

Name Type Description
wrapper NonAutoregressiveWrapper Model with .forward() for training and .generate() for iterative generation

Usage Examples

from x_transformers import TransformerWrapper, Encoder
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper

NUM_TOKENS = 256
MASK_ID = NUM_TOKENS

model = TransformerWrapper(
    num_tokens = NUM_TOKENS + 1,
    max_seq_len = 512,
    attn_layers = Encoder(
        dim = 512,
        depth = 6,
        heads = 8
    )
)

model = NonAutoregressiveWrapper(
    model,
    mask_id = MASK_ID,
    steps = 18,
    schedule = 'cosine',
    self_cond = True
).cuda()

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment