Overview
Concrete tool for hybrid discrete-continuous token transformers that handle numerical values alongside discrete tokens provided by the x-transformers library.
Description
The XValTransformerWrapper implements the xVal architecture, which extends a standard discrete-token transformer to jointly process numerical values. Each token position has both a discrete token ID and a continuous numerical value. When a token is the designated numerical_token_id, the token embedding is scaled by the numerical value, allowing the model to represent arbitrary real numbers within the standard transformer framework. The output produces both token logits and numerical predictions. The companion XValAutoregressiveWrapper provides autoregressive training (cross-entropy + MSE loss) and generation that returns both token sequences and numerical values.
Usage
Import these classes when building models that need to process sequences containing both discrete tokens and continuous numbers, such as mathematical reasoning, scientific data, financial time series, or any domain where arithmetic generalization is important.
Code Reference
Source Location
Signature
class XValTransformerWrapper(nn.Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
numerical_token_id,
attn_layers: AttentionLayers,
emb_dim = None,
logits_dim = None,
tie_embedding = False,
max_mem_len = 0,
num_memory_tokens = None,
emb_dropout = 0.,
use_abs_pos_emb = True,
scaled_sinu_pos_emb = False
):
"""
Args:
num_tokens: Size of discrete vocabulary.
max_seq_len: Maximum sequence length.
numerical_token_id: Token ID that indicates a numerical value position.
attn_layers: AttentionLayers instance (Encoder or Decoder).
emb_dim: Embedding dimension override.
logits_dim: Output logits dimension (defaults to num_tokens).
tie_embedding: Tie input/output embeddings.
max_mem_len: Maximum memory length for Transformer-XL recurrence.
num_memory_tokens: Number of learnable memory tokens.
emb_dropout: Dropout rate on embeddings.
use_abs_pos_emb: Use absolute positional embeddings.
scaled_sinu_pos_emb: Use scaled sinusoidal positional embeddings.
"""
class XValAutoregressiveWrapper(nn.Module):
def __init__(
self,
net: XValTransformerWrapper,
ignore_index = -100,
pad_value = 0,
numerical_loss_weight = 1.
):
"""
Args:
net: XValTransformerWrapper to wrap for autoregressive training.
ignore_index: Label index to ignore in cross-entropy loss.
pad_value: Padding value for generated sequences.
numerical_loss_weight: Weight for numerical MSE loss relative to cross-entropy.
"""
Import
from x_transformers.xval import XValTransformerWrapper, XValAutoregressiveWrapper
I/O Contract
XValTransformerWrapper Inputs
| Name |
Type |
Required |
Description
|
| x |
Tensor (b, n) of int |
Yes |
Discrete token IDs
|
| x_num |
Tensor (b, n) of float |
Yes |
Numerical values (used where x == numerical_token_id)
|
| mask |
Tensor (b, n) |
No |
Boolean attention mask
|
| return_embeddings |
bool |
No |
Return raw embeddings instead of logits
|
XValTransformerWrapper Outputs
| Name |
Type |
Description
|
| forward() returns |
(Tensor, Tensor) |
Tuple of (token_logits (b, n, vocab), numerical_pred (b, n))
|
XValAutoregressiveWrapper forward()
| Name |
Type |
Description
|
| loss |
Tensor (scalar) |
Combined cross-entropy + weighted MSE loss
|
| with return_loss_breakdown |
(Tensor, LossBreakdown) |
Total loss plus named tuple of (cross_entropy_loss, numerical_mse_loss)
|
XValAutoregressiveWrapper generate()
| Name |
Type |
Description
|
| returns |
GenerateReturn |
Named tuple of (sampled_token_ids, sampled_numbers, is_number_mask)
|
Usage Examples
Training
import torch
from x_transformers import Decoder
from x_transformers.xval import XValTransformerWrapper, XValAutoregressiveWrapper
NUMERICAL_TOKEN = 3 # designate token ID 3 as the "number" token
model = XValTransformerWrapper(
num_tokens=256,
max_seq_len=512,
numerical_token_id=NUMERICAL_TOKEN,
attn_layers=Decoder(dim=256, depth=6, heads=8)
)
wrapper = XValAutoregressiveWrapper(model, numerical_loss_weight=1.0)
# Sequence with some positions being numerical
x = torch.randint(0, 256, (4, 128))
x_num = torch.randn(4, 128) # numerical values (only used where x == NUMERICAL_TOKEN)
loss = wrapper(x, x_num)
loss.backward()
Generation
start_tokens = torch.randint(0, 256, (1, 5))
start_numbers = torch.zeros(1, 5)
result = wrapper.generate(start_tokens, start_numbers, seq_len=50)
# result.sampled_token_ids: generated discrete tokens
# result.sampled_numbers: predicted numerical values (NaN where not numerical)
# result.is_number_mask: boolean mask of numerical positions
Related Pages