Principle:Allenai Open instruct Distributed Data Loading
| Knowledge Sources | |
|---|---|
| Domains | Distributed Computing Data Engineering |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Distributed data loading is the technique of partitioning a dataset across multiple parallel workers with deterministic sharding, epoch-based reshuffling, and support for checkpointing and resumption.
Description
When training a model across multiple GPUs or nodes, each worker must receive a disjoint subset of the data to avoid redundant computation. A distributed data loader must solve several problems simultaneously:
- Sharding: The dataset is divided into non-overlapping partitions, one per data-parallel rank. Each rank processes only its assigned partition.
- Deterministic shuffling: At each epoch boundary, the data is reshuffled using a seed derived from the base seed plus the epoch number. This ensures that (a) different epochs see different orderings, and (b) the shuffling is reproducible given the same seed.
- Strided sampling: Rather than assigning contiguous blocks of data to each rank, a strided approach distributes examples from each global batch across ranks. This ensures diversity within each rank's local batches, which is important for stable training.
- Checkpointing: The data loader must be able to save and restore its state (epoch, batch position, excluded indices) to support training resumption from checkpoints.
- Index exclusion: In RLVR training, prompts that the model has fully mastered (100% solve rate) can be excluded from future sampling to focus training on harder examples.
Usage
Distributed data loading is used in any multi-GPU training scenario. In the GRPO pipeline, it is used by the DataPreparationActor to iterate over prompts that are then sent to vLLM engines for generation. It is also used in the DPO pipeline and supervised fine-tuning.
Theoretical Basis
The key insight is that global batches should be constructed first, then distributed to ranks, rather than each rank independently sampling from its shard. This ensures that each global batch contains a representative sample of the full dataset.
Given a dataset of size N, world size W, and global batch size B:
For epoch e:
rng = RandomGenerator(seed + e)
indices = rng.permutation(N)
# Compute number of complete global batches
total_batches = N // B
usable = total_batches * B
# Reshape into (total_batches, B) and assign strided slices to each rank
batched = indices[:usable].reshape(total_batches, B)
rank_indices = batched[:, rank::W].flatten()
# Each rank iterates over its assigned indices
for batch in chunks(rank_indices, B // W):
yield dataset[batch]
This strided distribution ensures that within each global batch, every rank gets examples from different positions in the shuffled dataset, maximizing batch diversity.
The exclusion mechanism enables a form of curriculum learning where the model naturally progresses from easier to harder examples as it masters simpler ones. The excluded indices are tracked per data loader instance and respected during reshuffling.