Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Lucidrains X transformers Masked Dataset Pattern

From Leeroopedia


Field Value
Repo x-transformers
Domains Data_Engineering, Generative_Models
Last Updated 2026-02-08 18:00 GMT

Overview

Pattern specification for creating token sequence datasets for non-autoregressive masked prediction training with x-transformers.

Description

This is a Pattern Doc. No reference training script exists in the repository. The interface is derived from NonAutoregressiveWrapper.forward(), which expects x of shape (batch, max_seq_len) as integer tokens. The forward method asserts n == self.max_seq_len, meaning sequences must be exactly the configured length. Masking is applied internally by the wrapper using a schedule-based strategy.

The pattern requires:

  • Subclassing torch.utils.data.Dataset.
  • Returning sequences of exactly max_seq_len tokens (no more, no less).
  • Returning unmasked integer token IDs (masking is handled by the wrapper).
  • Ensuring token values do not include the reserved mask_id.

Code Reference

File: x_transformers/nonautoregressive_wrapper.py, Lines: L275-284 (NonAutoregressiveWrapper.forward())

Interface Specification

MaskedTokenDataset (Derived Interface)

class MaskedTokenDataset(Dataset):
    """Dataset for NonAutoregressiveWrapper training.

    Must return integer token sequences of exactly max_seq_len length.
    Do NOT apply masking - the wrapper handles this internally.
    Token values must be in range [0, num_tokens - 1] (excluding mask_id).
    """
    def __init__(self, data, max_seq_len):
        self.data = data
        self.max_seq_len = max_seq_len

    def __getitem__(self, index):
        # Return exactly max_seq_len tokens (no +1 needed)
        start = torch.randint(0, len(self.data) - self.max_seq_len, (1,))
        return self.data[start:start + self.max_seq_len].long()

    def __len__(self):
        return len(self.data) // self.max_seq_len

Key details:

  • Unlike TextSamplerDataset for autoregressive training, this dataset returns sequences of exactly max_seq_len (not max_seq_len + 1).
  • The index parameter is ignored in favor of random sampling, similar to the autoregressive pattern.
  • No masking is applied — the NonAutoregressiveWrapper handles masking internally during the forward pass.
  • Token values must not include the mask_id token, as this is reserved for the masking mechanism.

Usage with NonAutoregressiveWrapper

from x_transformers import TransformerWrapper, Decoder
from x_transformers.nonautoregressive_wrapper import NonAutoregressiveWrapper

model = TransformerWrapper(
    num_tokens = 256,
    max_seq_len = MAX_SEQ_LEN,
    attn_layers = Decoder(dim=512, depth=6, heads=8)
)

wrapper = NonAutoregressiveWrapper(model, mask_id=255, steps=18)

dataset = MaskedTokenDataset(data, MAX_SEQ_LEN)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, drop_last=True)

for batch in loader:
    loss = wrapper(batch)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Input / Output

Direction Name Type Shape Description
Input data 1D torch.Tensor (integer) (corpus_len,) Raw tokenized corpus as a flat tensor of token IDs
Input max_seq_len int scalar Exact sequence length required by NonAutoregressiveWrapper
Output batch LongTensor (batch_size, max_seq_len) Batched unmasked token sequences from the DataLoader

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment