Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Lucidrains X transformers TextSamplerDataset Pattern

From Leeroopedia


Field Value
Repo x-transformers
Domains Data_Engineering, NLP
Last Updated 2026-02-08 18:00 GMT

Overview

Pattern specification for creating token sequence datasets for autoregressive training, as demonstrated in the x-transformers training examples.

Description

This is a Pattern Doc — it documents a user-defined interface, not a library API. Users create a torch.utils.data.Dataset subclass that returns integer token sequences of shape (seq_len + 1,). The reference implementation is TextSamplerDataset from train_enwik8.py.

The pattern requires:

  • Subclassing torch.utils.data.Dataset.
  • Storing the full tokenized corpus as a 1D tensor.
  • Returning random contiguous subsequences of length seq_len + 1.
  • Reporting a logical length based on the corpus size divided by seq_len.

Code Reference

File: train_enwik8.py, Lines: L71-88

Interface Specification

TextSamplerDataset (Reference Implementation)

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
        full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0) // self.seq_len

Key details:

  • The index parameter in __getitem__ is ignored — the start position is always random.
  • The upper bound for rand_start is data.size(0) - seq_len - 1, ensuring the subsequence never exceeds the data boundary.
  • The returned tensor is of type LongTensor (via .long()).
  • The .cuda() call moves data to GPU; in practice, users may prefer to handle device placement in the training loop instead.

Cycling DataLoader Pattern

def cycle(loader):
    while True:
        for data in loader:
            yield data

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE, drop_last=True))

Key details:

  • The cycle generator wraps a standard DataLoader and restarts it whenever the underlying iterator is exhausted.
  • drop_last=True ensures all batches have exactly BATCH_SIZE samples.
  • Consuming data is done via next(train_loader) in the training loop.

Input / Output

Direction Name Type Description
Input data 1D torch.Tensor (integer) Raw tokenized corpus as a flat tensor of token IDs
Input seq_len int Desired sequence length for training (samples will be seq_len + 1)
Output batch (batch_size, seq_len + 1) LongTensor Batched token sequences from the cycling DataLoader

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment