Implementation:Lucidrains X transformers XTransformer Generate
Metadata
| Field | Value |
|---|---|
| Repository | x-transformers |
| Domains | NLP, Inference |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for generating output sequences from encoder-decoder transformer models provided by the x-transformers library.
Description
The generate method of XTransformer performs conditional sequence generation. It first encodes the source sequence through the encoder (with return_embeddings=True), then passes the encoder output as context to self.decoder.generate() (the AutoregressiveWrapper's generate method), which generates tokens autoregressively with cross-attention to the encoded source. All sampling strategies from AutoregressiveWrapper.generate are available.
Usage
Call on a trained XTransformer to generate output sequences given source sequences.
Code Reference
- Repository: x-transformers
- File:
x_transformers/x_transformers.py - Lines: L3875–3878
Signature:
@torch.no_grad()
def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):
Import:
from x_transformers import XTransformer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| seq_in | Tensor |
Yes | Source token ids of shape (batch, enc_seq_len)
|
| seq_out_start | Tensor |
Yes | Decoder start/prefix tokens of shape (batch, 1)
|
| seq_len | int |
Yes | Number of tokens to generate |
| mask | Tensor or None |
No | Source padding mask of shape (batch, enc_seq_len)
|
| attn_mask | Tensor or None |
No | Source attention mask |
Outputs
| Name | Type | Description |
|---|---|---|
| generated | Tensor |
Generated token ids of shape (batch, seq_len)
|
Usage Examples
From train_copy.py
from x_transformers import XTransformer
import torch
# After training...
model.eval()
src = torch.randint(2, 18, (1, 32)).long().cuda()
src_mask = torch.ones(1, 32).bool().cuda()
start_tokens = (torch.ones((1, 1)) * 1).long().cuda()
sample = model.generate(src, start_tokens, 32, mask=src_mask)
# Check accuracy
incorrects = (src != sample).long().abs().sum()
print(f"input: {src}")
print(f"predicted output: {sample}")
print(f"incorrects: {incorrects}")