Principle:Lucidrains X transformers Autoregressive Training Loss
Metadata
| Field | Value |
|---|---|
| Paper | Attention Is All You Need |
| Repository | x-transformers |
| Domains | Deep_Learning, NLP, Training |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Next-token prediction training objective that computes cross-entropy loss between model predictions and shifted target sequences for autoregressive language modeling.
Description
The forward pass of an autoregressive language model takes a sequence of tokens, splits it into input (all but last) and target (all but first), runs the input through the transformer to get logits, and computes cross-entropy loss against the target. This is the standard teacher-forcing training paradigm.
The wrapper automatically handles:
- Input/target splitting — the sequence is split into offset-by-one pairs so the model learns to predict the next token at each position.
- Optional token masking (MLM augmentation) — randomly replaces a fraction of input tokens with a mask token during training, which has been shown to improve autoregressive training.
- Optional attention z-loss — a regularization term that penalizes large logits in the attention mechanism to stabilize training.
- Optional next-embed continuous prediction loss — an auxiliary objective that predicts the next token's embedding vector in continuous space rather than via the discrete softmax.
Usage
Use during the training loop. Pass batches of token sequences of shape (batch, seq_len + 1) to the wrapped model's forward method. The extra token is needed because the wrapper internally splits the sequence into input and target:
# tokens shape: (batch, seq_len + 1)
# The wrapper splits into:
# input = tokens[:, :-1] # (batch, seq_len)
# target = tokens[:, 1:] # (batch, seq_len)
loss = model(tokens)
loss.backward()
Theoretical Basis
Next-token prediction: given x_1 ... x_{t-1}, predict x_t.
The training loss is the standard cross-entropy:
Loss = CrossEntropy(logits, target) = -∑_t log P(x_t | x_1 ... x_{t-1})
Teacher forcing feeds ground-truth tokens as input at each time step (rather than the model's own predictions). This avoids compounding errors during training and allows fully parallel computation across the sequence.
Optional MLM masking: randomly replace some input tokens with a mask token during training. This technique was shown to improve autoregressive training by Tay et al. (2022), who demonstrated that exposing the model to corrupted inputs during training acts as a form of regularization and improves generalization.