Implementation:Snorkel team Snorkel Trainer Fit
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Training, Multi_Task_Learning, PyTorch |
| Last Updated | 2026-02-14 20:00 GMT |
Overview
Concrete tool for training multi-task classifiers with configurable optimization, logging, and checkpointing, provided by the Snorkel library.
Description
The Trainer class provides a complete training loop for MultitaskClassifier and its subclasses (including SliceAwareClassifier). It handles:
- Optimizer creation (SGD, Adam, Adamax)
- Learning rate scheduling (constant, linear, exponential, step)
- Batch scheduling across multiple tasks (shuffled or sequential)
- Gradient clipping
- Optional metric logging (JSON or TensorBoard)
- Optional model checkpointing (saving best model by metric)
Usage
Import this class when training any Snorkel MultitaskClassifier. Configure via keyword arguments that are merged into TrainerConfig.
Code Reference
Source Location
- Repository: snorkel
- File: snorkel/classification/training/trainer.py
- Lines: L108-586 (class), L138-142 (__init__), L144-248 (fit)
Signature
class Trainer:
def __init__(self, name: Optional[str] = None, **kwargs: Any) -> None:
"""
Args:
name: Trainer name.
**kwargs: Merged into TrainerConfig:
n_epochs (int, default 1), lr (float, default 0.01),
l2 (float, default 0.0), grad_clip (float, default 1.0),
optimizer (str, default "adam"),
lr_scheduler (str, default "constant"),
batch_scheduler (str, default "shuffled"),
checkpointing (bool, default False),
logging (bool, default False),
log_writer (str, default "tensorboard").
"""
def fit(
self,
model: MultitaskClassifier,
dataloaders: List[DictDataLoader],
) -> None:
"""
Train model on provided dataloaders.
Args:
model: MultitaskClassifier (or SliceAwareClassifier) to train.
dataloaders: List of DictDataLoaders with train/valid/test splits.
"""
Import
from snorkel.classification import Trainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | MultitaskClassifier | Yes | Model to train (in-place modification) |
| dataloaders | List[DictDataLoader] | Yes | Train/valid/test dataloaders |
| n_epochs | int | No | Training epochs (default 1) |
| lr | float | No | Learning rate (default 0.01) |
| optimizer | str | No | "sgd", "adam", or "adamax" (default "adam") |
| checkpointing | bool | No | Save best model checkpoint (default False) |
Outputs
| Name | Type | Description |
|---|---|---|
| Trained model | MultitaskClassifier | Model updated in-place with trained parameters |
| Checkpoints | Files | Optional model checkpoints saved to disk |
| Logs | Files | Optional training logs (JSON or TensorBoard) |
Usage Examples
from snorkel.classification import Trainer
# Train slice-aware model
trainer = Trainer(
n_epochs=10,
lr=0.001,
l2=0.01,
optimizer="adam",
lr_scheduler="step",
progress_bar=True,
)
trainer.fit(
model=model,
dataloaders=[train_dl, valid_dl],
)
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment