Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers AutoregressiveWrapper Init

From Leeroopedia


Implementation: AutoregressiveWrapper_Init

Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources Repo (x-transformers)
Domains NLP, Training
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for wrapping transformer models with autoregressive training and generation capabilities provided by the x-transformers library.

Description

AutoregressiveWrapper wraps a TransformerWrapper instance to add the essential machinery for autoregressive language modeling:

  • Automatic teacher-forcing: The forward() method splits the input sequence into inp = seq[:, :-1] and target = seq[:, 1:], feeding the input through the wrapped model and computing loss against the target.
  • Cross-entropy loss computation: The forward pass returns a scalar cross-entropy loss (using F.cross_entropy with a configurable ignore_index for padding). When attention z-loss regularization is enabled, the auxiliary z-loss is added to the main loss.
  • Multiple generation strategies: The generate() method supports top-k, top-p (nucleus), min-p, top-a, contrastive decoding, and beam search. Generation uses @torch.no_grad() and automatically switches to eval mode.
  • Optional MLM-style masking: When mask_prob > 0, a fraction of input tokens are randomly masked during training. This technique, from the paper Filling in the Blanks (arXiv:2210.13432), augments autoregressive training with a masked language modeling signal to improve model quality.
  • Attention z-loss regularization: When add_attn_z_loss=True, the wrapper extracts and adds the attention z-loss from the model's intermediates to the primary cross-entropy loss, providing a regularization signal for attention logit stability.
  • Next embedding loss: When the wrapped model has add_continuous_pred_head=True, an additional continuous prediction loss (weighted by next_embed_loss_weight) is computed alongside the discrete token prediction loss.

Code Reference

Source Location

x-transformers repo, file: x_transformers/autoregressive_wrapper.py, lines L156-183.

Signature

class AutoregressiveWrapper(Module):
    def __init__(
        self,
        net,
        ignore_index = -100,
        pad_value = 0,
        mask_prob = 0.,
        add_attn_z_loss = False,
        next_embed_loss_weight = 0.1
    ):

Import

from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

I/O Contract

Constructor Inputs

Parameter Type Required Default Description
net TransformerWrapper Yes -- The base transformer model to wrap. Must be a TransformerWrapper configured with Decoder attention layers.
ignore_index int No -100 Token index to ignore in the cross-entropy loss computation. Used to mask padding tokens so they do not contribute to the loss or gradient.
pad_value int No 0 The token value used for padding input sequences. Used when aligning variable-length sequences within a batch.
mask_prob float No 0.0 Probability of masking each input token during training. When greater than 0, enables MLM-style augmentation during autoregressive training. Must be strictly less than 1.0.
add_attn_z_loss bool No False Whether to add attention z-loss regularization to the training loss. When enabled, the auxiliary z-loss from the attention layers is extracted and summed with the cross-entropy loss.
next_embed_loss_weight float No 0.1 Weight for the continuous next-embedding prediction loss. Only active when the wrapped model has add_continuous_pred_head=True.

Constructor Outputs

Output Type Description
instance AutoregressiveWrapper A wrapped model with .forward() returning the training loss and .generate() for autoregressive sampling. The instance also exposes the max_seq_len attribute from the wrapped model.

Forward Behavior

  • Input: A batch of token sequences as a torch.Tensor of shape (batch, seq_len).
  • Output: A scalar torch.Tensor containing the cross-entropy loss (plus optional z-loss and embedding loss).

Generate Behavior

  • Input: A prompt tensor of shape (batch, prompt_len) and a target sequence length.
  • Output: A tensor of shape (batch, prompt_len + seq_len) containing the prompt followed by generated tokens.

Usage Examples

Basic Setup: Wrapping a Decoder Model

from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

model = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = 1024,
    attn_layers = Decoder(dim = 512, depth = 6, heads = 8)
)

model = AutoregressiveWrapper(model)
model.cuda()

Training Loop

import torch
from torch.optim import Adam

optimizer = Adam(model.parameters(), lr=3e-4)

for batch in dataloader:
    tokens = batch.cuda()          # shape: (batch_size, seq_len)
    loss = model(tokens)           # forward returns scalar loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Generation

prompt = torch.tensor([[1, 2, 3]]).cuda()
generated = model.generate(prompt, seq_len=100)  # shape: (1, 103)

With MLM Augmentation

model = AutoregressiveWrapper(
    net,
    mask_prob = 0.15  # mask 15% of input tokens during training
)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment