Implementation:Lucidrains X transformers NonAutoregressiveWrapper Generate
Metadata
| Field | Value |
|---|---|
| Repository | x-transformers |
| Domains | Generative_Models, Inference |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for iterative masked sequence generation using progressive unmasking provided by the x-transformers library.
Description
The generate method of NonAutoregressiveWrapper starts from a fully masked sequence and iteratively unmasks tokens over self.steps steps. The method is decorated with @torch.no_grad() and automatically switches the model to eval mode during generation.
At each step, the method:
- Forward pass — runs the current sequence (with remaining masks) through the wrapped
TransformerWrapperto obtain logits and embeddings. - Top-k filtering — optionally filters logits to retain only the top
(1 - filter_thres)fraction of the vocabulary at each position. - Temperature-scaled sampling — samples tokens using Gumbel sampling with an annealing temperature that decreases linearly from
start_temperatureto 0 over the generation steps. - Token placement — fills masked positions with the sampled tokens.
- Confidence scoring — computes a confidence score for each position, either via
(1 - softmax(logits))gathered at sampled indices or via the token critic with added Gumbel noise. - Re-masking — selects the least confident positions (up to
mask_countfor the current step) and re-masks them for the next iteration.
The method restores the original training state before returning.
Usage
Call on a trained NonAutoregressiveWrapper model to generate sequences. No input tokens are needed; generation starts from a fully masked sequence of length max_seq_len.
generated = model.generate(batch_size=16, start_temperature=1.0)
print(generated.shape) # (16, max_seq_len)
Code Reference
| Field | Value |
|---|---|
| Repository | x-transformers |
| File | x_transformers/nonautoregressive_wrapper.py
|
| Lines | L188–273 |
Signature:
@torch.no_grad()
def generate(
self,
batch_size = None,
start_temperature = 1.,
filter_thres = 0.7,
noise_level_scale = 1.,
**kwargs
) -> Tensor:
Import:
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
batch_size |
int or None |
No | Number of sequences to generate. If None, generates a single sequence and squeezes the batch dimension. Default: None (generates 1 sequence)
|
start_temperature |
float |
No | Initial sampling temperature, annealed linearly to 0 over generation steps. Higher values increase diversity. Default: 1.0
|
filter_thres |
float |
No | Top-k filtering threshold. Retains the top ceil((1 - filter_thres) * vocab_size) logits. Default: 0.7
|
noise_level_scale |
float |
No | Scaling factor for Gumbel noise added to token critic scores. Only used when a token critic is present. Default: 1.0
|
Outputs
| Name | Type | Description |
|---|---|---|
seq |
Tensor |
Generated token ids. Shape (batch_size, max_seq_len) if batch_size is provided, or (max_seq_len,) if batch_size is None
|
Usage Examples
Batch Generation
from x_transformers import TransformerWrapper, Encoder
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
import torch
# Setup model
net = TransformerWrapper(
num_tokens = 256,
max_seq_len = 512,
attn_layers = Encoder(dim = 512, depth = 6, heads = 8)
)
model = NonAutoregressiveWrapper(net, mask_id = 255, steps = 18).cuda()
# Generate sequences
generated = model.generate(
batch_size = 16,
start_temperature = 1.0,
filter_thres = 0.7
)
print(generated.shape) # (16, 512)
Single Sequence Generation
# Generate a single sequence (batch dimension squeezed)
single = model.generate()
print(single.shape) # (512,)
High-Diversity Generation with Token Critic
# Higher temperature and noise scale for more diverse outputs
generated = model.generate(
batch_size = 8,
start_temperature = 2.0,
noise_level_scale = 1.5
)