Workflow:Lucidrains X transformers Encoder Decoder Sequence to Sequence
| Knowledge Sources | |
|---|---|
| Domains | Sequence_to_Sequence, Deep_Learning, Transformer_Training |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
End-to-end process for building and training a full encoder-decoder transformer for sequence-to-sequence tasks using x-transformers' XTransformer class.
Description
This workflow covers the construction and training of a complete encoder-decoder transformer model using x-transformers. The XTransformer convenience class bundles an Encoder and Decoder (with cross-attention) into a single model that accepts source and target sequences and returns the loss. The encoder processes the input sequence bidirectionally, the decoder attends to the encoder output via cross-attention while generating the target sequence autoregressively. This architecture is suited for tasks where input and output are different sequences: translation, summarization, sequence copying, and similar mapping problems. The model supports tied embeddings between encoder and decoder, and provides a generate() method for inference.
Usage
Execute this workflow when you need to map an input sequence to an output sequence of potentially different length or vocabulary. Use this when your task requires an encoder to understand the full input context and a decoder to generate the output conditioned on that context. Common applications include machine translation, text summarization, code generation from natural language, and sequence transduction tasks.
Execution Steps
Step 1: Install Dependencies
Install the x-transformers package. The XTransformer class is part of the core package and requires no additional dependencies beyond the standard x-transformers installation.
Key considerations:
- CUDA-capable PyTorch for GPU training
- The XTransformer class is imported directly from x_transformers
Step 2: Prepare Paired Dataset
Create a dataset that yields (source, target, source_mask) tuples. The source sequence is the encoder input, the target sequence is the decoder input (with teacher forcing), and the source mask indicates valid positions. The target should include a start-of-sequence token prepended.
Key considerations:
- Source and target may have different lengths (enc_max_seq_len vs dec_max_seq_len)
- Source masks should be boolean tensors marking valid (non-padding) positions
- The XTransformer forward() handles the decoder target shifting internally when return_tgt_loss=True
Step 3: Configure XTransformer Model
Instantiate XTransformer with separate encoder and decoder configurations. Specify vocabulary sizes, max sequence lengths, model dimension, depth, number of heads, and cross-attention settings for both encoder and decoder independently.
What happens:
- An Encoder stack is created with bidirectional self-attention
- A Decoder stack is created with causal self-attention and cross-attention layers
- Embedding tables and projection heads are created for both encoder and decoder
- Optional tie_token_emb shares the embedding weights between encoder and decoder
Step 4: Train the Model
Run the training loop feeding source sequences, target sequences, and source masks into the XTransformer. The model returns the cross-entropy loss directly. The loss is computed over the target sequence, training the decoder to predict each next token given the encoder output and previous target tokens.
Key considerations:
- The forward() call returns a scalar loss when return_tgt_loss=True is set during initialization
- Use standard Adam optimizer with learning rate around 3e-4
- Monitor training loss to verify convergence
Step 5: Generate Output Sequences
Use the XTransformer's generate() method to produce output sequences given source inputs. Provide the source sequence, start tokens for the decoder, the desired output length, and the source mask. The decoder generates tokens autoregressively, conditioned on the encoded source.
Key considerations:
- Start tokens initialize the decoder generation (typically a dedicated SOS token ID)
- The generate method returns the predicted output sequence
- Compare generated output against expected output to evaluate model quality