Implementation:Snorkel team Snorkel SliceAwareClassifier Init
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Data_Slicing, Multi_Task_Learning, PyTorch |
| Last Updated | 2026-02-14 20:00 GMT |
Overview
Concrete tool for building a slice-aware multi-task classifier with per-slice indicator and predictor heads, provided by the Snorkel library.
Description
The SliceAwareClassifier extends MultitaskClassifier to support slice-aware training. It takes a base neural network architecture and automatically creates:
- A base task with the provided architecture and prediction head
- Per-slice indicator tasks (binary: in/out of slice)
- Per-slice predictor tasks (classification on slice members)
- A master head using SliceCombinerModule for attention-based aggregation
Internally, it calls convert_to_slice_tasks to build the multi-task graph from the base task.
Usage
Import this class when you want to train a model that is aware of critical data slices. Provide a base neural network architecture and the list of slice names.
Code Reference
Source Location
- Repository: snorkel
- File: snorkel/slicing/sliceaware_classifier.py
- Lines: L16-91 (class L16-179, __init__ L46-91)
Signature
class SliceAwareClassifier(MultitaskClassifier):
def __init__(
self,
base_architecture: nn.Module,
head_dim: int,
slice_names: List[str],
input_data_key: str = "input_data",
task_name: str = "task",
scorer: Scorer = Scorer(metrics=["accuracy", "f1"]),
**multitask_kwargs: Any,
) -> None:
"""
Args:
base_architecture: Shared feature extractor (nn.Module).
head_dim: Output dimension of base_architecture.
slice_names: List of slice names matching SF names.
input_data_key: Key for input data in X_dict.
task_name: Name of the base classification task.
scorer: Evaluation scorer.
**multitask_kwargs: Passed to MultitaskClassifier.
"""
Import
from snorkel.slicing import SliceAwareClassifier
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| base_architecture | nn.Module | Yes | Shared feature extractor network |
| head_dim | int | Yes | Output dimension of base_architecture |
| slice_names | List[str] | Yes | Names of slices (must match SF names) |
| scorer | Scorer | No | Evaluation scorer (default: accuracy + f1) |
Outputs
| Name | Type | Description |
|---|---|---|
| SliceAwareClassifier instance | SliceAwareClassifier | Multi-task model with slice indicator/predictor/master heads |
| base_task | Task | The base classification task |
| slice_names | List[str] | Stored slice names |
Usage Examples
Initialize Slice-Aware Classifier
import torch.nn as nn
from snorkel.slicing import SliceAwareClassifier
# Define base architecture
base_architecture = nn.Sequential(
nn.Linear(100, 64),
nn.ReLU(),
nn.Linear(64, 32),
)
# Initialize with slice names matching your SFs
slice_names = ["sf_short", "sf_has_link", "sf_urgent"]
model = SliceAwareClassifier(
base_architecture=base_architecture,
head_dim=32, # Output dim of base_architecture
slice_names=slice_names,
)
print(f"Tasks: {list(model.task_names)}")
# Includes base task + indicator/predictor tasks for each slice
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment