Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers AutoregressiveWrapper Forward

From Leeroopedia


Metadata

Field Value
Repository x-transformers
Domains NLP, Training
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for computing autoregressive next-token prediction training loss provided by the x-transformers library.

Description

The forward method of AutoregressiveWrapper takes token sequences, internally splits them into input and target (offset by 1), runs the input through the wrapped TransformerWrapper to get logits, and computes cross-entropy loss.

Supports optional features:

  • Token masking (MLM augmentation) — randomly replaces a fraction of input tokens with a mask token during training to act as regularization.
  • Attention z-loss regularization — penalizes large attention logits to stabilize training dynamics.
  • Next-embedding continuous prediction loss — an auxiliary objective that predicts the next token's embedding in continuous space alongside the discrete softmax prediction.

Usage

Call during each training step. Pass a batch of token ids of shape (batch, seq_len + 1). Returns a scalar loss tensor.

loss = model(tokens)        # returns scalar loss
loss.backward()

Code Reference

Field Value
Repository x-transformers
File x_transformers/autoregressive_wrapper.py
Lines L511–585

Signature:

def forward(
    self,
    x: Tensor,
    return_outputs: bool = False,
    prepend_embeds: Tensor | None = None,
    **kwargs
) -> Tensor:

Import:

from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

I/O Contract

Inputs

Name Type Required Description
x Tensor Yes Token ids of shape (batch, seq_len + 1), includes target token
return_outputs bool No If True, returns (loss, (logits, cache)) tuple
prepend_embeds Tensor or None No Optional embeddings to prepend before the sequence

Outputs

Name Type Description
loss Tensor Scalar cross-entropy loss (default)
(loss, (logits, cache)) Tuple If return_outputs=True, also returns logits and KV cache

Usage Examples

Basic Training Step

from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
import torch

# Setup model
model = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = 1024,
    attn_layers = Decoder(dim = 512, depth = 6, heads = 8)
)
model = AutoregressiveWrapper(model).cuda()

# Training step
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
tokens = torch.randint(0, 256, (4, 1025)).cuda()  # batch=4, seq_len+1=1025

loss = model(tokens)
loss.backward()
optimizer.step()
optimizer.zero_grad()

Gradient Accumulation (from train_enwik8.py)

for _ in range(GRADIENT_ACCUMULATE_EVERY):
    loss = model(next(train_loader))
    (loss / GRADIENT_ACCUMULATE_EVERY).backward()

torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
optimizer.zero_grad()

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment