Principle:Neuml Txtai Replaced Token Detection
| Knowledge Sources | |
|---|---|
| Domains | Model_Training, Pre_Training |
| Last Updated | 2026-02-09 17:00 GMT |
Overview
ELECTRA-style replaced token detection (RTD) pre-training uses a generator-discriminator architecture where a generator replaces tokens and a discriminator learns to detect replaced tokens, achieving greater sample efficiency than masked language modeling.
Description
Masked language modeling (MLM), as used in BERT, trains a model to predict randomly masked tokens in the input. While effective, MLM only provides a training signal from the small percentage of tokens that are masked (typically 15%), leaving the majority of input tokens unused for learning. Replaced token detection (RTD), introduced by the ELECTRA framework, addresses this inefficiency by providing a training signal from every single token in the input sequence.
RTD employs a two-model architecture consisting of a small generator and a larger discriminator. The generator is a small masked language model that proposes replacement tokens for randomly selected positions in the input. The discriminator then receives the corrupted input, where some original tokens have been swapped with generator-proposed alternatives, and must predict for every token position whether it contains the original token or a replacement. Because the generator produces plausible replacements rather than random noise, the discrimination task is genuinely challenging and forces the model to develop a deep and nuanced understanding of language at every token position.
In txtai, the RTD training pipeline implements this generator-discriminator setup for pre-training transformer models on custom data. The generator and discriminator share the same tokenizer and are trained jointly, with the generator's MLM loss and the discriminator's binary classification loss combined into a single objective. After pre-training, the generator is discarded and only the discriminator is retained for downstream fine-tuning on tasks such as classification, named entity recognition, or sentence similarity. Because the discriminator receives a training signal from all tokens rather than just the masked subset, RTD models achieve comparable or superior performance to MLM-based models while requiring significantly less compute and training data to reach the same quality level.
Usage
Apply replaced token detection pre-training when you need to train a language model from scratch or continue pre-training on domain-specific data with limited compute budgets, or when sample efficiency is a priority because labeled or unlabeled domain data is scarce. It is especially beneficial for building domain-adapted models in specialized fields like biomedical, legal, or scientific text where large-scale corpora may not be readily available.
Theoretical Basis
1. Generator-discriminator architecture -- A small generator network (typically 1/4 to 1/3 the size of the discriminator) is trained with MLM to produce plausible token replacements, while the larger discriminator network is trained on the binary classification task of identifying which tokens in the sequence have been replaced by the generator.
2. RTD loss function -- The discriminator loss is the sum of binary cross-entropy losses over all input tokens, where each token is classified as original or replaced: L = -sum over t of [y_t * log(D(x_t)) + (1 - y_t) * log(1 - D(x_t))], with y_t being 1 when token t is original and 0 when it has been replaced.
3. Sample efficiency vs MLM -- MLM computes loss only on the approximately 15% of masked tokens, while RTD computes loss on 100% of tokens, providing roughly 6-7 times more training signal per example. Empirically, ELECTRA matches BERT pre-training performance with approximately one quarter of the compute budget.
4. Discriminator fine-tuning -- After pre-training, the generator is discarded and the discriminator serves as the pre-trained encoder for downstream tasks. The discriminator's learned representations are directly usable for classification, sequence labeling, and other NLP tasks via standard fine-tuning with task-specific output heads appended to the encoder.