Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lucidrains X transformers UniversalPretrainWrapper

From Leeroopedia


Knowledge Sources
Domains NLP, Language_Modeling, Pretraining
Last Updated 2026-02-08 18:00 GMT

Overview

Concrete tool for pretraining transformers on synthetic data generated by lightweight Turing machines (LSTM/GRU) without requiring any real text corpus provided by the x-transformers library.

Description

The UniversalPretrainWrapper implements Universal Pretraining (UP) from Bloem (2025). Instead of pretraining on real text data, it uses a lightweight SyntheticDataGenerator (LSTM or GRU) to produce synthetic training sequences. The generator maintains a buffer of sequences that are iteratively "enriched" by generating new sequences conditioned on existing ones. The transformer is trained autoregressively on samples from this buffer. The approach follows Algorithm 1 from the paper: sample conditions and seeds from the buffer, generate new sequences via the Turing machine, place generated sequences back into the buffer, periodically reset portions of the buffer with random sequences, and train the transformer on buffer samples. The Turing machine weights can also be periodically reset to prevent the synthetic distribution from becoming too narrow.

Usage

Import these classes when you want to pretrain a transformer model without access to a real text corpus, or when you want to augment standard pretraining with synthetic data. The approach can serve as a warm-start or as the sole pretraining signal for small-scale experiments.

Code Reference

Source Location

Signature

class UniversalPretrainWrapper(Module):
    def __init__(
        self,
        model: TransformerWrapper,
        data_generator: SyntheticDataGenerator | Module | None = None,
        buffer_size = None,
        num_reset = 20,
        batch_size = 32,
        seq_len = 512,
        seed_length = 8,
        reset_turing_machine_every = 0,
        keep_buffer_on_cpu = False
    ):
        """
        Args:
            model: TransformerWrapper (causal decoder) to pretrain.
            data_generator: Turing machine for synthetic data. Defaults to SyntheticDataGenerator.
            buffer_size: Size of the synthetic data buffer (defaults to batch_size * 20).
            num_reset: Number of buffer entries to reset with random sequences each step.
            batch_size: Number of sequences per training step.
            seq_len: Sequence length for generated data.
            seed_length: Length of seed sequences for the Turing machine.
            reset_turing_machine_every: Reset generator weights every N steps (0 = never).
            keep_buffer_on_cpu: Keep the data buffer on CPU to save GPU memory.
        """

class SyntheticDataGenerator(Module):
    def __init__(
        self,
        dim,
        num_tokens,
        max_seq_len = 512,
        hidden_size = None,
        use_gru = False,
        network_klass = None
    ):
        """
        Args:
            dim: Embedding dimension.
            num_tokens: Vocabulary size.
            max_seq_len: Maximum sequence length for generation.
            hidden_size: RNN hidden size (defaults to dim).
            use_gru: Use GRU instead of LSTM.
            network_klass: Custom network class (overrides LSTM/GRU).
        """

Import

from x_transformers.up_wrapper import UniversalPretrainWrapper, SyntheticDataGenerator

I/O Contract

UniversalPretrainWrapper Inputs

Name Type Required Description
(no args) forward() takes no arguments; it self-generates training data

Outputs

Name Type Description
forward() returns Tensor (scalar) Autoregressive cross-entropy loss on synthetic data

Usage Examples

Universal Pretraining Loop

import torch
from x_transformers import TransformerWrapper, Decoder
from x_transformers.up_wrapper import UniversalPretrainWrapper

# Build the model to pretrain
model = TransformerWrapper(
    num_tokens=256,
    max_seq_len=512,
    attn_layers=Decoder(dim=256, depth=6, heads=8)
)

# Wrap for universal pretraining
up_wrapper = UniversalPretrainWrapper(
    model=model,
    buffer_size=640,
    batch_size=32,
    seq_len=512,
    seed_length=8,
    num_reset=20,
    reset_turing_machine_every=100
)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

# Training loop - no data loading needed!
for step in range(1000):
    loss = up_wrapper()  # generates its own data
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment