Overview
Concrete tool for autoregressive training and generation over sequences longer than the model's max_seq_len by segmenting and propagating memory across boundaries provided by the x-transformers library.
Description
The XLAutoregressiveWrapper implements Transformer-XL style segment-based processing for sequences that exceed the model's maximum sequence length. During training, the input is split into chunks of max_seq_len, and each chunk is processed sequentially with memory (mems) propagated from the previous chunk. The loss is weighted proportionally to each chunk's length. During generation, the wrapper first warms up memories by processing all preceding segments, then generates token-by-token while maintaining segment boundaries and memory state. It supports top-k and top-p sampling, EOS token stopping, and KV caching within segments.
Usage
Import this class when you need to train or generate with sequences longer than the TransformerWrapper's max_seq_len. This is the standard approach for handling long documents, extended contexts, or streaming generation where the full sequence cannot fit in a single forward pass.
Code Reference
Source Location
Signature
class XLAutoregressiveWrapper(nn.Module):
def __init__(
self,
net,
ignore_index = -100,
pad_value = 0
):
"""
Args:
net: TransformerWrapper with memory support (max_mem_len > 0).
ignore_index: Label index to ignore in cross-entropy loss.
pad_value: Padding value for masked output during generation.
"""
Import
from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
I/O Contract
forward() Inputs
| Name |
Type |
Required |
Description
|
| x |
Tensor (b, n) |
Yes |
Full token sequence (can be longer than max_seq_len)
|
| mems |
list of Tensor |
No |
Initial memory tensors (None for fresh start)
|
forward() Outputs
| Name |
Type |
Description
|
| returns |
Tensor (scalar) |
Weighted cross-entropy loss across all segments
|
generate() Inputs
| Name |
Type |
Required |
Description
|
| start_tokens |
Tensor (b, t) or (t,) |
Yes |
Prompt tokens (can span multiple segments)
|
| seq_len |
int |
Yes |
Number of new tokens to generate
|
| eos_token |
int |
No |
Stop generation when all sequences produce this token
|
| temperature |
float |
No |
Sampling temperature (default 1.0)
|
| filter_logits_fn |
callable |
No |
Logit filtering function (default top_k)
|
| mems |
list of Tensor |
No |
Initial memory tensors
|
generate() Outputs
| Name |
Type |
Description
|
| returns |
Tensor (b, seq_len) or (seq_len,) |
Generated token sequence (prompt excluded)
|
Usage Examples
Training on Long Sequences
import torch
from x_transformers import TransformerWrapper, Decoder
from x_transformers.xl_autoregressive_wrapper import XLAutoregressiveWrapper
# Model with max_seq_len=256 but we'll train on longer sequences
model = TransformerWrapper(
num_tokens=256,
max_seq_len=256,
attn_layers=Decoder(dim=256, depth=6, heads=8),
max_mem_len=256 # memory for Transformer-XL recurrence
)
wrapper = XLAutoregressiveWrapper(model)
# Train on a 1024-token sequence (4 segments of 256)
long_seq = torch.randint(0, 256, (4, 1024))
loss = wrapper(long_seq)
loss.backward()
Long-Context Generation
# Generate with segment-level memory propagation
prompt = torch.randint(0, 256, (1, 500)) # prompt spans 2 segments
generated = wrapper.generate(
start_tokens=prompt,
seq_len=200,
temperature=0.8,
eos_token=0
)
Related Pages