Heuristic:Snorkel team Snorkel Binary Only Slicing
| Knowledge Sources | |
|---|---|
| Domains | Slicing, Multi_Task_Learning |
| Last Updated | 2026-02-14 21:00 GMT |
Overview
The entire slice-aware classification system (SliceAwareClassifier and SliceCombinerModule) is limited to binary classification only. Multi-class slicing is not supported and will raise NotImplementedError.
Description
Snorkel's slice-aware classification pipeline hardcodes binary output dimensions in multiple locations: the SliceAwareClassifier prediction head uses `nn.Linear(head_dim, 2)`, and the SliceCombinerModule validates that predictor outputs have shape `[..., 2]`. Attempting to use slicing with more than 2 classes will raise a `NotImplementedError`.
Usage
Be aware of this limitation when designing slice-aware models. If your task has more than 2 classes, you cannot use the SliceAwareClassifier. Instead, use the base MultitaskClassifier directly and implement your own slice combination logic.
The Insight (Rule of Thumb)
- Action: Only use SliceAwareClassifier and SliceCombinerModule for binary classification tasks.
- Value: Output dimension is hardcoded to 2 in the prediction head.
- Trade-off: Multi-class support is not available. For multi-class tasks, fall back to the base MultitaskClassifier without slice-aware combination.
- Workaround: For multi-class problems, you can still use SlicingFunctions for analysis (via `score_slices`), but cannot use the attention-based slice combination mechanism.
Reasoning
The slice combination mechanism uses an attention-weighted sum of per-slice prediction features to produce a final prediction. The current implementation assumes binary classification in the attention weight computation and the final output shape. This is documented as a known limitation in the code comments.
Code evidence from `sliceaware_classifier.py:19,61`:
NOTE: This model currently only supports binary classification.
...
"prediction_head": nn.Linear(head_dim, 2),
Code evidence from `slice_combiner.py:19,94-101`:
NOTE: This module currently only handles binary labels.
...
if predictor_outputs[0].shape[1] > 2:
raise NotImplementedError(
"SliceCombiner does not support more than 2 classes yet."
)
elif predictor_outputs[0].shape[1] < 2:
raise NotImplementedError(
"SliceCombiner currently requires output shape [..., 2] for predictor heads."
)