Implementation:Snorkel team Snorkel DictDataset Init
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Data_Preparation, Multi_Task_Learning, PyTorch |
| Last Updated | 2026-02-14 20:00 GMT |
Overview
Concrete tool for creating dictionary-indexed datasets and dataloaders for multi-task training, provided by the Snorkel library.
Description
DictDataset extends PyTorch Dataset to store features in X_dict and labels in Y_dict. DictDataLoader extends PyTorch DataLoader with a custom collate_dicts function that merges per-example dicts into batched dicts.
The from_tensors classmethod provides a convenience constructor for single-task datasets.
Usage
Import these classes when setting up data for MultitaskClassifier or SliceAwareClassifier training.
Code Reference
Source Location
- Repository: snorkel
- File: snorkel/classification/data.py
- Lines: L19-177
Signature
class DictDataset(Dataset):
def __init__(
self,
name: str,
split: str,
X_dict: Dict[str, Any],
Y_dict: Dict[str, Tensor],
) -> None:
"""
Args:
name: Dataset name (used in metric reporting).
split: Split name ("train", "valid", "test").
X_dict: Feature dict mapping field names to data.
Y_dict: Label dict mapping task names to label tensors.
"""
@classmethod
def from_tensors(
cls,
X_tensor: Tensor,
Y_tensor: Tensor,
split: str,
input_data_key: str = "input_data",
task_name: str = "task",
dataset_name: str = "SnorkelDataset",
) -> "DictDataset": ...
class DictDataLoader(DataLoader):
def __init__(
self,
dataset: DictDataset,
collate_fn: Callable[..., Any] = collate_dicts,
**kwargs: Any,
) -> None:
"""
Args:
dataset: DictDataset to wrap.
collate_fn: Batch collation function (default: collate_dicts).
**kwargs: Passed to DataLoader (batch_size, shuffle, etc.).
"""
Import
from snorkel.classification import DictDataset, DictDataLoader
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| name | str | Yes | Dataset name for metric reporting |
| split | str | Yes | "train", "valid", or "test" |
| X_dict | Dict[str, Any] | Yes | Feature dict (field_name -> data) |
| Y_dict | Dict[str, Tensor] | Yes | Label dict (task_name -> Tensor) |
Outputs
| Name | Type | Description |
|---|---|---|
| DictDataset | DictDataset | Indexable dataset yielding (X_dict, Y_dict) tuples |
| DictDataLoader | DictDataLoader | Batched iteration over DictDataset |
Usage Examples
import torch
from snorkel.classification import DictDataset, DictDataLoader
# Multi-field, multi-task dataset
X_dict = {
"features": torch.randn(100, 50),
"metadata": torch.randn(100, 10),
}
Y_dict = {
"sentiment": torch.randint(0, 2, (100,)),
"topic": torch.randint(0, 5, (100,)),
}
dataset = DictDataset(
name="my_data",
split="train",
X_dict=X_dict,
Y_dict=Y_dict,
)
dataloader = DictDataLoader(dataset, batch_size=32, shuffle=True)
# Simple single-task dataset
dataset = DictDataset.from_tensors(
X_tensor=torch.randn(100, 50),
Y_tensor=torch.randint(0, 2, (100,)),
split="train",
)
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment