Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Snorkel team Snorkel MultitaskClassifier Init

From Leeroopedia
Knowledge Sources
Domains Multi_Task_Learning, Model_Architecture, PyTorch
Last Updated 2026-02-14 20:00 GMT

Overview

Concrete tool for constructing a multi-task neural network classifier from Task definitions, provided by the Snorkel library.

Description

The MultitaskClassifier extends nn.Module and accepts a list of Task objects. During construction, it:

  • Merges all module pools into a combined nn.ModuleDict
  • Stores operation sequences, loss functions, output functions, and scorers per task
  • Moves the model to the configured device
  • Supports dynamic task addition via add_task()

The forward pass executes each tasks operation sequence, producing task-specific outputs for loss computation or prediction.

Usage

Import this class when building multi-task models from Task definitions. For slice-aware models, use SliceAwareClassifier instead (which extends this class).

Code Reference

Source Location

  • Repository: snorkel
  • File: snorkel/classification/multitask_classifier.py
  • Lines: L51-534 (class), L81-113 (__init__)

Signature

class MultitaskClassifier(nn.Module):
    def __init__(
        self,
        tasks: List[Task],
        name: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        """
        Args:
            tasks: List of Task objects defining the multi-task model.
            name: Classifier name.
            **kwargs: Merged into ClassifierConfig (device, dataparallel).
        """

    def add_task(self, task: Task) -> None:
        """Add a task to an existing model."""

Import

from snorkel.classification import MultitaskClassifier

I/O Contract

Inputs

Name Type Required Description
tasks List[Task] Yes Task definitions with module pools and operation sequences
name Optional[str] No Model name
device int No CUDA device index (default 0)

Outputs

Name Type Description
MultitaskClassifier MultitaskClassifier nn.Module with combined module pool and task flows
task_names Set[str] Names of all registered tasks

Usage Examples

import torch.nn as nn
from snorkel.classification import MultitaskClassifier, Task, Operation
from snorkel.analysis import Scorer

# Shared encoder
shared_encoder = nn.Linear(100, 64)
relu = nn.ReLU()

# Task 1: Sentiment
task1 = Task(
    name="sentiment",
    module_pool=nn.ModuleDict({
        "encoder": shared_encoder,
        "relu": relu,
        "sentiment_head": nn.Linear(64, 2),
    }),
    op_sequence=[
        Operation(module_name="encoder", inputs=[("_input_", "features")]),
        Operation(module_name="relu", inputs=["encoder"]),
        Operation(module_name="sentiment_head", inputs=["relu"]),
    ],
)

# Task 2: Topic
task2 = Task(
    name="topic",
    module_pool=nn.ModuleDict({
        "encoder": shared_encoder,  # Shared!
        "relu": relu,
        "topic_head": nn.Linear(64, 5),
    }),
    op_sequence=[
        Operation(module_name="encoder", inputs=[("_input_", "features")]),
        Operation(module_name="relu", inputs=["encoder"]),
        Operation(module_name="topic_head", inputs=["relu"]),
    ],
)

# Build multi-task model
model = MultitaskClassifier(tasks=[task1, task2], name="my_mtl_model")
print(f"Tasks: {model.task_names}")

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment