Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers AutoregressiveWrapper Generate

From Leeroopedia


Metadata

Field Value
Repository x-transformers
Domains NLP, Inference
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for autoregressive text generation with multiple sampling strategies provided by the x-transformers library.

Description

The generate method of AutoregressiveWrapper produces token sequences autoregressively. It supports:

  • Variable-length prompts — accepts either a list of tensors (variable lengths) or a single padded tensor
  • Multiple logit filtering strategiestop_k, top_p, top_a, min_p
  • Temperature-controlled sampling — adjustable randomness via the temperature parameter
  • KV caching — efficient generation by reusing previous key-value computations
  • Contrastive decoding — improved output quality using an amateur model
  • Early stopping — halt generation when an EOS token is produced

Prompts can be right-aligned for batched generation with different prompt lengths.

Usage

Call on a trained AutoregressiveWrapper model to generate text. Pass prompt tokens and the desired generation length.

Code Reference

  • Repository: x-transformers
  • File: x_transformers/autoregressive_wrapper.py
  • Lines: L351–509

Signature:

@torch.no_grad()
@eval_decorator
def generate(
    self,
    prompts: list[Tensor] | Tensor,
    seq_len: int,
    eos_token: int | None = None,
    temperature: float = 1.,
    prompt_lens: Tensor | None = None,
    filter_logits_fn: str | Callable = top_k,
    restrict_to_max_seq_len: bool = True,
    amateur_model: Module | Tuple[Module] | None = None,
    filter_kwargs: dict = dict(),
    contrastive_decode_kwargs: dict | Tuple[dict] = dict(beta=0.5, alpha=0.1),
    cache_kv: bool = True,
    **kwargs
) -> Tensor:

Import:

from x_transformers.autoregressive_wrapper import AutoregressiveWrapper

I/O Contract

Inputs

Name Type Required Description
prompts list[Tensor] or Tensor Yes Prompt token ids (variable-length list or fixed-size tensor)
seq_len int Yes Number of new tokens to generate
eos_token int or None No Stop generation when this token is produced
temperature float No Sampling temperature (0=greedy, 1=standard, >1=more random)
filter_logits_fn str or Callable No Filtering strategy: 'top_k', 'top_p', 'top_a', 'min_p'
cache_kv bool No Enable KV caching for faster generation (default True)
amateur_model Module or None No Amateur model for contrastive decoding

Outputs

Name Type Description
generated Tensor Generated token ids of shape (batch, seq_len), new tokens only (no prompt)

Usage Examples

Basic Generation

import torch
# Assuming model is a trained AutoregressiveWrapper

# Single prompt generation
prompt = torch.tensor([1, 5, 10, 22]).cuda()
generated = model.generate(
    prompts = prompt,
    seq_len = 256,
    temperature = 0.8,
    filter_logits_fn = 'top_k',
    filter_kwargs = dict(k = 25),
    cache_kv = True
)

# Variable-length batch generation
prompts = [
    torch.tensor([1, 5, 10]).cuda(),
    torch.tensor([1, 5, 10, 22, 33]).cuda(),
]
generated = model.generate(
    prompts = prompts,
    seq_len = 128,
    temperature = 1.0,
    filter_logits_fn = 'top_p',
    filter_kwargs = dict(thres = 0.9)
)

From train_enwik8.py Example

sample = model.generate(
    prompts = inp,
    seq_len = GENERATE_LENGTH,
    cache_kv = True
)
output_str = decode_tokens(sample)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment