Implementation:Lucidrains X transformers NonAutoregressiveWrapper Forward
Appearance
Metadata
| Field | Value |
|---|---|
| Repository | x-transformers |
| Domains | Generative_Models, Training |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for training non-autoregressive masked prediction models provided by the x-transformers library.
Description
The forward method of NonAutoregressiveWrapper applies random masking to input sequences, runs the masked input through the encoder, and computes cross-entropy loss on masked positions.
The method performs the following steps:
- Sample masking ratio — for each batch element, sample
t ~ U(0, 1)and computenum_tokens_mask = schedule(t) * seq_len. - Apply BERT-style augmentation — optionally keep some masked positions unchanged (no-replace) or replace them with random tokens.
- Self-conditioning — optionally run a no-gradient forward pass and feed the resulting embeddings back as a conditioning signal.
- Forward pass — run the masked sequence through the wrapped
TransformerWrapperto get logits. - Compute generator loss — cross-entropy on masked positions only, optionally weighted by MDLM loss weights.
- Compute critic loss (optional) — sample tokens from logits via Gumbel sampling, feed the generated sequence to the token critic, and compute binary cross-entropy against whether each token matches the original.
- Return a
Lossesnamedtuple with total loss, generator loss, and optional critic loss.
Usage
Call during each training step. Pass a batch of complete (unmasked) token ids of shape (batch, max_seq_len). Returns a Losses namedtuple.
losses = model(tokens) # returns Losses namedtuple
losses.loss.backward() # backprop on total loss
Code Reference
| Field | Value |
|---|---|
| Repository | x-transformers |
| File | x_transformers/nonautoregressive_wrapper.py
|
| Lines | L275–380 |
Signature:
def forward(
self,
x,
only_train_generator = False,
only_train_critic = False,
generator_sample_temperature = None,
**kwargs
) -> Losses: # namedtuple(loss, generator_loss, critic_loss)
Import:
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
x |
Tensor |
Yes | Complete unmasked token ids of shape (batch, max_seq_len)
|
only_train_generator |
bool |
No | If True, returns generator loss only, ignores critic. Default: False
|
only_train_critic |
bool |
No | If True, runs generator in no-grad mode and returns critic loss only. Default: False
|
generator_sample_temperature |
float or None |
No | Temperature for Gumbel sampling when generating tokens for critic training. Default: random value in [0, 1)
|
Outputs
| Name | Type | Description |
|---|---|---|
loss |
Tensor |
Total scalar loss (generator + weighted critic, or just one if selective training) |
generator_loss |
Tensor or None |
Masked cross-entropy loss for token prediction (None if only_train_critic=True)
|
critic_loss |
Tensor or None |
Binary cross-entropy loss for token critic (None if no critic or only_train_generator=True)
|
Usage Examples
Basic Training Loop
from x_transformers import TransformerWrapper, Encoder
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
import torch
# Setup model
net = TransformerWrapper(
num_tokens = 256,
max_seq_len = 512,
attn_layers = Encoder(dim = 512, depth = 6, heads = 8)
)
model = NonAutoregressiveWrapper(net, mask_id = 255).cuda()
# Training loop
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
tokens = torch.randint(0, 256, (32, 512)).cuda()
losses = model(tokens)
losses.loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Total: {losses.loss.item()}, Generator: {losses.generator_loss.item()}")
Selective Critic Training
# Train only the critic head (generator weights frozen)
losses = model(tokens, only_train_critic=True)
losses.loss.backward() # only critic parameters receive gradients
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment