Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Snorkel team Snorkel Trainer Fit Multitask

From Leeroopedia
Knowledge Sources
Domains Training, Multi_Task_Learning, PyTorch
Last Updated 2026-02-14 20:00 GMT

Overview

Concrete tool for training MultitaskClassifier models with configurable optimization and monitoring, provided by the Snorkel library. This is the same Trainer class used for slice-aware training, documented here for the general multi-task classification context.

Description

The Trainer class provides the training loop for MultitaskClassifier. For the general multi-task context, the key difference from slice-aware usage is that tasks may have independent dataloaders and different label spaces.

Key features for multi-task training:

  • Shuffled batch scheduler: Interleaves batches from all tasks randomly
  • Sequential batch scheduler: Processes one task at a time
  • Per-task loss computation: Each task uses its configured loss function
  • Shared optimizer: Single optimizer updates all shared and task-specific parameters

Usage

Import this class when training a MultitaskClassifier. Same class as used for slice-aware training.

Code Reference

Source Location

  • Repository: snorkel
  • File: snorkel/classification/training/trainer.py
  • Lines: L108-586 (Trainer class), L35-105 (TrainerConfig)

Signature

class Trainer:
    def __init__(self, name: Optional[str] = None, **kwargs: Any) -> None:
        """
        Args:
            name: Trainer name.
            **kwargs: TrainerConfig overrides (n_epochs, lr, l2,
                grad_clip, optimizer, lr_scheduler, batch_scheduler,
                checkpointing, logging, log_writer, train_split,
                valid_split, test_split).
        """

    def fit(
        self,
        model: MultitaskClassifier,
        dataloaders: List[DictDataLoader],
    ) -> None:
        """
        Train multi-task model.

        Args:
            model: MultitaskClassifier to train.
            dataloaders: DictDataLoaders for train/valid/test splits.
        """

Import

from snorkel.classification import Trainer

I/O Contract

Inputs

Name Type Required Description
model MultitaskClassifier Yes Model to train in-place
dataloaders List[DictDataLoader] Yes Dataloaders with named splits
n_epochs int No Training epochs (default 1)
lr float No Learning rate (default 0.01)
batch_scheduler str No "shuffled" or "sequential" (default "shuffled")
checkpointing bool No Save best model (default False)
logging bool No Enable metric logging (default False)

Outputs

Name Type Description
Trained model MultitaskClassifier Model updated in-place
Checkpoints Files Optional saved model states
Logs Files Optional JSON or TensorBoard logs

Usage Examples

from snorkel.classification import Trainer

trainer = Trainer(
    n_epochs=20,
    lr=0.001,
    optimizer="adam",
    batch_scheduler="shuffled",
    checkpointing=True,
    checkpointer_config={"checkpoint_dir": "./checkpoints"},
    logging=True,
    log_writer="tensorboard",
)

trainer.fit(
    model=model,
    dataloaders=[train_dl, valid_dl, test_dl],
)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment