Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers ContinuousTransformerWrapper

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Continuous_Modeling, Time_Series
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for applying transformer attention layers to continuous (real-valued) input sequences provided by the x-transformers library.

Description

The ContinuousTransformerWrapper wraps an AttentionLayers module to process continuous-valued inputs instead of discrete tokens. It handles positional embeddings, input/output linear projections, memory tokens, and optional probabilistic (Gaussian mean-variance) output. The companion ContinuousAutoregressiveWrapper adds autoregressive training and generation for continuous sequences, supporting MSE, L1, or Gaussian NLL loss, and multi-step rollout training as used in world model papers.

Usage

Import these classes when working with continuous-valued sequences such as time series, audio features, or world model states where the input and output are real-valued vectors rather than discrete tokens.

Code Reference

Source Location

Signature

class ContinuousTransformerWrapper(Module):
    def __init__(
        self,
        *,
        max_seq_len,
        attn_layers: AttentionLayers,
        dim_in = None,
        dim_out = None,
        emb_dim = None,
        max_mem_len = 0,
        num_memory_tokens = None,
        post_emb_norm = False,
        emb_dropout = 0.,
        use_abs_pos_emb = True,
        scaled_sinu_pos_emb = False,
        average_pool_embed = False,
        probabilistic = False,
    ):
        """
        Args:
            max_seq_len: Maximum sequence length for positional embeddings.
            attn_layers: AttentionLayers instance (Encoder or Decoder).
            dim_in: Input projection dimension. None means no projection.
            dim_out: Output projection dimension. None means no projection.
            emb_dim: Embedding dimension override.
            max_mem_len: Maximum memory length for Transformer-XL style recurrence.
            num_memory_tokens: Number of learnable memory tokens.
            post_emb_norm: Apply LayerNorm after embedding.
            emb_dropout: Dropout rate on embeddings.
            use_abs_pos_emb: Use absolute positional embeddings.
            scaled_sinu_pos_emb: Use scaled sinusoidal positional embeddings.
            average_pool_embed: Average pool output embeddings over sequence.
            probabilistic: Output mean and variance for Gaussian predictions.
        """

class ContinuousAutoregressiveWrapper(Module):
    def __init__(
        self,
        net: ContinuousTransformerWrapper,
        loss_fn: Module | None = None,
        use_l1_loss = False,
        equal_loss_weight_batch = False,
    ):
        """
        Args:
            net: ContinuousTransformerWrapper to wrap.
            loss_fn: Custom loss function. Defaults to MSE, L1, or GaussianNLL.
            use_l1_loss: Use L1 loss instead of MSE when no custom loss_fn.
            equal_loss_weight_batch: Weight each sequence equally regardless of length.
        """

Import

from x_transformers.continuous import ContinuousTransformerWrapper, ContinuousAutoregressiveWrapper

I/O Contract

ContinuousTransformerWrapper Inputs

Name Type Required Description
x Tensor (b, n, d_in) Yes Continuous input sequence
mask Tensor (b, n) No Boolean attention mask
lens Tensor (b,) No Sequence lengths (alternative to mask)
return_embeddings bool No Return raw embeddings instead of projected output
mems list of Tensor No Memory tensors for Transformer-XL recurrence

ContinuousTransformerWrapper Outputs

Name Type Description
forward() returns Tensor (b, n, d_out) Projected output; or tuple(mean, variance) if probabilistic
with return_mems (Tensor, tuple) Output plus memory tensors for recurrence

ContinuousAutoregressiveWrapper Inputs

Name Type Required Description
x Tensor (b, n, d) Yes Full continuous sequence (model predicts x[:,1:] from x[:,:-1])
rollout_steps int No Multi-step rollout training (default 1)

ContinuousAutoregressiveWrapper Outputs

Name Type Description
forward() returns Tensor (scalar) Mean loss (MSE, L1, or Gaussian NLL)
generate() returns Tensor (b, seq_len, d) Autoregressively generated continuous sequence

Usage Examples

Basic Continuous Autoregressive Training

import torch
from x_transformers import Decoder
from x_transformers.continuous import ContinuousTransformerWrapper, ContinuousAutoregressiveWrapper

# Build continuous transformer
model = ContinuousTransformerWrapper(
    max_seq_len=512,
    attn_layers=Decoder(dim=256, depth=6, heads=8),
    dim_in=64,
    dim_out=64
)

wrapper = ContinuousAutoregressiveWrapper(model)

# Training: predict next continuous vector from previous
x = torch.randn(4, 128, 64)  # batch of continuous sequences
loss = wrapper(x)
loss.backward()

Generation

# Generate 50 steps from a seed sequence
seed = torch.randn(1, 10, 64)
generated = wrapper.generate(seed, seq_len=50)
# generated.shape == (1, 50, 64)

Probabilistic Output

# Gaussian mean-variance output
prob_model = ContinuousTransformerWrapper(
    max_seq_len=512,
    attn_layers=Decoder(dim=256, depth=6, heads=8),
    dim_in=64,
    dim_out=64,
    probabilistic=True
)

prob_wrapper = ContinuousAutoregressiveWrapper(prob_model)

x = torch.randn(4, 128, 64)
loss = prob_wrapper(x)
loss.backward()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment