Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers XTransformer Forward

From Leeroopedia


Metadata

Field Value
Repository x-transformers
Domains NLP, Training
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for training encoder-decoder sequence-to-sequence models provided by the x-transformers library.

Description

The forward method of XTransformer runs the full encoder-decoder forward pass: it encodes the source sequence through the encoder (returning embeddings), optionally applies cross-attention token dropout for regularization, and passes the target sequence and encoder context to the decoder (AutoregressiveWrapper) which computes the cross-entropy loss. The decoder internally shifts the target for teacher forcing.

Usage

Call during each training step with source tokens, target tokens, and an optional source mask. Returns a scalar loss tensor.

loss = model(src, tgt, mask=src_mask)
loss.backward()

Code Reference

Field Value
Repository x-transformers
File x_transformers/x_transformers.py
Lines L3880–3891

Signature:

def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):

Import:

from x_transformers import XTransformer

I/O Contract

Inputs

Name Type Required Description
src Tensor Yes Source token ids of shape (batch, enc_seq_len)
tgt Tensor Yes Target token ids of shape (batch, dec_seq_len) — auto-shifted internally
mask Tensor or None No Source padding mask of shape (batch, enc_seq_len)
attn_mask Tensor or None No Custom attention mask
src_prepend_embeds Tensor or None No Embeddings to prepend to the source sequence

Outputs

Name Type Description
loss Tensor Scalar cross-entropy loss on the target sequence

Usage Examples

From train_copy.py

from x_transformers import XTransformer
import torch

model = XTransformer(
    dim = 128,
    tie_token_emb = True,
    return_tgt_loss = True,
    enc_num_tokens = 18,
    enc_depth = 3,
    enc_heads = 8,
    enc_max_seq_len = 32,
    dec_num_tokens = 18,
    dec_depth = 3,
    dec_heads = 8,
    dec_max_seq_len = 65
).cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

# Training step
src = torch.randint(2, 18, (32, 32)).long().cuda()
prefix = torch.ones((32, 1)).long().cuda()
tgt = torch.cat((prefix, src, src), 1)
src_mask = torch.ones(32, 32).bool().cuda()

loss = model(src, tgt, mask=src_mask)
loss.backward()
optimizer.step()
optimizer.zero_grad()

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment