Implementation:Lucidrains X transformers UniversalPretrainWrapper
| Knowledge Sources | |
|---|---|
| Domains | NLP, Language_Modeling, Pretraining |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Concrete tool for pretraining transformers on synthetic data generated by lightweight Turing machines (LSTM/GRU) without requiring any real text corpus provided by the x-transformers library.
Description
The UniversalPretrainWrapper implements Universal Pretraining (UP) from Bloem (2025). Instead of pretraining on real text data, it uses a lightweight SyntheticDataGenerator (LSTM or GRU) to produce synthetic training sequences. The generator maintains a buffer of sequences that are iteratively "enriched" by generating new sequences conditioned on existing ones. The transformer is trained autoregressively on samples from this buffer. The approach follows Algorithm 1 from the paper: sample conditions and seeds from the buffer, generate new sequences via the Turing machine, place generated sequences back into the buffer, periodically reset portions of the buffer with random sequences, and train the transformer on buffer samples. The Turing machine weights can also be periodically reset to prevent the synthetic distribution from becoming too narrow.
Usage
Import these classes when you want to pretrain a transformer model without access to a real text corpus, or when you want to augment standard pretraining with synthetic data. The approach can serve as a warm-start or as the sole pretraining signal for small-scale experiments.
Code Reference
Source Location
- Repository: Lucidrains_X_transformers
- File: x_transformers/up_wrapper.py
- Lines: 145-260
Signature
class UniversalPretrainWrapper(Module):
def __init__(
self,
model: TransformerWrapper,
data_generator: SyntheticDataGenerator | Module | None = None,
buffer_size = None,
num_reset = 20,
batch_size = 32,
seq_len = 512,
seed_length = 8,
reset_turing_machine_every = 0,
keep_buffer_on_cpu = False
):
"""
Args:
model: TransformerWrapper (causal decoder) to pretrain.
data_generator: Turing machine for synthetic data. Defaults to SyntheticDataGenerator.
buffer_size: Size of the synthetic data buffer (defaults to batch_size * 20).
num_reset: Number of buffer entries to reset with random sequences each step.
batch_size: Number of sequences per training step.
seq_len: Sequence length for generated data.
seed_length: Length of seed sequences for the Turing machine.
reset_turing_machine_every: Reset generator weights every N steps (0 = never).
keep_buffer_on_cpu: Keep the data buffer on CPU to save GPU memory.
"""
class SyntheticDataGenerator(Module):
def __init__(
self,
dim,
num_tokens,
max_seq_len = 512,
hidden_size = None,
use_gru = False,
network_klass = None
):
"""
Args:
dim: Embedding dimension.
num_tokens: Vocabulary size.
max_seq_len: Maximum sequence length for generation.
hidden_size: RNN hidden size (defaults to dim).
use_gru: Use GRU instead of LSTM.
network_klass: Custom network class (overrides LSTM/GRU).
"""
Import
from x_transformers.up_wrapper import UniversalPretrainWrapper, SyntheticDataGenerator
I/O Contract
UniversalPretrainWrapper Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| (no args) | — | — | forward() takes no arguments; it self-generates training data |
Outputs
| Name | Type | Description |
|---|---|---|
| forward() returns | Tensor (scalar) | Autoregressive cross-entropy loss on synthetic data |
Usage Examples
Universal Pretraining Loop
import torch
from x_transformers import TransformerWrapper, Decoder
from x_transformers.up_wrapper import UniversalPretrainWrapper
# Build the model to pretrain
model = TransformerWrapper(
num_tokens=256,
max_seq_len=512,
attn_layers=Decoder(dim=256, depth=6, heads=8)
)
# Wrap for universal pretraining
up_wrapper = UniversalPretrainWrapper(
model=model,
buffer_size=640,
batch_size=32,
seq_len=512,
seed_length=8,
num_reset=20,
reset_turing_machine_every=100
)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
# Training loop - no data loading needed!
for step in range(1000):
loss = up_wrapper() # generates its own data
loss.backward()
optimizer.step()
optimizer.zero_grad()