Implementation:Huggingface Datasets Stratify Utils
Overview
Stratify Utils provides stratified sampling utilities for generating train/test splits that preserve class distributions. The module implements an approximate mode computation for the multivariate hypergeometric distribution and a stratified shuffle-split index generator, closely following the approach used in scikit-learn's StratifiedShuffleSplit.
This module is part of the huggingface/datasets repository.
- Source file: src/datasets/utils/stratify.py (107 lines)
- Domain: Sampling, Statistics
- Import:
from datasets.utils.stratify import stratified_shuffle_split_generate_indices
Functions
approximate_mode
Computes the approximate mode of a multivariate hypergeometric distribution. Given a set of class counts and a number of draws, this function returns the most likely outcome of drawing n_draws samples from the population defined by class_counts. The result should not be off by more than one from the true mode.
Parameters:
class_counts(ndarray of int) -- Population count per class.n_draws(int) -- Total number of samples to draw from the overall population.rng(random state) -- A NumPy random state used to break ties randomly, avoiding biases.
Returns:
sampled_classes(ndarray of int) -- Number of samples drawn from each class. The sum of the returned array equalsn_draws.
Algorithm:
- Computes a continuous proportional allocation:
n_draws * class_counts / class_counts.sum(). - Floors the continuous values to avoid overshooting.
- Distributes the remaining draws according to fractional remainders, breaking ties via random selection.
def approximate_mode(class_counts, n_draws, rng):
continuous = n_draws * class_counts / class_counts.sum()
floored = np.floor(continuous)
need_to_add = int(n_draws - floored.sum())
if need_to_add > 0:
remainder = continuous - floored
values = np.sort(np.unique(remainder))[::-1]
for value in values:
(inds,) = np.where(remainder == value)
add_now = min(len(inds), need_to_add)
inds = rng.choice(inds, size=add_now, replace=False)
floored[inds] += 1
need_to_add -= add_now
if need_to_add == 0:
break
return floored.astype(np.int64)
stratified_shuffle_split_generate_indices
Provides train/test indices to split data into train/test sets while preserving class distributions. The implementation is based on scikit-learn's StratifiedShuffleSplit.
Parameters:
y(array-like) -- Target labels used for stratification.n_train(int) -- Absolute number of train samples.n_test(int) -- Absolute number of test samples.rng(random state) -- A NumPy random state for reproducibility.n_splits(int, default=10) -- Number of re-shuffling and splitting iterations.
Yields:
- A tuple
(train, test)of index arrays for each split iteration.
Validation:
- Raises
ValueErrorif any class has fewer than 2 samples. - Raises
ValueErrorifn_trainorn_testis less than the number of classes.
def stratified_shuffle_split_generate_indices(y, n_train, n_test, rng, n_splits=10):
classes, y_indices = np.unique(y, return_inverse=True)
n_classes = classes.shape[0]
class_counts = np.bincount(y_indices)
if np.min(class_counts) < 2:
raise ValueError("Minimum class count error")
if n_train < n_classes:
raise ValueError(
"The train_size = %d should be greater or equal to the number of classes = %d" % (n_train, n_classes)
)
if n_test < n_classes:
raise ValueError(
"The test_size = %d should be greater or equal to the number of classes = %d" % (n_test, n_classes)
)
class_indices = np.split(np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1])
for _ in range(n_splits):
n_i = approximate_mode(class_counts, n_train, rng)
class_counts_remaining = class_counts - n_i
t_i = approximate_mode(class_counts_remaining, n_test, rng)
train = []
test = []
for i in range(n_classes):
permutation = rng.permutation(class_counts[i])
perm_indices_class_i = class_indices[i].take(permutation, mode="clip")
train.extend(perm_indices_class_i[: n_i[i]])
test.extend(perm_indices_class_i[n_i[i] : n_i[i] + t_i[i]])
train = rng.permutation(train)
test = rng.permutation(test)
yield train, test
Dependencies
| Dependency | Type | Purpose |
|---|---|---|
numpy |
External | Array operations, random state, bincount, and permutation |
Usage Example
import numpy as np
from datasets.utils.stratify import stratified_shuffle_split_generate_indices
# Labels with 3 classes
y = np.array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2])
rng = np.random.RandomState(42)
for train_indices, test_indices in stratified_shuffle_split_generate_indices(
y, n_train=6, n_test=3, rng=rng, n_splits=1
):
print("Train indices:", train_indices)
print("Test indices:", test_indices)
Design Notes
- The
approximate_modefunction uses a floor-then-distribute approach rather than computing the exact mode of the multivariate hypergeometric, which would be computationally expensive. The approximation is guaranteed to be within one of the true mode. - Tie-breaking is handled via random selection (
rng.choice), which prevents systematic biases toward particular classes. - The generator pattern (
yield) allows memory-efficient iteration over multiple splits without computing all splits upfront. - This implementation mirrors scikit-learn's
StratifiedShuffleSplitbut is self-contained within the datasets library to avoid a hard dependency on scikit-learn.
File Location
- Repository: huggingface/datasets
- Full path: src/datasets/utils/stratify.py