Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers XValTransformerWrapper

From Leeroopedia


Knowledge Sources
Domains NLP, Numerical_Reasoning, Model_Architecture
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for hybrid discrete-continuous token transformers that handle numerical values alongside discrete tokens provided by the x-transformers library.

Description

The XValTransformerWrapper implements the xVal architecture, which extends a standard discrete-token transformer to jointly process numerical values. Each token position has both a discrete token ID and a continuous numerical value. When a token is the designated numerical_token_id, the token embedding is scaled by the numerical value, allowing the model to represent arbitrary real numbers within the standard transformer framework. The output produces both token logits and numerical predictions. The companion XValAutoregressiveWrapper provides autoregressive training (cross-entropy + MSE loss) and generation that returns both token sequences and numerical values.

Usage

Import these classes when building models that need to process sequences containing both discrete tokens and continuous numbers, such as mathematical reasoning, scientific data, financial time series, or any domain where arithmetic generalization is important.

Code Reference

Source Location

Signature

class XValTransformerWrapper(nn.Module):
    def __init__(
        self,
        *,
        num_tokens,
        max_seq_len,
        numerical_token_id,
        attn_layers: AttentionLayers,
        emb_dim = None,
        logits_dim = None,
        tie_embedding = False,
        max_mem_len = 0,
        num_memory_tokens = None,
        emb_dropout = 0.,
        use_abs_pos_emb = True,
        scaled_sinu_pos_emb = False
    ):
        """
        Args:
            num_tokens: Size of discrete vocabulary.
            max_seq_len: Maximum sequence length.
            numerical_token_id: Token ID that indicates a numerical value position.
            attn_layers: AttentionLayers instance (Encoder or Decoder).
            emb_dim: Embedding dimension override.
            logits_dim: Output logits dimension (defaults to num_tokens).
            tie_embedding: Tie input/output embeddings.
            max_mem_len: Maximum memory length for Transformer-XL recurrence.
            num_memory_tokens: Number of learnable memory tokens.
            emb_dropout: Dropout rate on embeddings.
            use_abs_pos_emb: Use absolute positional embeddings.
            scaled_sinu_pos_emb: Use scaled sinusoidal positional embeddings.
        """

class XValAutoregressiveWrapper(nn.Module):
    def __init__(
        self,
        net: XValTransformerWrapper,
        ignore_index = -100,
        pad_value = 0,
        numerical_loss_weight = 1.
    ):
        """
        Args:
            net: XValTransformerWrapper to wrap for autoregressive training.
            ignore_index: Label index to ignore in cross-entropy loss.
            pad_value: Padding value for generated sequences.
            numerical_loss_weight: Weight for numerical MSE loss relative to cross-entropy.
        """

Import

from x_transformers.xval import XValTransformerWrapper, XValAutoregressiveWrapper

I/O Contract

XValTransformerWrapper Inputs

Name Type Required Description
x Tensor (b, n) of int Yes Discrete token IDs
x_num Tensor (b, n) of float Yes Numerical values (used where x == numerical_token_id)
mask Tensor (b, n) No Boolean attention mask
return_embeddings bool No Return raw embeddings instead of logits

XValTransformerWrapper Outputs

Name Type Description
forward() returns (Tensor, Tensor) Tuple of (token_logits (b, n, vocab), numerical_pred (b, n))

XValAutoregressiveWrapper forward()

Name Type Description
loss Tensor (scalar) Combined cross-entropy + weighted MSE loss
with return_loss_breakdown (Tensor, LossBreakdown) Total loss plus named tuple of (cross_entropy_loss, numerical_mse_loss)

XValAutoregressiveWrapper generate()

Name Type Description
returns GenerateReturn Named tuple of (sampled_token_ids, sampled_numbers, is_number_mask)

Usage Examples

Training

import torch
from x_transformers import Decoder
from x_transformers.xval import XValTransformerWrapper, XValAutoregressiveWrapper

NUMERICAL_TOKEN = 3  # designate token ID 3 as the "number" token

model = XValTransformerWrapper(
    num_tokens=256,
    max_seq_len=512,
    numerical_token_id=NUMERICAL_TOKEN,
    attn_layers=Decoder(dim=256, depth=6, heads=8)
)

wrapper = XValAutoregressiveWrapper(model, numerical_loss_weight=1.0)

# Sequence with some positions being numerical
x = torch.randint(0, 256, (4, 128))
x_num = torch.randn(4, 128)  # numerical values (only used where x == NUMERICAL_TOKEN)

loss = wrapper(x, x_num)
loss.backward()

Generation

start_tokens = torch.randint(0, 256, (1, 5))
start_numbers = torch.zeros(1, 5)

result = wrapper.generate(start_tokens, start_numbers, seq_len=50)
# result.sampled_token_ids: generated discrete tokens
# result.sampled_numbers: predicted numerical values (NaN where not numerical)
# result.is_number_mask: boolean mask of numerical positions

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment