Implementation:Neuml Txtai TokenDetection
| Knowledge Sources | |
|---|---|
| Domains | Model_Training, NLP |
| Last Updated | 2026-02-09 17:00 GMT |
Overview
The TokenDetection class implements ELECTRA-style replaced token detection training, combining a generator and discriminator model for efficient pre-training of language models.
Description
The TokenDetection class inherits from HuggingFace's PreTrainedModel and orchestrates the ELECTRA training procedure. It pairs a small generator model (which proposes replacement tokens for masked positions) with a larger discriminator model (which learns to detect which tokens were replaced). The discriminator's binary classification objective over all input tokens is far more sample-efficient than traditional masked language model training, which only learns from the small fraction of masked positions. The weight parameter controls the relative importance of the discriminator loss versus the generator loss.
Usage
Use the TokenDetection class when you need to pre-train or fine-tune a language model using the ELECTRA replaced token detection objective. This is particularly efficient for training smaller models that need to achieve strong performance with limited compute. It is used through txtai's training pipelines and requires both a generator and discriminator model architecture to be configured.
Code Reference
Source Location
- Repository: Neuml_Txtai
- File: src/python/txtai/models/tokendetection.py
- Lines: 1-122
Signature
class TokenDetection(PreTrainedModel):
def __init__(self, generator, discriminator, tokenizer, weight=50.0):
"""
Creates a TokenDetection training model.
Args:
generator: small masked language model that proposes replacement tokens
discriminator: larger model that detects replaced tokens
tokenizer: tokenizer shared by both models
weight: discriminator loss weight relative to generator loss (default: 50.0)
"""
def forward(self, input_ids=None, labels=None, attention_mask=None, token_type_ids=None):
"""
Forward pass combining generator and discriminator.
Args:
input_ids: input token IDs
labels: original token IDs for masked positions
attention_mask: attention mask tensor
token_type_ids: token type IDs for segment encoding
Returns:
loss combining generator MLM loss and weighted discriminator detection loss
"""
def save_pretrained(self, output, state_dict=None, **kwargs):
"""
Saves the discriminator model to disk.
Args:
output: output directory path
state_dict: optional state dict override
**kwargs: additional arguments passed to parent save method
"""
Import
from txtai.models import TokenDetection
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| generator | PreTrainedModel | Yes | Small masked language model used to generate replacement token candidates |
| discriminator | PreTrainedModel | Yes | Larger model trained to detect which tokens were replaced by the generator |
| tokenizer | PreTrainedTokenizer | Yes | Tokenizer shared by both generator and discriminator models |
| weight | float | No | Weight applied to discriminator loss relative to generator loss (default: 50.0) |
| input_ids | torch.Tensor | Yes (for forward) | Input token IDs tensor of shape (batch_size, seq_length) |
| labels | torch.Tensor | No | Original token IDs for computing generator loss at masked positions |
| attention_mask | torch.Tensor | No | Attention mask tensor of shape (batch_size, seq_length) |
| token_type_ids | torch.Tensor | No | Token type IDs for segment-level encoding |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor | Combined training loss: generator_loss + (weight * discriminator_loss) |
Usage Examples
Basic Usage
from transformers import AutoModelForMaskedLM, AutoModelForTokenClassification, AutoTokenizer
from txtai.models import TokenDetection
# Load generator (small) and discriminator (larger) models
generator = AutoModelForMaskedLM.from_pretrained("google/electra-small-generator")
discriminator = AutoModelForTokenClassification.from_pretrained("google/electra-small-discriminator")
tokenizer = AutoTokenizer.from_pretrained("google/electra-small-generator")
# Create the TokenDetection training model
model = TokenDetection(
generator=generator,
discriminator=discriminator,
tokenizer=tokenizer,
weight=50.0
)
# The model can now be used with HuggingFace Trainer
# for ELECTRA-style pre-training
print(f"Model ready for replaced token detection training")
Training with txtai
from txtai.pipeline import HFTrainer
# Configure ELECTRA-style training through txtai
trainer = HFTrainer()
# Train with token detection objective
model, tokenizer = trainer(
"google/electra-small-discriminator",
task="token-detection",
train="training_data.csv",
columns=("text",),
output="models/my-electra",
generator="google/electra-small-generator"
)
# Save the trained discriminator
model.save_pretrained("models/my-electra-trained")