Implementation:Snorkel team Snorkel MultitaskClassifier Init
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Multi_Task_Learning, Model_Architecture, PyTorch |
| Last Updated | 2026-02-14 20:00 GMT |
Overview
Concrete tool for constructing a multi-task neural network classifier from Task definitions, provided by the Snorkel library.
Description
The MultitaskClassifier extends nn.Module and accepts a list of Task objects. During construction, it:
- Merges all module pools into a combined nn.ModuleDict
- Stores operation sequences, loss functions, output functions, and scorers per task
- Moves the model to the configured device
- Supports dynamic task addition via add_task()
The forward pass executes each tasks operation sequence, producing task-specific outputs for loss computation or prediction.
Usage
Import this class when building multi-task models from Task definitions. For slice-aware models, use SliceAwareClassifier instead (which extends this class).
Code Reference
Source Location
- Repository: snorkel
- File: snorkel/classification/multitask_classifier.py
- Lines: L51-534 (class), L81-113 (__init__)
Signature
class MultitaskClassifier(nn.Module):
def __init__(
self,
tasks: List[Task],
name: Optional[str] = None,
**kwargs: Any,
) -> None:
"""
Args:
tasks: List of Task objects defining the multi-task model.
name: Classifier name.
**kwargs: Merged into ClassifierConfig (device, dataparallel).
"""
def add_task(self, task: Task) -> None:
"""Add a task to an existing model."""
Import
from snorkel.classification import MultitaskClassifier
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| tasks | List[Task] | Yes | Task definitions with module pools and operation sequences |
| name | Optional[str] | No | Model name |
| device | int | No | CUDA device index (default 0) |
Outputs
| Name | Type | Description |
|---|---|---|
| MultitaskClassifier | MultitaskClassifier | nn.Module with combined module pool and task flows |
| task_names | Set[str] | Names of all registered tasks |
Usage Examples
import torch.nn as nn
from snorkel.classification import MultitaskClassifier, Task, Operation
from snorkel.analysis import Scorer
# Shared encoder
shared_encoder = nn.Linear(100, 64)
relu = nn.ReLU()
# Task 1: Sentiment
task1 = Task(
name="sentiment",
module_pool=nn.ModuleDict({
"encoder": shared_encoder,
"relu": relu,
"sentiment_head": nn.Linear(64, 2),
}),
op_sequence=[
Operation(module_name="encoder", inputs=[("_input_", "features")]),
Operation(module_name="relu", inputs=["encoder"]),
Operation(module_name="sentiment_head", inputs=["relu"]),
],
)
# Task 2: Topic
task2 = Task(
name="topic",
module_pool=nn.ModuleDict({
"encoder": shared_encoder, # Shared!
"relu": relu,
"topic_head": nn.Linear(64, 5),
}),
op_sequence=[
Operation(module_name="encoder", inputs=[("_input_", "features")]),
Operation(module_name="relu", inputs=["encoder"]),
Operation(module_name="topic_head", inputs=["relu"]),
],
)
# Build multi-task model
model = MultitaskClassifier(tasks=[task1, task2], name="my_mtl_model")
print(f"Tasks: {model.task_names}")
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment