Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers XTransformer Generate

From Leeroopedia


Metadata

Field Value
Repository x-transformers
Domains NLP, Inference
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for generating output sequences from encoder-decoder transformer models provided by the x-transformers library.

Description

The generate method of XTransformer performs conditional sequence generation. It first encodes the source sequence through the encoder (with return_embeddings=True), then passes the encoder output as context to self.decoder.generate() (the AutoregressiveWrapper's generate method), which generates tokens autoregressively with cross-attention to the encoded source. All sampling strategies from AutoregressiveWrapper.generate are available.

Usage

Call on a trained XTransformer to generate output sequences given source sequences.

Code Reference

  • Repository: x-transformers
  • File: x_transformers/x_transformers.py
  • Lines: L3875–3878

Signature:

@torch.no_grad()
def generate(self, seq_in, seq_out_start, seq_len, mask = None, attn_mask = None, **kwargs):

Import:

from x_transformers import XTransformer

I/O Contract

Inputs

Name Type Required Description
seq_in Tensor Yes Source token ids of shape (batch, enc_seq_len)
seq_out_start Tensor Yes Decoder start/prefix tokens of shape (batch, 1)
seq_len int Yes Number of tokens to generate
mask Tensor or None No Source padding mask of shape (batch, enc_seq_len)
attn_mask Tensor or None No Source attention mask

Outputs

Name Type Description
generated Tensor Generated token ids of shape (batch, seq_len)

Usage Examples

From train_copy.py

from x_transformers import XTransformer
import torch

# After training...
model.eval()
src = torch.randint(2, 18, (1, 32)).long().cuda()
src_mask = torch.ones(1, 32).bool().cuda()
start_tokens = (torch.ones((1, 1)) * 1).long().cuda()

sample = model.generate(src, start_tokens, 32, mask=src_mask)

# Check accuracy
incorrects = (src != sample).long().abs().sum()
print(f"input: {src}")
print(f"predicted output: {sample}")
print(f"incorrects: {incorrects}")

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment