Implementation:Snorkel team Snorkel SliceAwareClassifier Make Slice Dataloader
| Knowledge Sources | |
|---|---|
| Domains | Data_Slicing, Data_Preparation |
| Last Updated | 2026-02-14 20:00 GMT |
Overview
Concrete tool for creating slice-aware dataloaders with indicator and prediction labels from a dataset and slice matrix, provided by the Snorkel library.
Description
The SliceAwareClassifier.make_slice_dataloader() method combines a DictDataset with a slice matrix (from PandasSFApplier) to produce a DictDataLoader augmented with slice-specific labels. Internally, it calls add_slice_labels() which modifies the datasets Y_dict in-place, adding two label tensors per slice:
- task_slice:{name}_ind - indicator labels
- task_slice:{name}_pred - masked prediction labels
Usage
Use this method after initializing a SliceAwareClassifier and obtaining slice matrices. Create separate dataloaders for train/valid/test splits.
Code Reference
Source Location
- Repository: snorkel
- File: snorkel/slicing/sliceaware_classifier.py (make_slice_dataloader L93-125), snorkel/slicing/utils.py (add_slice_labels L15-55)
Signature
class SliceAwareClassifier(MultitaskClassifier):
def make_slice_dataloader(
self,
dataset: DictDataset,
S: np.recarray,
**dataloader_kwargs: Any,
) -> DictDataLoader:
"""
Create DictDataLoader with slice labels.
Args:
dataset: Base DictDataset with features and labels.
S: [n_examples, n_slices] slice recarray from PandasSFApplier.
**dataloader_kwargs: Passed to DictDataLoader (batch_size, shuffle, etc.).
Returns:
DictDataLoader with slice-specific labels added to Y_dict.
"""
def add_slice_labels(
dataloader: DictDataLoader,
base_task: Task,
S: np.recarray,
) -> None:
"""
Modify dataloader in-place, adding indicator and prediction labels per slice.
"""
Import
from snorkel.slicing import SliceAwareClassifier
from snorkel.slicing.utils import add_slice_labels
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dataset | DictDataset | Yes | Base dataset with X_dict and Y_dict |
| S | np.recarray | Yes | Slice matrix from PandasSFApplier |
| **dataloader_kwargs | Any | No | batch_size, shuffle, num_workers, etc. |
Outputs
| Name | Type | Description |
|---|---|---|
| dataloader | DictDataLoader | Augmented dataloader with slice indicator/prediction labels in Y_dict |
Usage Examples
import torch
from snorkel.classification import DictDataset
# Create base dataset
X_dict = {"input_data": torch.randn(100, 50)}
Y_dict = {"task": torch.randint(0, 2, (100,))}
dataset = DictDataset(name="train_data", split="train", X_dict=X_dict, Y_dict=Y_dict)
# S from PandasSFApplier
# model is a SliceAwareClassifier instance
train_dl = model.make_slice_dataloader(
dataset=dataset,
S=S_train,
batch_size=32,
shuffle=True,
)
# The dataloader now contains labels like:
# "task_slice:sf_short_ind", "task_slice:sf_short_pred", etc.