Principle:Snorkel team Snorkel Multitask Model Training
| Knowledge Sources | |
|---|---|
| Domains | Multi_Task_Learning, Training, Deep_Learning |
| Last Updated | 2026-02-14 20:00 GMT |
Overview
A training procedure that jointly optimizes a neural network across multiple related tasks, sharing representations while allowing task-specific specialization.
Description
Multi-task Model Training is the optimization step for models with multiple task heads. In the context of Snorkel, this is used both for general multi-task classification and for slice-aware training. The trainer:
- Iterates over batches drawn from multiple dataloaders
- Computes per-task losses using task-specific loss functions
- Aggregates losses across tasks
- Applies gradient clipping and optimization
- Supports configurable batch scheduling (shuffled or sequential across tasks)
- Optionally logs metrics and checkpoints the best model
The training loop handles the complexity of multi-task optimization: different tasks may have different numbers of examples, different loss scales, and different convergence rates.
Usage
Use this principle when training any MultitaskClassifier or SliceAwareClassifier. Configure training hyperparameters (epochs, learning rate, optimizer) and optionally enable logging and checkpointing.
Theoretical Basis
Multi-task training minimizes the sum of task losses:
where are shared parameters and are task-specific parameters. The gradient with respect to shared parameters receives contributions from all tasks:
This encourages the shared representation to be useful across all tasks while allowing specialization in task-specific heads.