Implementation:Lucidrains X transformers TextSamplerDataset Pattern
Appearance
| Field | Value |
|---|---|
| Repo | x-transformers |
| Domains | Data_Engineering, NLP |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Pattern specification for creating token sequence datasets for autoregressive training, as demonstrated in the x-transformers training examples.
Description
This is a Pattern Doc — it documents a user-defined interface, not a library API. Users create a torch.utils.data.Dataset subclass that returns integer token sequences of shape (seq_len + 1,). The reference implementation is TextSamplerDataset from train_enwik8.py.
The pattern requires:
- Subclassing
torch.utils.data.Dataset. - Storing the full tokenized corpus as a 1D tensor.
- Returning random contiguous subsequences of length
seq_len + 1. - Reporting a logical length based on the corpus size divided by
seq_len.
Code Reference
File: train_enwik8.py, Lines: L71-88
Interface Specification
TextSamplerDataset (Reference Implementation)
class TextSamplerDataset(Dataset):
def __init__(self, data, seq_len):
super().__init__()
self.data = data
self.seq_len = seq_len
def __getitem__(self, index):
rand_start = torch.randint(0, self.data.size(0) - self.seq_len - 1, (1,))
full_seq = self.data[rand_start: rand_start + self.seq_len + 1].long()
return full_seq.cuda()
def __len__(self):
return self.data.size(0) // self.seq_len
Key details:
- The
indexparameter in__getitem__is ignored — the start position is always random. - The upper bound for
rand_startisdata.size(0) - seq_len - 1, ensuring the subsequence never exceeds the data boundary. - The returned tensor is of type
LongTensor(via.long()). - The
.cuda()call moves data to GPU; in practice, users may prefer to handle device placement in the training loop instead.
Cycling DataLoader Pattern
def cycle(loader):
while True:
for data in loader:
yield data
train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE, drop_last=True))
Key details:
- The
cyclegenerator wraps a standardDataLoaderand restarts it whenever the underlying iterator is exhausted. drop_last=Trueensures all batches have exactlyBATCH_SIZEsamples.- Consuming data is done via
next(train_loader)in the training loop.
Input / Output
| Direction | Name | Type | Description |
|---|---|---|---|
| Input | data |
1D torch.Tensor (integer) |
Raw tokenized corpus as a flat tensor of token IDs |
| Input | seq_len |
int |
Desired sequence length for training (samples will be seq_len + 1)
|
| Output | batch | (batch_size, seq_len + 1) LongTensor |
Batched token sequences from the cycling DataLoader
|
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment