Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers XLAutoregressiveWrapper

From Leeroopedia


Knowledge Sources
Domains NLP, Language_Modeling, Long_Context
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for autoregressive training and generation over sequences longer than the model's max_seq_len by segmenting and propagating memory across boundaries provided by the x-transformers library.

Description

The XLAutoregressiveWrapper implements Transformer-XL style segment-based processing for sequences that exceed the model's maximum sequence length. During training, the input is split into chunks of max_seq_len, and each chunk is processed sequentially with memory (mems) propagated from the previous chunk. The loss is weighted proportionally to each chunk's length. During generation, the wrapper first warms up memories by processing all preceding segments, then generates token-by-token while maintaining segment boundaries and memory state. It supports top-k and top-p sampling, EOS token stopping, and KV caching within segments.

Usage

Import this class when you need to train or generate with sequences longer than the TransformerWrapper's max_seq_len. This is the standard approach for handling long documents, extended contexts, or streaming generation where the full sequence cannot fit in a single forward pass.

Code Reference

Source Location

Signature

class XLAutoregressiveWrapper(nn.Module):
    def __init__(
        self,
        net,
        ignore_index = -100,
        pad_value = 0
    ):
        """
        Args:
            net: TransformerWrapper with memory support (max_mem_len > 0).
            ignore_index: Label index to ignore in cross-entropy loss.
            pad_value: Padding value for masked output during generation.
        """

Import

from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper

I/O Contract

forward() Inputs

Name Type Required Description
x Tensor (b, n) Yes Full token sequence (can be longer than max_seq_len)
mems list of Tensor No Initial memory tensors (None for fresh start)

forward() Outputs

Name Type Description
returns Tensor (scalar) Weighted cross-entropy loss across all segments

generate() Inputs

Name Type Required Description
start_tokens Tensor (b, t) or (t,) Yes Prompt tokens (can span multiple segments)
seq_len int Yes Number of new tokens to generate
eos_token int No Stop generation when all sequences produce this token
temperature float No Sampling temperature (default 1.0)
filter_logits_fn callable No Logit filtering function (default top_k)
mems list of Tensor No Initial memory tensors

generate() Outputs

Name Type Description
returns Tensor (b, seq_len) or (seq_len,) Generated token sequence (prompt excluded)

Usage Examples

Training on Long Sequences

import torch
from x_transformers import TransformerWrapper, Decoder
from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper

# Model with max_seq_len=256 but we'll train on longer sequences
model = TransformerWrapper(
    num_tokens=256,
    max_seq_len=256,
    attn_layers=Decoder(dim=256, depth=6, heads=8),
    max_mem_len=256  # memory for Transformer-XL recurrence
)

wrapper = XLAutoregressiveWrapper(model)

# Train on a 1024-token sequence (4 segments of 256)
long_seq = torch.randint(0, 256, (4, 1024))
loss = wrapper(long_seq)
loss.backward()

Long-Context Generation

# Generate with segment-level memory propagation
prompt = torch.randint(0, 256, (1, 500))  # prompt spans 2 segments

generated = wrapper.generate(
    start_tokens=prompt,
    seq_len=200,
    temperature=0.8,
    eos_token=0
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment