Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Snorkel team Snorkel DataParallel Default Behavior

From Leeroopedia
Knowledge Sources
Domains Multi_Task_Learning, Optimization
Last Updated 2026-02-14 21:00 GMT

Overview

The MultitaskClassifier enables `nn.DataParallel` by default, wrapping all module pool entries. This causes complications when accessing module attributes (e.g., `in_features`) and requires unwrapping in downstream code like slicing utilities.

Description

When `dataparallel=True` (the default), every module added to the MultitaskClassifier's module pool is wrapped in `nn.DataParallel`. This enables multi-GPU training but introduces a layer of indirection: the actual module is accessible via `.module` attribute. The slicing utilities must explicitly unwrap DataParallel modules to access `in_features` and `out_features` when constructing per-slice prediction heads.

Usage

Be aware of this default when:

  • Single-GPU systems: Set `dataparallel=False` to avoid unnecessary overhead.
  • Accessing module attributes: After model construction, module pool entries are DataParallel wrappers. Access the underlying module via `.module`.
  • Custom slicing code: If writing custom task conversion code, always check for and unwrap DataParallel modules before accessing layer dimensions.

The Insight (Rule of Thumb)

  • Action: Set `dataparallel=False` in MultitaskClassifier kwargs on single-GPU or CPU systems to avoid overhead and attribute access complications.
  • Value: Default is `True`.
  • Trade-off: DataParallel distributes computation across available GPUs but adds overhead on single-GPU systems. More importantly, it wraps modules, making direct attribute access (e.g., `linear.in_features`) impossible without unwrapping.

Reasoning

The DataParallel wrapping is applied indiscriminately to all module pool entries when tasks are added. The slicing infrastructure must handle this by explicitly checking `isinstance(head_module, nn.DataParallel)` and unwrapping. This coupling between the classifier default and downstream consumers is a source of subtle bugs if new code does not account for it.

Code evidence from `multitask_classifier.py:43-48,147-155`:

class ClassifierConfig(Config):
    device: int = 0
    dataparallel: bool = True
    ...
            if key in self.module_pool.keys():
                if self.config.dataparallel:
                    task.module_pool[key] = nn.DataParallel(self.module_pool[key])
                else:
                    task.module_pool[key] = self.module_pool[key]

Unwrapping in slicing from `slicing/utils.py:107-108`:

    if isinstance(head_module, nn.DataParallel):
        head_module = head_module.module

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment