Heuristic:Snorkel team Snorkel DataParallel Default Behavior
| Knowledge Sources | |
|---|---|
| Domains | Multi_Task_Learning, Optimization |
| Last Updated | 2026-02-14 21:00 GMT |
Overview
The MultitaskClassifier enables `nn.DataParallel` by default, wrapping all module pool entries. This causes complications when accessing module attributes (e.g., `in_features`) and requires unwrapping in downstream code like slicing utilities.
Description
When `dataparallel=True` (the default), every module added to the MultitaskClassifier's module pool is wrapped in `nn.DataParallel`. This enables multi-GPU training but introduces a layer of indirection: the actual module is accessible via `.module` attribute. The slicing utilities must explicitly unwrap DataParallel modules to access `in_features` and `out_features` when constructing per-slice prediction heads.
Usage
Be aware of this default when:
- Single-GPU systems: Set `dataparallel=False` to avoid unnecessary overhead.
- Accessing module attributes: After model construction, module pool entries are DataParallel wrappers. Access the underlying module via `.module`.
- Custom slicing code: If writing custom task conversion code, always check for and unwrap DataParallel modules before accessing layer dimensions.
The Insight (Rule of Thumb)
- Action: Set `dataparallel=False` in MultitaskClassifier kwargs on single-GPU or CPU systems to avoid overhead and attribute access complications.
- Value: Default is `True`.
- Trade-off: DataParallel distributes computation across available GPUs but adds overhead on single-GPU systems. More importantly, it wraps modules, making direct attribute access (e.g., `linear.in_features`) impossible without unwrapping.
Reasoning
The DataParallel wrapping is applied indiscriminately to all module pool entries when tasks are added. The slicing infrastructure must handle this by explicitly checking `isinstance(head_module, nn.DataParallel)` and unwrapping. This coupling between the classifier default and downstream consumers is a source of subtle bugs if new code does not account for it.
Code evidence from `multitask_classifier.py:43-48,147-155`:
class ClassifierConfig(Config):
device: int = 0
dataparallel: bool = True
...
if key in self.module_pool.keys():
if self.config.dataparallel:
task.module_pool[key] = nn.DataParallel(self.module_pool[key])
else:
task.module_pool[key] = self.module_pool[key]
Unwrapping in slicing from `slicing/utils.py:107-108`:
if isinstance(head_module, nn.DataParallel):
head_module = head_module.module