Principle:Lucidrains X transformers Masked Token Prediction Training
Metadata
| Field | Value |
|---|---|
| Paper | MaskGIT |
| Paper | MDLM |
| Repository | x-transformers |
| Domains | Deep_Learning, Generative_Models, Training |
| Last Updated | 2026-02-08 18:00 GMT |
Overview
Training objective that randomly masks tokens according to a schedule and trains a bidirectional model to predict the original tokens at masked positions.
Description
During training, a random fraction of input tokens are replaced with a mask token according to a time-dependent schedule (linear or cosine). The model predicts the original tokens at all positions, but loss is computed only on masked positions. This is similar to BERT MLM training but with a continuous masking ratio sampled uniformly from [0, 1] per batch element.
The training procedure supports:
- BERT-style augmentation — a fraction of masked positions keep their original token (no-replace, 15%) and another fraction are replaced with a random token (random-token, 10%), following the original BERT paper. These augmentations are applied stochastically via coin-flip gating.
- Self-conditioning — the model can condition on its own prior predictions by running a no-gradient forward pass and feeding the resulting embeddings back as an additive bias on the next forward pass.
- Token critic training — an optional auxiliary binary classifier (either a separate
TransformerWrapperor aSelfCritichead) is trained to distinguish correctly predicted tokens from incorrectly predicted ones, providing better confidence scores at generation time. - MDLM loss weighting — following Sahoo et al. (2024), loss is weighted by
schedule'(t) / (1 - schedule(t)), which upweights positions at higher noise levels where reconstruction is harder.
The method returns a Losses namedtuple containing the total loss, generator loss (masked cross-entropy), and an optional critic loss (binary cross-entropy).
Usage
Use during training of non-autoregressive models. Pass complete (unmasked) token sequences of shape (batch, max_seq_len); masking is applied internally by the wrapper. The returned Losses.loss field is the scalar to call .backward() on:
# tokens shape: (batch, max_seq_len) — complete, unmasked sequences
losses = model(tokens)
losses.loss.backward()
Theoretical Basis
Masked language model objective:
L = -∑_{i ∈ masked} log P(x_i | x_masked)
Time-based masking: sample t ~ U(0, 1), then compute the mask ratio as schedule(t). For the linear schedule: schedule(t) = 1 - t. For the cosine schedule (from MaskGIT): schedule(t) = cos(t * π / 2).
MDLM loss weight: the per-sample weight is computed as:
w(t) = schedule'(t) / (1 - schedule(t))
This reweighting follows equation (10) of Sahoo et al. (2024) and upweights samples at higher noise levels where reconstruction is more difficult.
BERT augmentation: of the masked positions, 15% keep their original token unchanged (no-replace) and 10% are replaced with a uniformly random token. Both augmentations are gated by an independent coin flip per training step, following the stochastic strategy from the original BERT paper.