Implementation:Lucidrains X transformers Masked Dataset Pattern
| Field | Value |
|---|---|
| Repo | x-transformers |
| Domains | Data_Engineering, Generative_Models |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Pattern specification for creating token sequence datasets for non-autoregressive masked prediction training with x-transformers.
Description
This is a Pattern Doc. No reference training script exists in the repository. The interface is derived from NonAutoregressiveWrapper.forward(), which expects x of shape (batch, max_seq_len) as integer tokens. The forward method asserts n == self.max_seq_len, meaning sequences must be exactly the configured length. Masking is applied internally by the wrapper using a schedule-based strategy.
The pattern requires:
- Subclassing
torch.utils.data.Dataset. - Returning sequences of exactly
max_seq_lentokens (no more, no less). - Returning unmasked integer token IDs (masking is handled by the wrapper).
- Ensuring token values do not include the reserved
mask_id.
Code Reference
File: x_transformers/nonautoregressive_wrapper.py, Lines: L275-284 (NonAutoregressiveWrapper.forward())
Interface Specification
MaskedTokenDataset (Derived Interface)
class MaskedTokenDataset(Dataset):
"""Dataset for NonAutoregressiveWrapper training.
Must return integer token sequences of exactly max_seq_len length.
Do NOT apply masking - the wrapper handles this internally.
Token values must be in range [0, num_tokens - 1] (excluding mask_id).
"""
def __init__(self, data, max_seq_len):
self.data = data
self.max_seq_len = max_seq_len
def __getitem__(self, index):
# Return exactly max_seq_len tokens (no +1 needed)
start = torch.randint(0, len(self.data) - self.max_seq_len, (1,))
return self.data[start:start + self.max_seq_len].long()
def __len__(self):
return len(self.data) // self.max_seq_len
Key details:
- Unlike
TextSamplerDatasetfor autoregressive training, this dataset returns sequences of exactlymax_seq_len(notmax_seq_len + 1). - The
indexparameter is ignored in favor of random sampling, similar to the autoregressive pattern. - No masking is applied — the
NonAutoregressiveWrapperhandles masking internally during the forward pass. - Token values must not include the
mask_idtoken, as this is reserved for the masking mechanism.
Usage with NonAutoregressiveWrapper
from x_transformers import TransformerWrapper, Decoder
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper
model = TransformerWrapper(
num_tokens = 256,
max_seq_len = MAX_SEQ_LEN,
attn_layers = Decoder(dim=512, depth=6, heads=8)
)
wrapper = NonAutoregressiveWrapper(model, mask_id=255, steps=18)
dataset = MaskedTokenDataset(data, MAX_SEQ_LEN)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, drop_last=True)
for batch in loader:
loss = wrapper(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Input / Output
| Direction | Name | Type | Shape | Description |
|---|---|---|---|---|
| Input | data |
1D torch.Tensor (integer) |
(corpus_len,) |
Raw tokenized corpus as a flat tensor of token IDs |
| Input | max_seq_len |
int |
scalar | Exact sequence length required by NonAutoregressiveWrapper
|
| Output | batch | LongTensor |
(batch_size, max_seq_len) |
Batched unmasked token sequences from the DataLoader
|