Overview
Concrete tool for applying transformer attention layers to continuous (real-valued) input sequences provided by the x-transformers library.
Description
The ContinuousTransformerWrapper wraps an AttentionLayers module to process continuous-valued inputs instead of discrete tokens. It handles positional embeddings, input/output linear projections, memory tokens, and optional probabilistic (Gaussian mean-variance) output. The companion ContinuousAutoregressiveWrapper adds autoregressive training and generation for continuous sequences, supporting MSE, L1, or Gaussian NLL loss, and multi-step rollout training as used in world model papers.
Usage
Import these classes when working with continuous-valued sequences such as time series, audio features, or world model states where the input and output are real-valued vectors rather than discrete tokens.
Code Reference
Source Location
Signature
class ContinuousTransformerWrapper(Module):
def __init__(
self,
*,
max_seq_len,
attn_layers: AttentionLayers,
dim_in = None,
dim_out = None,
emb_dim = None,
max_mem_len = 0,
num_memory_tokens = None,
post_emb_norm = False,
emb_dropout = 0.,
use_abs_pos_emb = True,
scaled_sinu_pos_emb = False,
average_pool_embed = False,
probabilistic = False,
):
"""
Args:
max_seq_len: Maximum sequence length for positional embeddings.
attn_layers: AttentionLayers instance (Encoder or Decoder).
dim_in: Input projection dimension. None means no projection.
dim_out: Output projection dimension. None means no projection.
emb_dim: Embedding dimension override.
max_mem_len: Maximum memory length for Transformer-XL style recurrence.
num_memory_tokens: Number of learnable memory tokens.
post_emb_norm: Apply LayerNorm after embedding.
emb_dropout: Dropout rate on embeddings.
use_abs_pos_emb: Use absolute positional embeddings.
scaled_sinu_pos_emb: Use scaled sinusoidal positional embeddings.
average_pool_embed: Average pool output embeddings over sequence.
probabilistic: Output mean and variance for Gaussian predictions.
"""
class ContinuousAutoregressiveWrapper(Module):
def __init__(
self,
net: ContinuousTransformerWrapper,
loss_fn: Module | None = None,
use_l1_loss = False,
equal_loss_weight_batch = False,
):
"""
Args:
net: ContinuousTransformerWrapper to wrap.
loss_fn: Custom loss function. Defaults to MSE, L1, or GaussianNLL.
use_l1_loss: Use L1 loss instead of MSE when no custom loss_fn.
equal_loss_weight_batch: Weight each sequence equally regardless of length.
"""
Import
from x_transformers.continuous import ContinuousTransformerWrapper, ContinuousAutoregressiveWrapper
I/O Contract
ContinuousTransformerWrapper Inputs
| Name |
Type |
Required |
Description
|
| x |
Tensor (b, n, d_in) |
Yes |
Continuous input sequence
|
| mask |
Tensor (b, n) |
No |
Boolean attention mask
|
| lens |
Tensor (b,) |
No |
Sequence lengths (alternative to mask)
|
| return_embeddings |
bool |
No |
Return raw embeddings instead of projected output
|
| mems |
list of Tensor |
No |
Memory tensors for Transformer-XL recurrence
|
ContinuousTransformerWrapper Outputs
| Name |
Type |
Description
|
| forward() returns |
Tensor (b, n, d_out) |
Projected output; or tuple(mean, variance) if probabilistic
|
| with return_mems |
(Tensor, tuple) |
Output plus memory tensors for recurrence
|
ContinuousAutoregressiveWrapper Inputs
| Name |
Type |
Required |
Description
|
| x |
Tensor (b, n, d) |
Yes |
Full continuous sequence (model predicts x[:,1:] from x[:,:-1])
|
| rollout_steps |
int |
No |
Multi-step rollout training (default 1)
|
ContinuousAutoregressiveWrapper Outputs
| Name |
Type |
Description
|
| forward() returns |
Tensor (scalar) |
Mean loss (MSE, L1, or Gaussian NLL)
|
| generate() returns |
Tensor (b, seq_len, d) |
Autoregressively generated continuous sequence
|
Usage Examples
Basic Continuous Autoregressive Training
import torch
from x_transformers import Decoder
from x_transformers.continuous import ContinuousTransformerWrapper, ContinuousAutoregressiveWrapper
# Build continuous transformer
model = ContinuousTransformerWrapper(
max_seq_len=512,
attn_layers=Decoder(dim=256, depth=6, heads=8),
dim_in=64,
dim_out=64
)
wrapper = ContinuousAutoregressiveWrapper(model)
# Training: predict next continuous vector from previous
x = torch.randn(4, 128, 64) # batch of continuous sequences
loss = wrapper(x)
loss.backward()
Generation
# Generate 50 steps from a seed sequence
seed = torch.randn(1, 10, 64)
generated = wrapper.generate(seed, seq_len=50)
# generated.shape == (1, 50, 64)
Probabilistic Output
# Gaussian mean-variance output
prob_model = ContinuousTransformerWrapper(
max_seq_len=512,
attn_layers=Decoder(dim=256, depth=6, heads=8),
dim_in=64,
dim_out=64,
probabilistic=True
)
prob_wrapper = ContinuousAutoregressiveWrapper(prob_model)
x = torch.randn(4, 128, 64)
loss = prob_wrapper(x)
loss.backward()
Related Pages