Heuristic:Mlfoundations Open flamingo Deterministic Shard Shuffling
| Knowledge Sources | |
|---|---|
| Domains | Data_Loading, Distributed_Training, Reproducibility |
| Last Updated | 2026-02-08 03:30 GMT |
Overview
Deterministic shard shuffling using a seed-plus-epoch scheme that ensures reproducible data ordering across all workers and nodes while still providing different data each epoch.
Description
OpenFlamingo uses a custom `detshuffle2` pipeline stage (borrowed from open_clip) that provides deterministic shuffling of WebDataset shards. The shuffle seed is computed as `seed + epoch`, ensuring that all workers and nodes see the same shard order within an epoch (necessary for distributed training correctness) while seeing different orders across epochs. A `SharedEpoch` mechanism synchronizes epoch values across dataloader worker processes. For sample-level shuffling within shards, standard WebDataset buffered shuffling is used with configurable buffer sizes.
Usage
Apply this heuristic when training with WebDataset on multiple nodes/workers. It is automatically enabled for non-resampled datasets. The `--seed` argument controls the base seed. For resampled datasets, `ResampledShards2` with `deterministic=True` provides equivalent behavior.
The Insight (Rule of Thumb)
- Action: Use `detshuffle2` with `seed=args.seed` and `epoch=shared_epoch` for shard-level shuffling. Use `wds.shuffle(bufsize=5000, initial=1000)` for sample-level shuffling.
- Value: Shard shuffle buffer: 2000 (initial: 500). Sample shuffle buffer: 5000 (initial: 1000).
- Trade-off: Deterministic shuffling limits randomness compared to fully random shuffling, but ensures reproducibility and correct distributed training behavior.
Reasoning
In distributed training with WebDataset, each worker processes a subset of shards. If shuffling is not deterministic across workers, different workers may process overlapping data or miss shards entirely. The seed+epoch scheme ensures that:
- All workers see the same shard order (consistent splitting by `wds.split_by_node` and `wds.split_by_worker`)
- Each epoch has a different shard order (preventing memorization of data ordering)
- Training is reproducible given the same seed
The `SharedEpoch` class uses `multiprocessing.Value` to share the epoch counter across dataloader worker processes, which is critical because Python dataloader workers are separate processes.
Code Evidence
Deterministic shuffling constants from `open_flamingo/train/data.py:23-27`:
_SHARD_SHUFFLE_SIZE = 2000
_SHARD_SHUFFLE_INITIAL = 500
_SAMPLE_SHUFFLE_SIZE = 5000
_SAMPLE_SHUFFLE_INITIAL = 1000
detshuffle2 seed computation from `open_flamingo/train/data_utils.py:172-188`:
def run(self, src):
if isinstance(self.epoch, SharedEpoch):
epoch = self.epoch.get_value()
else:
self.epoch += 1
epoch = self.epoch
rng = random.Random()
if self.seed < 0:
seed = pytorch_worker_seed(epoch)
else:
# This seed to be deterministic AND the same across all nodes/workers
seed = self.seed + epoch
rng.seed(seed)
return _shuffle(src, self.bufsize, self.initial, rng)
SharedEpoch for cross-worker sync from `open_flamingo/train/data_utils.py:34-43`:
class SharedEpoch:
def __init__(self, epoch: int = 0):
self.shared_epoch = Value("i", epoch)
def set_value(self, epoch):
self.shared_epoch.value = epoch
def get_value(self):
return self.shared_epoch.value