Implementation:Lucidrains X transformers AutoregressiveWrapper Generate
Appearance
Metadata
| Field | Value |
|---|---|
| Repository | x-transformers |
| Domains | NLP, Inference |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for autoregressive text generation with multiple sampling strategies provided by the x-transformers library.
Description
The generate method of AutoregressiveWrapper produces token sequences autoregressively. It supports:
- Variable-length prompts — accepts either a list of tensors (variable lengths) or a single padded tensor
- Multiple logit filtering strategies —
top_k,top_p,top_a,min_p - Temperature-controlled sampling — adjustable randomness via the temperature parameter
- KV caching — efficient generation by reusing previous key-value computations
- Contrastive decoding — improved output quality using an amateur model
- Early stopping — halt generation when an EOS token is produced
Prompts can be right-aligned for batched generation with different prompt lengths.
Usage
Call on a trained AutoregressiveWrapper model to generate text. Pass prompt tokens and the desired generation length.
Code Reference
- Repository: x-transformers
- File:
x_transformers/autoregressive_wrapper.py - Lines: L351–509
Signature:
@torch.no_grad()
@eval_decorator
def generate(
self,
prompts: list[Tensor] | Tensor,
seq_len: int,
eos_token: int | None = None,
temperature: float = 1.,
prompt_lens: Tensor | None = None,
filter_logits_fn: str | Callable = top_k,
restrict_to_max_seq_len: bool = True,
amateur_model: Module | Tuple[Module] | None = None,
filter_kwargs: dict = dict(),
contrastive_decode_kwargs: dict | Tuple[dict] = dict(beta=0.5, alpha=0.1),
cache_kv: bool = True,
**kwargs
) -> Tensor:
Import:
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| prompts | list[Tensor] or Tensor |
Yes | Prompt token ids (variable-length list or fixed-size tensor) |
| seq_len | int |
Yes | Number of new tokens to generate |
| eos_token | int or None |
No | Stop generation when this token is produced |
| temperature | float |
No | Sampling temperature (0=greedy, 1=standard, >1=more random) |
| filter_logits_fn | str or Callable |
No | Filtering strategy: 'top_k', 'top_p', 'top_a', 'min_p'
|
| cache_kv | bool |
No | Enable KV caching for faster generation (default True)
|
| amateur_model | Module or None |
No | Amateur model for contrastive decoding |
Outputs
| Name | Type | Description |
|---|---|---|
| generated | Tensor |
Generated token ids of shape (batch, seq_len), new tokens only (no prompt)
|
Usage Examples
Basic Generation
import torch
# Assuming model is a trained AutoregressiveWrapper
# Single prompt generation
prompt = torch.tensor([1, 5, 10, 22]).cuda()
generated = model.generate(
prompts = prompt,
seq_len = 256,
temperature = 0.8,
filter_logits_fn = 'top_k',
filter_kwargs = dict(k = 25),
cache_kv = True
)
# Variable-length batch generation
prompts = [
torch.tensor([1, 5, 10]).cuda(),
torch.tensor([1, 5, 10, 22, 33]).cuda(),
]
generated = model.generate(
prompts = prompts,
seq_len = 128,
temperature = 1.0,
filter_logits_fn = 'top_p',
filter_kwargs = dict(thres = 0.9)
)
From train_enwik8.py Example
sample = model.generate(
prompts = inp,
seq_len = GENERATE_LENGTH,
cache_kv = True
)
output_str = decode_tokens(sample)
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment