Implementation:Lucidrains X transformers AutoregressiveWrapper Forward
Appearance
Metadata
| Field | Value |
|---|---|
| Repository | x-transformers |
| Domains | NLP, Training |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for computing autoregressive next-token prediction training loss provided by the x-transformers library.
Description
The forward method of AutoregressiveWrapper takes token sequences, internally splits them into input and target (offset by 1), runs the input through the wrapped TransformerWrapper to get logits, and computes cross-entropy loss.
Supports optional features:
- Token masking (MLM augmentation) — randomly replaces a fraction of input tokens with a mask token during training to act as regularization.
- Attention z-loss regularization — penalizes large attention logits to stabilize training dynamics.
- Next-embedding continuous prediction loss — an auxiliary objective that predicts the next token's embedding in continuous space alongside the discrete softmax prediction.
Usage
Call during each training step. Pass a batch of token ids of shape (batch, seq_len + 1). Returns a scalar loss tensor.
loss = model(tokens) # returns scalar loss
loss.backward()
Code Reference
| Field | Value |
|---|---|
| Repository | x-transformers |
| File | x_transformers/autoregressive_wrapper.py
|
| Lines | L511–585 |
Signature:
def forward(
self,
x: Tensor,
return_outputs: bool = False,
prepend_embeds: Tensor | None = None,
**kwargs
) -> Tensor:
Import:
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
x |
Tensor |
Yes | Token ids of shape (batch, seq_len + 1), includes target token
|
return_outputs |
bool |
No | If True, returns (loss, (logits, cache)) tuple
|
prepend_embeds |
Tensor or None |
No | Optional embeddings to prepend before the sequence |
Outputs
| Name | Type | Description |
|---|---|---|
loss |
Tensor |
Scalar cross-entropy loss (default) |
(loss, (logits, cache)) |
Tuple |
If return_outputs=True, also returns logits and KV cache
|
Usage Examples
Basic Training Step
from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
import torch
# Setup model
model = TransformerWrapper(
num_tokens = 256,
max_seq_len = 1024,
attn_layers = Decoder(dim = 512, depth = 6, heads = 8)
)
model = AutoregressiveWrapper(model).cuda()
# Training step
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
tokens = torch.randint(0, 256, (4, 1025)).cuda() # batch=4, seq_len+1=1025
loss = model(tokens)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Gradient Accumulation (from train_enwik8.py)
for _ in range(GRADIENT_ACCUMULATE_EVERY):
loss = model(next(train_loader))
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
optimizer.zero_grad()
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment