Implementation:Lucidrains X transformers AutoregressiveWrapper Init
Appearance
Implementation: AutoregressiveWrapper_Init
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Knowledge Sources | Repo (x-transformers) |
| Domains | NLP, Training |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for wrapping transformer models with autoregressive training and generation capabilities provided by the x-transformers library.
Description
AutoregressiveWrapper wraps a TransformerWrapper instance to add the essential machinery for autoregressive language modeling:
- Automatic teacher-forcing: The
forward()method splits the input sequence intoinp = seq[:, :-1]andtarget = seq[:, 1:], feeding the input through the wrapped model and computing loss against the target. - Cross-entropy loss computation: The forward pass returns a scalar cross-entropy loss (using
F.cross_entropywith a configurableignore_indexfor padding). When attention z-loss regularization is enabled, the auxiliary z-loss is added to the main loss. - Multiple generation strategies: The
generate()method supports top-k, top-p (nucleus), min-p, top-a, contrastive decoding, and beam search. Generation uses@torch.no_grad()and automatically switches to eval mode. - Optional MLM-style masking: When
mask_prob > 0, a fraction of input tokens are randomly masked during training. This technique, from the paper Filling in the Blanks (arXiv:2210.13432), augments autoregressive training with a masked language modeling signal to improve model quality. - Attention z-loss regularization: When
add_attn_z_loss=True, the wrapper extracts and adds the attention z-loss from the model's intermediates to the primary cross-entropy loss, providing a regularization signal for attention logit stability. - Next embedding loss: When the wrapped model has
add_continuous_pred_head=True, an additional continuous prediction loss (weighted bynext_embed_loss_weight) is computed alongside the discrete token prediction loss.
Code Reference
Source Location
x-transformers repo, file: x_transformers/autoregressive_wrapper.py, lines L156-183.
Signature
class AutoregressiveWrapper(Module):
def __init__(
self,
net,
ignore_index = -100,
pad_value = 0,
mask_prob = 0.,
add_attn_z_loss = False,
next_embed_loss_weight = 0.1
):
Import
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
I/O Contract
Constructor Inputs
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
net |
TransformerWrapper | Yes | -- | The base transformer model to wrap. Must be a TransformerWrapper configured with Decoder attention layers. |
ignore_index |
int | No | -100 |
Token index to ignore in the cross-entropy loss computation. Used to mask padding tokens so they do not contribute to the loss or gradient. |
pad_value |
int | No | 0 |
The token value used for padding input sequences. Used when aligning variable-length sequences within a batch. |
mask_prob |
float | No | 0.0 |
Probability of masking each input token during training. When greater than 0, enables MLM-style augmentation during autoregressive training. Must be strictly less than 1.0. |
add_attn_z_loss |
bool | No | False |
Whether to add attention z-loss regularization to the training loss. When enabled, the auxiliary z-loss from the attention layers is extracted and summed with the cross-entropy loss. |
next_embed_loss_weight |
float | No | 0.1 |
Weight for the continuous next-embedding prediction loss. Only active when the wrapped model has add_continuous_pred_head=True.
|
Constructor Outputs
| Output | Type | Description |
|---|---|---|
| instance | AutoregressiveWrapper | A wrapped model with .forward() returning the training loss and .generate() for autoregressive sampling. The instance also exposes the max_seq_len attribute from the wrapped model.
|
Forward Behavior
- Input: A batch of token sequences as a
torch.Tensorof shape(batch, seq_len). - Output: A scalar
torch.Tensorcontaining the cross-entropy loss (plus optional z-loss and embedding loss).
Generate Behavior
- Input: A prompt tensor of shape
(batch, prompt_len)and a target sequence length. - Output: A tensor of shape
(batch, prompt_len + seq_len)containing the prompt followed by generated tokens.
Usage Examples
Basic Setup: Wrapping a Decoder Model
from x_transformers import TransformerWrapper, Decoder
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
model = TransformerWrapper(
num_tokens = 256,
max_seq_len = 1024,
attn_layers = Decoder(dim = 512, depth = 6, heads = 8)
)
model = AutoregressiveWrapper(model)
model.cuda()
Training Loop
import torch
from torch.optim import Adam
optimizer = Adam(model.parameters(), lr=3e-4)
for batch in dataloader:
tokens = batch.cuda() # shape: (batch_size, seq_len)
loss = model(tokens) # forward returns scalar loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
Generation
prompt = torch.tensor([[1, 2, 3]]).cuda()
generated = model.generate(prompt, seq_len=100) # shape: (1, 103)
With MLM Augmentation
model = AutoregressiveWrapper(
net,
mask_prob = 0.15 # mask 15% of input tokens during training
)
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment