Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:Snorkel team Snorkel Slice Aware Training

From Leeroopedia
Knowledge Sources
Domains Data_Slicing, Multi_Task_Learning, Model_Robustness
Last Updated 2026-02-14 20:00 GMT

Overview

End-to-end process for training a classifier that maintains high performance across critical data slices by using slice-specific indicator and predictor heads with attention-based combination.

Description

This workflow addresses the problem of models that perform well on average but fail on important subsets (slices) of the data. Users define slicing functions (SFs) that identify critical data subpopulations. These SFs are applied to produce a slice matrix indicating which data points belong to which slices. A SliceAwareClassifier extends the MultitaskClassifier with per-slice indicator heads (to detect slice membership) and per-slice predictor heads (to make predictions specialized to each slice). An attention-based SliceCombinerModule reweights the slice-specific representations to produce the final prediction, ensuring the model attends to relevant slice expertise for each data point.

Usage

Execute this workflow when your model needs to perform well on specific, identifiable subsets of data that may be underrepresented or have distinct characteristics. This is appropriate when you can define programmatic rules that identify these critical slices (e.g., short texts, rare categories, specific domains) and when aggregate metrics mask poor performance on important subpopulations. Currently supports binary classification tasks only.

Execution Steps

Step 1: Define Slicing Functions

Create slicing functions that identify critical data subsets. Each SF is a function that takes a data point and returns a boolean indicating slice membership. SFs use the same decorator pattern and preprocessor infrastructure as labeling functions, inheriting from LabelingFunction internally.

Key considerations:

  • Each SF should target a meaningful data subpopulation where model performance matters
  • SFs return boolean values (True/False) indicating slice membership
  • A data point can belong to multiple slices simultaneously
  • Use NlpSlicingFunction for text data requiring spaCy preprocessing
  • Name SFs descriptively for interpretable monitoring

Step 2: Apply Slicing Functions

Execute all slicing functions across the dataset using PandasSFApplier (or Dask/Spark variants) to produce a slice matrix S. This is a numpy recarray of shape [n_datapoints, n_slices] where each column is named after its SF and contains binary indicators.

Key considerations:

  • The applier reuses the same infrastructure as the LF applier
  • The output is a numpy recarray with named fields corresponding to SF names
  • A "base" slice is automatically added to represent the full dataset
  • The slice matrix will be used to create slice-specific labels in the dataloader

Step 3: Initialize SliceAwareClassifier

Create a SliceAwareClassifier by providing a base neural network architecture, the output dimension of that architecture, and the list of slice names. The classifier automatically constructs slice-specific tasks: for each slice, it creates an indicator task (binary: in-slice or not) and a predictor task (the main classification, trained only on in-slice examples). A master head with a SliceCombinerModule uses attention to combine slice-specific representations.

Key considerations:

  • The base_architecture module must output a fixed-size representation vector
  • head_dim must match the output dimensionality of base_architecture
  • The classifier automatically adds a "base" slice for the overall dataset
  • Each slice generates two tasks: an indicator task (uses F1 metric) and a predictor task
  • The master head uses learned attention weights to combine slice representations

Step 4: Prepare Slice Aware DataLoaders

Convert standard DictDatasets into slice-aware DictDataLoaders using the make_slice_dataloader method. This augments the dataset with per-slice indicator labels and per-slice prediction labels. Indicator labels mark whether each example is in each slice. Prediction labels are the original labels masked to -1 for out-of-slice examples, so slice predictors only train on relevant data.

Key considerations:

  • The base task labels must already exist in the dataset Y_dict
  • The slice matrix S from Step 2 is used to create per-slice labels
  • Out-of-slice prediction labels are set to -1 (ignored during training)
  • Create separate dataloaders for train, validation, and test splits

Step 5: Train the Slice Aware Model

Train the SliceAwareClassifier using the Trainer class with the slice-aware dataloaders. The Trainer handles the multi-task training loop, cycling through all slice tasks (indicators and predictors) plus the master head. It supports checkpointing, logging, learning rate scheduling, and gradient clipping.

Key considerations:

  • Use the standard Trainer with TrainerConfig for training configuration
  • All slice tasks are trained jointly in a multi-task fashion
  • The batch scheduler controls the order of task sampling (sequential or shuffled)
  • Monitor both per-slice and overall metrics during training
  • Configure checkpointing to save the best model based on validation performance

Step 6: Evaluate Slice Performance

Use the score_slices method to evaluate model performance on each slice independently. This remaps slice-specific prediction labels to the base task head for evaluation, giving per-slice accuracy and F1 metrics using the single master prediction head rather than individual slice heads.

Key considerations:

  • score_slices uses the master head (not individual slice heads) for all predictions
  • Results can be returned as a dictionary or pandas DataFrame
  • Monitor slice-specific metrics to ensure no critical slice is underperforming
  • Compare against a non-slice-aware baseline to quantify improvement

Execution Diagram

GitHub URL

Workflow Repository