Implementation:Huggingface Datasets Split Dataset By Node
| Knowledge Sources | |
|---|---|
| Domains | Data_Engineering, NLP |
| Last Updated | 2026-02-14 18:00 GMT |
Overview
Concrete tool for partitioning datasets across multiple distributed training nodes provided by the HuggingFace Datasets library.
Description
split_dataset_by_node is a top-level function that accepts either a Dataset (map-style) or an IterableDataset (streaming) and returns the partition assigned to the node identified by rank in a pool of world_size nodes.
Internally, the function dispatches based on the dataset type:
- For
Dataset: Delegates to_split_by_node_map_style_dataset(dataset, rank=rank, world_size=world_size), which computes contiguous row ranges and returns a sliced dataset. - For
IterableDataset: Delegates to_split_by_node_iterable_dataset(dataset, rank=rank, world_size=world_size), which either assigns whole shards (ifdataset.num_shards % world_size == 0) or configures interleaved element selection (keeping everyworld_size-th example starting at offsetrank).
The function is designed to be called identically on every node, with only the rank parameter varying. Each node receives a dataset object that, when iterated, yields only that node's assigned portion.
Usage
Use split_dataset_by_node at the beginning of a distributed training script, after loading the dataset but before iterating. It is typically called once per process during initialization.
Code Reference
Source Location
- Repository: datasets
- File:
src/datasets/distributed.py - Lines: L10-L43
Signature
def split_dataset_by_node(dataset: DatasetType, rank: int, world_size: int) -> DatasetType:
Import
from datasets.distributed import split_dataset_by_node
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dataset | Dataset or IterableDataset |
Yes | The dataset to split across nodes. |
| rank | int |
Yes | Rank (0-indexed identifier) of the current node. |
| world_size | int |
Yes | Total number of nodes in the distributed pool. |
Outputs
| Name | Type | Description |
|---|---|---|
| dataset | Dataset or IterableDataset |
The partition of the dataset assigned to the node at the given rank. Same type as the input. |
Usage Examples
Basic Usage
import torch
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
# In a distributed training script
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
ds = load_dataset("my_large_dataset", split="train", streaming=True)
# Each node gets its own partition
ds = split_dataset_by_node(ds, rank=rank, world_size=world_size)
# If shuffling, use the same seed on all nodes
ds = ds.shuffle(seed=42, buffer_size=10_000)
for example in ds:
train_step(example)