Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Snorkel team Snorkel Trainer Fit

From Leeroopedia
Knowledge Sources
Domains Training, Multi_Task_Learning, PyTorch
Last Updated 2026-02-14 20:00 GMT

Overview

Concrete tool for training multi-task classifiers with configurable optimization, logging, and checkpointing, provided by the Snorkel library.

Description

The Trainer class provides a complete training loop for MultitaskClassifier and its subclasses (including SliceAwareClassifier). It handles:

  • Optimizer creation (SGD, Adam, Adamax)
  • Learning rate scheduling (constant, linear, exponential, step)
  • Batch scheduling across multiple tasks (shuffled or sequential)
  • Gradient clipping
  • Optional metric logging (JSON or TensorBoard)
  • Optional model checkpointing (saving best model by metric)

Usage

Import this class when training any Snorkel MultitaskClassifier. Configure via keyword arguments that are merged into TrainerConfig.

Code Reference

Source Location

  • Repository: snorkel
  • File: snorkel/classification/training/trainer.py
  • Lines: L108-586 (class), L138-142 (__init__), L144-248 (fit)

Signature

class Trainer:
    def __init__(self, name: Optional[str] = None, **kwargs: Any) -> None:
        """
        Args:
            name: Trainer name.
            **kwargs: Merged into TrainerConfig:
                n_epochs (int, default 1), lr (float, default 0.01),
                l2 (float, default 0.0), grad_clip (float, default 1.0),
                optimizer (str, default "adam"),
                lr_scheduler (str, default "constant"),
                batch_scheduler (str, default "shuffled"),
                checkpointing (bool, default False),
                logging (bool, default False),
                log_writer (str, default "tensorboard").
        """

    def fit(
        self,
        model: MultitaskClassifier,
        dataloaders: List[DictDataLoader],
    ) -> None:
        """
        Train model on provided dataloaders.

        Args:
            model: MultitaskClassifier (or SliceAwareClassifier) to train.
            dataloaders: List of DictDataLoaders with train/valid/test splits.
        """

Import

from snorkel.classification import Trainer

I/O Contract

Inputs

Name Type Required Description
model MultitaskClassifier Yes Model to train (in-place modification)
dataloaders List[DictDataLoader] Yes Train/valid/test dataloaders
n_epochs int No Training epochs (default 1)
lr float No Learning rate (default 0.01)
optimizer str No "sgd", "adam", or "adamax" (default "adam")
checkpointing bool No Save best model checkpoint (default False)

Outputs

Name Type Description
Trained model MultitaskClassifier Model updated in-place with trained parameters
Checkpoints Files Optional model checkpoints saved to disk
Logs Files Optional training logs (JSON or TensorBoard)

Usage Examples

from snorkel.classification import Trainer

# Train slice-aware model
trainer = Trainer(
    n_epochs=10,
    lr=0.001,
    l2=0.01,
    optimizer="adam",
    lr_scheduler="step",
    progress_bar=True,
)

trainer.fit(
    model=model,
    dataloaders=[train_dl, valid_dl],
)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment