Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers NonAutoregressiveWrapper Forward

From Leeroopedia


Metadata

Field Value
Repository x-transformers
Domains Generative_Models, Training
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for training non-autoregressive masked prediction models provided by the x-transformers library.

Description

The forward method of NonAutoregressiveWrapper applies random masking to input sequences, runs the masked input through the encoder, and computes cross-entropy loss on masked positions.

The method performs the following steps:

  1. Sample masking ratio — for each batch element, sample t ~ U(0, 1) and compute num_tokens_mask = schedule(t) * seq_len.
  2. Apply BERT-style augmentation — optionally keep some masked positions unchanged (no-replace) or replace them with random tokens.
  3. Self-conditioning — optionally run a no-gradient forward pass and feed the resulting embeddings back as a conditioning signal.
  4. Forward pass — run the masked sequence through the wrapped TransformerWrapper to get logits.
  5. Compute generator loss — cross-entropy on masked positions only, optionally weighted by MDLM loss weights.
  6. Compute critic loss (optional) — sample tokens from logits via Gumbel sampling, feed the generated sequence to the token critic, and compute binary cross-entropy against whether each token matches the original.
  7. Return a Losses namedtuple with total loss, generator loss, and optional critic loss.

Usage

Call during each training step. Pass a batch of complete (unmasked) token ids of shape (batch, max_seq_len). Returns a Losses namedtuple.

losses = model(tokens)           # returns Losses namedtuple
losses.loss.backward()           # backprop on total loss

Code Reference

Field Value
Repository x-transformers
File x_transformers/nonautoregressive_wrapper.py
Lines L275–380

Signature:

def forward(
    self,
    x,
    only_train_generator = False,
    only_train_critic = False,
    generator_sample_temperature = None,
    **kwargs
) -> Losses:  # namedtuple(loss, generator_loss, critic_loss)

Import:

from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper

I/O Contract

Inputs

Name Type Required Description
x Tensor Yes Complete unmasked token ids of shape (batch, max_seq_len)
only_train_generator bool No If True, returns generator loss only, ignores critic. Default: False
only_train_critic bool No If True, runs generator in no-grad mode and returns critic loss only. Default: False
generator_sample_temperature float or None No Temperature for Gumbel sampling when generating tokens for critic training. Default: random value in [0, 1)

Outputs

Name Type Description
loss Tensor Total scalar loss (generator + weighted critic, or just one if selective training)
generator_loss Tensor or None Masked cross-entropy loss for token prediction (None if only_train_critic=True)
critic_loss Tensor or None Binary cross-entropy loss for token critic (None if no critic or only_train_generator=True)

Usage Examples

Basic Training Loop

from x_transformers import TransformerWrapper, Encoder
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
import torch

# Setup model
net = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = 512,
    attn_layers = Encoder(dim = 512, depth = 6, heads = 8)
)
model = NonAutoregressiveWrapper(net, mask_id = 255).cuda()

# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
tokens = torch.randint(0, 256, (32, 512)).cuda()

losses = model(tokens)
losses.loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Total: {losses.loss.item()}, Generator: {losses.generator_loss.item()}")

Selective Critic Training

# Train only the critic head (generator weights frozen)
losses = model(tokens, only_train_critic=True)
losses.loss.backward()  # only critic parameters receive gradients

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment