Principle:Lucidrains X transformers Hybrid Discrete Continuous Tokens
| Knowledge Sources | |
|---|---|
| Domains | NLP, Numerical_Reasoning, Model_Architecture |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Technique that enables transformers to jointly process discrete tokens and continuous numerical values by scaling token embeddings with real-valued numbers.
Description
The Hybrid Discrete-Continuous Token approach extends the standard discrete vocabulary transformer to handle continuous numbers. The key mechanism is simple: a special "numerical token" is designated in the vocabulary. When a position contains this token, the embedding is multiplicatively scaled by the actual numerical value. This allows the same transformer architecture to process both discrete text and continuous numbers without separate heads or encoders for each modality. The model outputs both token logits (for classification) and numerical predictions (for regression) at each position, with training combining cross-entropy loss for discrete tokens and MSE loss for numerical values.
Usage
Use this principle when designing transformer models that must process sequences containing interleaved discrete and continuous data, particularly for mathematical reasoning tasks. The xVal approach generalizes better for arithmetic operations compared to approaches that tokenize numbers into digit strings.
Theoretical Basis
The embedding modification for numerical tokens:
Failed to parse (unknown function "\begin{cases}"): {\displaystyle \mathbf{e}_i = \begin{cases} x_{num,i} \cdot \text{Embed}(x_i) & \text{if } x_i = \text{NUM\_TOKEN} \\ \text{Embed}(x_i) & \text{otherwise} \end{cases} }
Pseudo-code Logic:
# Abstract algorithm (NOT real implementation)
embedding = token_embed(token_ids)
is_number = (token_ids == NUMERICAL_TOKEN_ID)
scale = where(is_number, numerical_values, 1.0)
embedding = embedding * scale # scale only numerical positions
# Dual-head output
token_logits = to_logits(hidden) # discrete prediction
numerical_pred = to_numerical(hidden) # continuous prediction
# Combined loss
loss = cross_entropy(token_logits, target_tokens) + mse(numerical_pred, target_numbers)