Implementation:Snorkel team Snorkel Trainer Fit Multitask
| Knowledge Sources | |
|---|---|
| Domains | Training, Multi_Task_Learning, PyTorch |
| Last Updated | 2026-02-14 20:00 GMT |
Overview
Concrete tool for training MultitaskClassifier models with configurable optimization and monitoring, provided by the Snorkel library. This is the same Trainer class used for slice-aware training, documented here for the general multi-task classification context.
Description
The Trainer class provides the training loop for MultitaskClassifier. For the general multi-task context, the key difference from slice-aware usage is that tasks may have independent dataloaders and different label spaces.
Key features for multi-task training:
- Shuffled batch scheduler: Interleaves batches from all tasks randomly
- Sequential batch scheduler: Processes one task at a time
- Per-task loss computation: Each task uses its configured loss function
- Shared optimizer: Single optimizer updates all shared and task-specific parameters
Usage
Import this class when training a MultitaskClassifier. Same class as used for slice-aware training.
Code Reference
Source Location
- Repository: snorkel
- File: snorkel/classification/training/trainer.py
- Lines: L108-586 (Trainer class), L35-105 (TrainerConfig)
Signature
class Trainer:
def __init__(self, name: Optional[str] = None, **kwargs: Any) -> None:
"""
Args:
name: Trainer name.
**kwargs: TrainerConfig overrides (n_epochs, lr, l2,
grad_clip, optimizer, lr_scheduler, batch_scheduler,
checkpointing, logging, log_writer, train_split,
valid_split, test_split).
"""
def fit(
self,
model: MultitaskClassifier,
dataloaders: List[DictDataLoader],
) -> None:
"""
Train multi-task model.
Args:
model: MultitaskClassifier to train.
dataloaders: DictDataLoaders for train/valid/test splits.
"""
Import
from snorkel.classification import Trainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | MultitaskClassifier | Yes | Model to train in-place |
| dataloaders | List[DictDataLoader] | Yes | Dataloaders with named splits |
| n_epochs | int | No | Training epochs (default 1) |
| lr | float | No | Learning rate (default 0.01) |
| batch_scheduler | str | No | "shuffled" or "sequential" (default "shuffled") |
| checkpointing | bool | No | Save best model (default False) |
| logging | bool | No | Enable metric logging (default False) |
Outputs
| Name | Type | Description |
|---|---|---|
| Trained model | MultitaskClassifier | Model updated in-place |
| Checkpoints | Files | Optional saved model states |
| Logs | Files | Optional JSON or TensorBoard logs |
Usage Examples
from snorkel.classification import Trainer
trainer = Trainer(
n_epochs=20,
lr=0.001,
optimizer="adam",
batch_scheduler="shuffled",
checkpointing=True,
checkpointer_config={"checkpoint_dir": "./checkpoints"},
logging=True,
log_writer="tensorboard",
)
trainer.fit(
model=model,
dataloaders=[train_dl, valid_dl, test_dl],
)