Principle:Lucidrains X transformers Encoder Decoder Training
Metadata
| Field | Value |
|---|---|
| Paper | Attention Is All You Need |
| Repository | x-transformers |
| Domains | Deep_Learning, NLP, Training |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Training procedure for encoder-decoder transformer models that computes cross-entropy loss on target sequences conditioned on encoded source sequences.
Description
Training an encoder-decoder model involves three key steps:
- Encoding the source sequence — the bidirectional encoder processes the full source sequence and produces contextual representations for every source token.
- Cross-attention conditioning — the encoder outputs are fed as cross-attention context to the decoder, allowing each decoder layer to attend over the encoded source.
- Autoregressive loss computation — the decoder computes cross-entropy loss on the target sequence in an autoregressive fashion, predicting each target token given all previous target tokens and the full encoder output.
The XTransformer.forward() method handles this end-to-end: it encodes the source, optionally applies cross-attention token dropout for regularization, and passes the encoding to the decoder (an AutoregressiveWrapper that computes the loss). The target sequence is automatically shifted internally by the decoder wrapper.
Usage
Use during the training loop for sequence-to-sequence tasks. Pass source tokens, target tokens, and an optional source mask:
loss = model(src, tgt, mask=src_mask)
loss.backward()
Theoretical Basis
Sequence-to-sequence training: the loss is the negative log-likelihood of the target sequence conditioned on the encoded source:
L = -∑_t log P(y_t | y_{<t}, Encoder(x))
The encoder produces contextual representations of the source sequence. The decoder predicts target tokens autoregressively while cross-attending to the encoder output at every layer.
Teacher forcing: during training, ground-truth target tokens are used as decoder input at each time step rather than the model's own predictions. This avoids compounding errors and enables fully parallel computation across the target sequence.