Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Snorkel team Snorkel SliceAwareClassifier Init

From Leeroopedia
Knowledge Sources
Domains Data_Slicing, Multi_Task_Learning, PyTorch
Last Updated 2026-02-14 20:00 GMT

Overview

Concrete tool for building a slice-aware multi-task classifier with per-slice indicator and predictor heads, provided by the Snorkel library.

Description

The SliceAwareClassifier extends MultitaskClassifier to support slice-aware training. It takes a base neural network architecture and automatically creates:

  • A base task with the provided architecture and prediction head
  • Per-slice indicator tasks (binary: in/out of slice)
  • Per-slice predictor tasks (classification on slice members)
  • A master head using SliceCombinerModule for attention-based aggregation

Internally, it calls convert_to_slice_tasks to build the multi-task graph from the base task.

Usage

Import this class when you want to train a model that is aware of critical data slices. Provide a base neural network architecture and the list of slice names.

Code Reference

Source Location

  • Repository: snorkel
  • File: snorkel/slicing/sliceaware_classifier.py
  • Lines: L16-91 (class L16-179, __init__ L46-91)

Signature

class SliceAwareClassifier(MultitaskClassifier):
    def __init__(
        self,
        base_architecture: nn.Module,
        head_dim: int,
        slice_names: List[str],
        input_data_key: str = "input_data",
        task_name: str = "task",
        scorer: Scorer = Scorer(metrics=["accuracy", "f1"]),
        **multitask_kwargs: Any,
    ) -> None:
        """
        Args:
            base_architecture: Shared feature extractor (nn.Module).
            head_dim: Output dimension of base_architecture.
            slice_names: List of slice names matching SF names.
            input_data_key: Key for input data in X_dict.
            task_name: Name of the base classification task.
            scorer: Evaluation scorer.
            **multitask_kwargs: Passed to MultitaskClassifier.
        """

Import

from snorkel.slicing import SliceAwareClassifier

I/O Contract

Inputs

Name Type Required Description
base_architecture nn.Module Yes Shared feature extractor network
head_dim int Yes Output dimension of base_architecture
slice_names List[str] Yes Names of slices (must match SF names)
scorer Scorer No Evaluation scorer (default: accuracy + f1)

Outputs

Name Type Description
SliceAwareClassifier instance SliceAwareClassifier Multi-task model with slice indicator/predictor/master heads
base_task Task The base classification task
slice_names List[str] Stored slice names

Usage Examples

Initialize Slice-Aware Classifier

import torch.nn as nn
from snorkel.slicing import SliceAwareClassifier

# Define base architecture
base_architecture = nn.Sequential(
    nn.Linear(100, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
)

# Initialize with slice names matching your SFs
slice_names = ["sf_short", "sf_has_link", "sf_urgent"]

model = SliceAwareClassifier(
    base_architecture=base_architecture,
    head_dim=32,  # Output dim of base_architecture
    slice_names=slice_names,
)

print(f"Tasks: {list(model.task_names)}")
# Includes base task + indicator/predictor tasks for each slice

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment