Implementation:Lucidrains X transformers XTransformer Forward
Metadata
| Field | Value |
|---|---|
| Repository | x-transformers |
| Domains | NLP, Training |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for training encoder-decoder sequence-to-sequence models provided by the x-transformers library.
Description
The forward method of XTransformer runs the full encoder-decoder forward pass: it encodes the source sequence through the encoder (returning embeddings), optionally applies cross-attention token dropout for regularization, and passes the target sequence and encoder context to the decoder (AutoregressiveWrapper) which computes the cross-entropy loss. The decoder internally shifts the target for teacher forcing.
Usage
Call during each training step with source tokens, target tokens, and an optional source mask. Returns a scalar loss tensor.
loss = model(src, tgt, mask=src_mask)
loss.backward()
Code Reference
| Field | Value |
|---|---|
| Repository | x-transformers |
| File | x_transformers/x_transformers.py
|
| Lines | L3880–3891 |
Signature:
def forward(self, src, tgt, mask = None, attn_mask = None, src_prepend_embeds = None):
Import:
from x_transformers import XTransformer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
src |
Tensor |
Yes | Source token ids of shape (batch, enc_seq_len)
|
tgt |
Tensor |
Yes | Target token ids of shape (batch, dec_seq_len) — auto-shifted internally
|
mask |
Tensor or None |
No | Source padding mask of shape (batch, enc_seq_len)
|
attn_mask |
Tensor or None |
No | Custom attention mask |
src_prepend_embeds |
Tensor or None |
No | Embeddings to prepend to the source sequence |
Outputs
| Name | Type | Description |
|---|---|---|
loss |
Tensor |
Scalar cross-entropy loss on the target sequence |
Usage Examples
From train_copy.py
from x_transformers import XTransformer
import torch
model = XTransformer(
dim = 128,
tie_token_emb = True,
return_tgt_loss = True,
enc_num_tokens = 18,
enc_depth = 3,
enc_heads = 8,
enc_max_seq_len = 32,
dec_num_tokens = 18,
dec_depth = 3,
dec_heads = 8,
dec_max_seq_len = 65
).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
# Training step
src = torch.randint(2, 18, (32, 32)).long().cuda()
prefix = torch.ones((32, 1)).long().cuda()
tgt = torch.cat((prefix, src, src), 1)
src_mask = torch.ones(32, 32).bool().cuda()
loss = model(src, tgt, mask=src_mask)
loss.backward()
optimizer.step()
optimizer.zero_grad()