Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Curator BaseClassifierStage

From Leeroopedia
Revision as of 13:19, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/NVIDIA_NeMo_Curator_BaseClassifierStage.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Knowledge Sources
Domains NLP, Classification, Distributed Computing, Deep Learning
Last Updated 2026-02-14 00:00 GMT

Overview

Provides the base classifier infrastructure for NeMo Curator: a DeBERTa model wrapper, a classifier-specific model inference stage, and a reusable DistributedDataClassifier composite stage that chains tokenization, model inference, and optional filtering for distributed text classification.

Description

This module defines three key classes that form the foundation of NeMo Curator's DeBERTa-based classification system:

  • Deberta - A PyTorch module that wraps a pretrained transformer model (loaded via HuggingFace's AutoModel) and adds a classification head consisting of a dropout layer and a fully connected linear layer. The forward pass extracts the CLS token representation, applies dropout and the linear layer, then returns softmax probabilities. It also integrates PyTorchModelHubMixin to support loading from and saving to the HuggingFace Hub.
  • ClassifierModelStage - Extends ModelStage to handle the full lifecycle of classifier inference: loading the DeBERTa model and its configuration onto the GPU, processing model output tensors into label predictions (via argmax) and probability scores, and constructing the final output DataFrame. It supports configurable label and score fields, sequence ordering for performance optimization, and autocast for faster inference with minor accuracy trade-offs.
  • DistributedDataClassifier - A dataclass-based CompositeStage that orchestrates the classification pipeline. It decomposes into a TokenizerStage (for text tokenization with configurable max characters, max sequence length, and padding side), followed by a ClassifierModelStage (for inference), and optionally a Filter stage (when filter_by categories are specified). This composite stage is the parent class for concrete classifiers like QualityClassifier, DomainClassifier, and ContentTypeClassifier.

Usage

Use DistributedDataClassifier as a base class when building a new DeBERTa-based text classifier. Subclass it and configure the model identifier, default field names, max sequence length, and max character limits for the specific classification task. For direct use, instantiate it with a HuggingFace model identifier that follows the expected DeBERTa architecture with a classification head.

Code Reference

Source Location

  • Repository: NeMo-Curator
  • File: nemo_curator/stages/text/classifiers/base.py
  • Lines: 1-229

Signature

class Deberta(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config: dataclass): ...
    def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: ...

class ClassifierModelStage(ModelStage):
    def __init__(
        self,
        model_identifier: str,
        cache_dir: str | None = None,
        label_field: str = "preds",
        score_field: str | None = None,
        model_inference_batch_size: int = 256,
        has_seq_order: bool = True,
        padding_side: Literal["left", "right"] = "right",
        autocast: bool = True,
    ): ...

@dataclass(kw_only=True)
class DistributedDataClassifier(CompositeStage[DocumentBatch, DocumentBatch]):
    model_identifier: str
    cache_dir: str | None = None
    label_field: str = "preds"
    score_field: str | None = None
    text_field: str = "text"
    filter_by: list[str] | None = None
    max_chars: int | None = None
    max_seq_length: int | None = None
    padding_side: Literal["left", "right"] = "right"
    sort_by_length: bool = True
    model_inference_batch_size: int = 256
    autocast: bool = True

Import

from nemo_curator.stages.text.classifiers.base import (
    Deberta,
    ClassifierModelStage,
    DistributedDataClassifier,
)

I/O Contract

Inputs

Name Type Required Description
model_identifier str Yes HuggingFace model identifier for the DeBERTa classifier
cache_dir str or None No HuggingFace cache directory for storing downloaded models
label_field str No Name of the output prediction column (default: "preds")
score_field str or None No Name of the output probability column; if None, probabilities are not retained
text_field str No Name of the text field in the input DocumentBatch (default: "text")
filter_by list[str] or None No List of category labels to keep; if provided, adds a Filter stage
max_chars int or None No Maximum characters to feed to the tokenizer; None means no truncation
max_seq_length int or None No Maximum sequence length for the tokenizer (default: 512 via base config)
padding_side Literal["left", "right"] No Side to pad tokens (default: "right")
sort_by_length bool No Whether to sort input data by token length for performance (default: True)
model_inference_batch_size int No Batch size for model inference (default: 256)
autocast bool No Use autocast for faster inference at minor accuracy cost (default: True)

Outputs

Name Type Description
DocumentBatch DocumentBatch The input batch augmented with prediction label column and optionally a probability score column
label_field column str (per row) The predicted category label for each document
score_field column list[float] (per row) Softmax probability vector across all classes (only if score_field is specified)

Usage Examples

Basic Usage

from nemo_curator.stages.text.classifiers.base import DistributedDataClassifier

# Create a classifier with a custom HuggingFace DeBERTa model
classifier = DistributedDataClassifier(
    model_identifier="nvidia/quality-classifier-deberta",
    label_field="quality_pred",
    text_field="text",
    max_chars=6000,
    max_seq_length=1024,
    model_inference_batch_size=256,
)

With Filtering

# Classify and filter to keep only "High" and "Medium" quality documents
classifier = DistributedDataClassifier(
    model_identifier="nvidia/quality-classifier-deberta",
    label_field="quality_pred",
    score_field="quality_prob",
    filter_by=["High", "Medium"],
    max_chars=6000,
    max_seq_length=1024,
)

Subclassing for a New Classifier

from nemo_curator.stages.text.classifiers.base import DistributedDataClassifier

class MyCustomClassifier(DistributedDataClassifier):
    def __init__(self, cache_dir=None, label_field="my_pred", **kwargs):
        super().__init__(
            model_identifier="my-org/my-deberta-classifier",
            cache_dir=cache_dir,
            label_field=label_field,
            max_chars=3000,
            max_seq_length=512,
            **kwargs,
        )

Internal Architecture

Pipeline Decomposition

The DistributedDataClassifier decomposes into an ordered list of processing stages:

  1. TokenizerStage - Tokenizes the text field using the HuggingFace tokenizer associated with the model, producing input_ids and attention_mask tensors.
  2. ClassifierModelStage - Runs the Deberta model on the tokenized batches, producing label predictions and probability scores.
  3. Filter (optional) - If filter_by is specified, applies a category-based filter using the filter_by_category method that checks if the predicted label is in the allowed set.

Deberta Model Architecture

The Deberta class adds a classification head on top of a pretrained transformer:

# Internal architecture
self.model = AutoModel.from_pretrained(config["base_model"])  # Pretrained transformer
self.dropout = nn.Dropout(config["fc_dropout"])                # Regularization
self.fc = nn.Linear(self.model.config.hidden_size, len(config["id2label"]))  # Classification head

The forward pass uses the CLS token (position 0) from the last hidden state, applies dropout and the linear layer, then softmax normalization.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment