Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Snorkel team Snorkel DictDataset Init

From Leeroopedia
Knowledge Sources
Domains Data_Preparation, Multi_Task_Learning, PyTorch
Last Updated 2026-02-14 20:00 GMT

Overview

Concrete tool for creating dictionary-indexed datasets and dataloaders for multi-task training, provided by the Snorkel library.

Description

DictDataset extends PyTorch Dataset to store features in X_dict and labels in Y_dict. DictDataLoader extends PyTorch DataLoader with a custom collate_dicts function that merges per-example dicts into batched dicts.

The from_tensors classmethod provides a convenience constructor for single-task datasets.

Usage

Import these classes when setting up data for MultitaskClassifier or SliceAwareClassifier training.

Code Reference

Source Location

  • Repository: snorkel
  • File: snorkel/classification/data.py
  • Lines: L19-177

Signature

class DictDataset(Dataset):
    def __init__(
        self,
        name: str,
        split: str,
        X_dict: Dict[str, Any],
        Y_dict: Dict[str, Tensor],
    ) -> None:
        """
        Args:
            name: Dataset name (used in metric reporting).
            split: Split name ("train", "valid", "test").
            X_dict: Feature dict mapping field names to data.
            Y_dict: Label dict mapping task names to label tensors.
        """

    @classmethod
    def from_tensors(
        cls,
        X_tensor: Tensor,
        Y_tensor: Tensor,
        split: str,
        input_data_key: str = "input_data",
        task_name: str = "task",
        dataset_name: str = "SnorkelDataset",
    ) -> "DictDataset": ...

class DictDataLoader(DataLoader):
    def __init__(
        self,
        dataset: DictDataset,
        collate_fn: Callable[..., Any] = collate_dicts,
        **kwargs: Any,
    ) -> None:
        """
        Args:
            dataset: DictDataset to wrap.
            collate_fn: Batch collation function (default: collate_dicts).
            **kwargs: Passed to DataLoader (batch_size, shuffle, etc.).
        """

Import

from snorkel.classification import DictDataset, DictDataLoader

I/O Contract

Inputs

Name Type Required Description
name str Yes Dataset name for metric reporting
split str Yes "train", "valid", or "test"
X_dict Dict[str, Any] Yes Feature dict (field_name -> data)
Y_dict Dict[str, Tensor] Yes Label dict (task_name -> Tensor)

Outputs

Name Type Description
DictDataset DictDataset Indexable dataset yielding (X_dict, Y_dict) tuples
DictDataLoader DictDataLoader Batched iteration over DictDataset

Usage Examples

import torch
from snorkel.classification import DictDataset, DictDataLoader

# Multi-field, multi-task dataset
X_dict = {
    "features": torch.randn(100, 50),
    "metadata": torch.randn(100, 10),
}
Y_dict = {
    "sentiment": torch.randint(0, 2, (100,)),
    "topic": torch.randint(0, 5, (100,)),
}

dataset = DictDataset(
    name="my_data",
    split="train",
    X_dict=X_dict,
    Y_dict=Y_dict,
)

dataloader = DictDataLoader(dataset, batch_size=32, shuffle=True)

# Simple single-task dataset
dataset = DictDataset.from_tensors(
    X_tensor=torch.randn(100, 50),
    Y_tensor=torch.randint(0, 2, (100,)),
    split="train",
)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment