Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers NonAutoregressiveWrapper Generate

From Leeroopedia


Metadata

Field Value
Repository x-transformers
Domains Generative_Models, Inference
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for iterative masked sequence generation using progressive unmasking provided by the x-transformers library.

Description

The generate method of NonAutoregressiveWrapper starts from a fully masked sequence and iteratively unmasks tokens over self.steps steps. The method is decorated with @torch.no_grad() and automatically switches the model to eval mode during generation.

At each step, the method:

  1. Forward pass — runs the current sequence (with remaining masks) through the wrapped TransformerWrapper to obtain logits and embeddings.
  2. Top-k filtering — optionally filters logits to retain only the top (1 - filter_thres) fraction of the vocabulary at each position.
  3. Temperature-scaled sampling — samples tokens using Gumbel sampling with an annealing temperature that decreases linearly from start_temperature to 0 over the generation steps.
  4. Token placement — fills masked positions with the sampled tokens.
  5. Confidence scoring — computes a confidence score for each position, either via (1 - softmax(logits)) gathered at sampled indices or via the token critic with added Gumbel noise.
  6. Re-masking — selects the least confident positions (up to mask_count for the current step) and re-masks them for the next iteration.

The method restores the original training state before returning.

Usage

Call on a trained NonAutoregressiveWrapper model to generate sequences. No input tokens are needed; generation starts from a fully masked sequence of length max_seq_len.

generated = model.generate(batch_size=16, start_temperature=1.0)
print(generated.shape)  # (16, max_seq_len)

Code Reference

Field Value
Repository x-transformers
File x_transformers/nonautoregressive_wrapper.py
Lines L188–273

Signature:

@torch.no_grad()
def generate(
    self,
    batch_size = None,
    start_temperature = 1.,
    filter_thres = 0.7,
    noise_level_scale = 1.,
    **kwargs
) -> Tensor:

Import:

from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper

I/O Contract

Inputs

Name Type Required Description
batch_size int or None No Number of sequences to generate. If None, generates a single sequence and squeezes the batch dimension. Default: None (generates 1 sequence)
start_temperature float No Initial sampling temperature, annealed linearly to 0 over generation steps. Higher values increase diversity. Default: 1.0
filter_thres float No Top-k filtering threshold. Retains the top ceil((1 - filter_thres) * vocab_size) logits. Default: 0.7
noise_level_scale float No Scaling factor for Gumbel noise added to token critic scores. Only used when a token critic is present. Default: 1.0

Outputs

Name Type Description
seq Tensor Generated token ids. Shape (batch_size, max_seq_len) if batch_size is provided, or (max_seq_len,) if batch_size is None

Usage Examples

Batch Generation

from x_transformers import TransformerWrapper, Encoder
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
import torch

# Setup model
net = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = 512,
    attn_layers = Encoder(dim = 512, depth = 6, heads = 8)
)
model = NonAutoregressiveWrapper(net, mask_id = 255, steps = 18).cuda()

# Generate sequences
generated = model.generate(
    batch_size = 16,
    start_temperature = 1.0,
    filter_thres = 0.7
)
print(generated.shape)  # (16, 512)

Single Sequence Generation

# Generate a single sequence (batch dimension squeezed)
single = model.generate()
print(single.shape)  # (512,)

High-Diversity Generation with Token Critic

# Higher temperature and noise scale for more diverse outputs
generated = model.generate(
    batch_size = 8,
    start_temperature = 2.0,
    noise_level_scale = 1.5
)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment