Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Neuml Txtai TokenDetection

From Leeroopedia
Revision as of 16:03, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Neuml_Txtai_TokenDetection.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Model_Training, NLP
Last Updated 2026-02-09 17:00 GMT

Overview

The TokenDetection class implements ELECTRA-style replaced token detection training, combining a generator and discriminator model for efficient pre-training of language models.

Description

The TokenDetection class inherits from HuggingFace's PreTrainedModel and orchestrates the ELECTRA training procedure. It pairs a small generator model (which proposes replacement tokens for masked positions) with a larger discriminator model (which learns to detect which tokens were replaced). The discriminator's binary classification objective over all input tokens is far more sample-efficient than traditional masked language model training, which only learns from the small fraction of masked positions. The weight parameter controls the relative importance of the discriminator loss versus the generator loss.

Usage

Use the TokenDetection class when you need to pre-train or fine-tune a language model using the ELECTRA replaced token detection objective. This is particularly efficient for training smaller models that need to achieve strong performance with limited compute. It is used through txtai's training pipelines and requires both a generator and discriminator model architecture to be configured.

Code Reference

Source Location

Signature

class TokenDetection(PreTrainedModel):
    def __init__(self, generator, discriminator, tokenizer, weight=50.0):
        """
        Creates a TokenDetection training model.

        Args:
            generator: small masked language model that proposes replacement tokens
            discriminator: larger model that detects replaced tokens
            tokenizer: tokenizer shared by both models
            weight: discriminator loss weight relative to generator loss (default: 50.0)
        """

    def forward(self, input_ids=None, labels=None, attention_mask=None, token_type_ids=None):
        """
        Forward pass combining generator and discriminator.

        Args:
            input_ids: input token IDs
            labels: original token IDs for masked positions
            attention_mask: attention mask tensor
            token_type_ids: token type IDs for segment encoding

        Returns:
            loss combining generator MLM loss and weighted discriminator detection loss
        """

    def save_pretrained(self, output, state_dict=None, **kwargs):
        """
        Saves the discriminator model to disk.

        Args:
            output: output directory path
            state_dict: optional state dict override
            **kwargs: additional arguments passed to parent save method
        """

Import

from txtai.models import TokenDetection

I/O Contract

Inputs

Name Type Required Description
generator PreTrainedModel Yes Small masked language model used to generate replacement token candidates
discriminator PreTrainedModel Yes Larger model trained to detect which tokens were replaced by the generator
tokenizer PreTrainedTokenizer Yes Tokenizer shared by both generator and discriminator models
weight float No Weight applied to discriminator loss relative to generator loss (default: 50.0)
input_ids torch.Tensor Yes (for forward) Input token IDs tensor of shape (batch_size, seq_length)
labels torch.Tensor No Original token IDs for computing generator loss at masked positions
attention_mask torch.Tensor No Attention mask tensor of shape (batch_size, seq_length)
token_type_ids torch.Tensor No Token type IDs for segment-level encoding

Outputs

Name Type Description
loss torch.Tensor Combined training loss: generator_loss + (weight * discriminator_loss)

Usage Examples

Basic Usage

from transformers import AutoModelForMaskedLM, AutoModelForTokenClassification, AutoTokenizer
from txtai.models import TokenDetection

# Load generator (small) and discriminator (larger) models
generator = AutoModelForMaskedLM.from_pretrained("google/electra-small-generator")
discriminator = AutoModelForTokenClassification.from_pretrained("google/electra-small-discriminator")
tokenizer = AutoTokenizer.from_pretrained("google/electra-small-generator")

# Create the TokenDetection training model
model = TokenDetection(
    generator=generator,
    discriminator=discriminator,
    tokenizer=tokenizer,
    weight=50.0
)

# The model can now be used with HuggingFace Trainer
# for ELECTRA-style pre-training
print(f"Model ready for replaced token detection training")

Training with txtai

from txtai.pipeline import HFTrainer

# Configure ELECTRA-style training through txtai
trainer = HFTrainer()

# Train with token detection objective
model, tokenizer = trainer(
    "google/electra-small-discriminator",
    task="token-detection",
    train="training_data.csv",
    columns=("text",),
    output="models/my-electra",
    generator="google/electra-small-generator"
)

# Save the trained discriminator
model.save_pretrained("models/my-electra-trained")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment