Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:Snorkel team Snorkel Multitask Classification

From Leeroopedia
Knowledge Sources
Domains Multi_Task_Learning, Deep_Learning, Classification
Last Updated 2026-02-14 20:00 GMT

Overview

End-to-end process for building and training a multi-task neural network classifier using Snorkel's dynamic task graph architecture with shared modules, configurable operation sequences, and probabilistic label support.

Description

This workflow covers the construction and training of Snorkel's MultitaskClassifier, a PyTorch-based model that supports multiple tasks sharing a common set of neural network modules. Each Task defines a sequence of Operations that route data through a shared module pool, enabling flexible weight sharing across tasks. The system supports training with probabilistic (soft) labels via a custom cross-entropy loss, making it ideal for training on the output of the LabelModel. The Trainer orchestrates the multi-task training loop with configurable batch scheduling, checkpointing, logging, learning rate scheduling, and gradient clipping.

Usage

Execute this workflow when you need to train a neural network on one or more related tasks, especially when using probabilistic labels from Snorkel's weak supervision pipeline. This is appropriate for end-model training after label model inference, multi-task learning scenarios where tasks share representations, or any classification task that benefits from Snorkel's training infrastructure (checkpointing, logging, multi-device support).

Execution Steps

Step 1: Prepare DictDatasets

Create DictDataset instances for train, validation, and test splits. Each DictDataset wraps an X_dict (mapping field names to feature tensors) and a Y_dict (mapping task names to label tensors). Labels can be hard integer tensors or soft probability tensors from the LabelModel. Wrap datasets in DictDataLoader instances with appropriate batch sizes and sampling options.

Key considerations:

  • X_dict keys must match the input references used in Operation definitions
  • Y_dict keys must match the task names defined in the Task objects
  • Labels from LabelModel.predict_proba can be used directly as soft labels
  • Each dataset needs a name and split identifier for metrics reporting
  • DictDataLoader is a wrapper around PyTorch DataLoader with dict-based collation

Step 2: Define Tasks and Operations

Define one or more Task objects, each specifying a module pool (ModuleDict of PyTorch modules), an operation sequence (ordered list of Operations defining the forward pass), a scorer, and loss/output functions. Operations reference modules by name and specify their inputs as either raw input keys or outputs of previous operations, creating a dynamic computation graph.

Key considerations:

  • The module_pool is a shared ModuleDict; modules with the same name across tasks are shared
  • Each Operation specifies a module_name and input sources (from _input_ or previous operations)
  • The default loss function is cross_entropy; use cross_entropy_with_probs for soft labels
  • The default output function is softmax over the last dimension
  • Tasks are composable: multiple tasks can share body modules while having different heads

Step 3: Build MultitaskClassifier

Instantiate the MultitaskClassifier with the list of Tasks. The classifier merges all module pools into a single shared pool, extracts operation sequences per task, and configures itself for multi-device execution (DataParallel if enabled). The model is placed on the configured device (CPU or GPU).

Key considerations:

  • Modules with identical names across tasks are shared (stored once in the merged pool)
  • DataParallel wrapping is applied automatically when dataparallel=True and GPUs are available
  • The device configuration defaults to GPU 0 but falls back to CPU if unavailable
  • Task names must be unique across all provided tasks

Step 4: Configure and Run Training

Create a Trainer with a TrainerConfig specifying training hyperparameters (epochs, learning rate, optimizer, scheduler), then call trainer.fit() with the model and dataloaders. The Trainer handles the full training loop: iterating over tasks according to the batch scheduler, computing loss, backpropagating, updating weights, and periodically evaluating and checkpointing.

Key considerations:

  • The batch scheduler (sequential or shuffled) controls the order of task batches within each epoch
  • Checkpointing saves the best model based on validation metrics
  • Logging supports JSON file output or TensorBoard visualization
  • Gradient clipping prevents exploding gradients in deep networks
  • The optimizer choice (SGD, Adam, Adamax) and learning rate scheduler are fully configurable

Step 5: Evaluate and Predict

Use the trained MultitaskClassifier to make predictions and evaluate performance. The score method computes metrics across all tasks and datasets. The predict method returns probability distributions over classes for a given task and dataloader. Convert probabilities to hard predictions using probs_to_preds.

Key considerations:

  • score returns a dictionary of metrics keyed by task/dataset/split/metric
  • Results can be formatted as a pandas DataFrame for easy inspection
  • predict returns numpy arrays of class probabilities
  • Use the Scorer class to compute custom metrics beyond the defaults
  • The model supports evaluation on multiple dataloaders simultaneously

Execution Diagram

GitHub URL

Workflow Repository