Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Datasets Stratify Utils

From Leeroopedia
Revision as of 13:00, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Huggingface_Datasets_Stratify_Utils.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)

Overview

Stratify Utils provides stratified sampling utilities for generating train/test splits that preserve class distributions. The module implements an approximate mode computation for the multivariate hypergeometric distribution and a stratified shuffle-split index generator, closely following the approach used in scikit-learn's StratifiedShuffleSplit.

This module is part of the huggingface/datasets repository.

  • Source file: src/datasets/utils/stratify.py (107 lines)
  • Domain: Sampling, Statistics
  • Import: from datasets.utils.stratify import stratified_shuffle_split_generate_indices

Functions

approximate_mode

Computes the approximate mode of a multivariate hypergeometric distribution. Given a set of class counts and a number of draws, this function returns the most likely outcome of drawing n_draws samples from the population defined by class_counts. The result should not be off by more than one from the true mode.

Parameters:

  • class_counts (ndarray of int) -- Population count per class.
  • n_draws (int) -- Total number of samples to draw from the overall population.
  • rng (random state) -- A NumPy random state used to break ties randomly, avoiding biases.

Returns:

  • sampled_classes (ndarray of int) -- Number of samples drawn from each class. The sum of the returned array equals n_draws.

Algorithm:

  1. Computes a continuous proportional allocation: n_draws * class_counts / class_counts.sum().
  2. Floors the continuous values to avoid overshooting.
  3. Distributes the remaining draws according to fractional remainders, breaking ties via random selection.
def approximate_mode(class_counts, n_draws, rng):
    continuous = n_draws * class_counts / class_counts.sum()
    floored = np.floor(continuous)
    need_to_add = int(n_draws - floored.sum())
    if need_to_add > 0:
        remainder = continuous - floored
        values = np.sort(np.unique(remainder))[::-1]
        for value in values:
            (inds,) = np.where(remainder == value)
            add_now = min(len(inds), need_to_add)
            inds = rng.choice(inds, size=add_now, replace=False)
            floored[inds] += 1
            need_to_add -= add_now
            if need_to_add == 0:
                break
    return floored.astype(np.int64)

stratified_shuffle_split_generate_indices

Provides train/test indices to split data into train/test sets while preserving class distributions. The implementation is based on scikit-learn's StratifiedShuffleSplit.

Parameters:

  • y (array-like) -- Target labels used for stratification.
  • n_train (int) -- Absolute number of train samples.
  • n_test (int) -- Absolute number of test samples.
  • rng (random state) -- A NumPy random state for reproducibility.
  • n_splits (int, default=10) -- Number of re-shuffling and splitting iterations.

Yields:

  • A tuple (train, test) of index arrays for each split iteration.

Validation:

  • Raises ValueError if any class has fewer than 2 samples.
  • Raises ValueError if n_train or n_test is less than the number of classes.
def stratified_shuffle_split_generate_indices(y, n_train, n_test, rng, n_splits=10):
    classes, y_indices = np.unique(y, return_inverse=True)
    n_classes = classes.shape[0]
    class_counts = np.bincount(y_indices)
    if np.min(class_counts) < 2:
        raise ValueError("Minimum class count error")
    if n_train < n_classes:
        raise ValueError(
            "The train_size = %d should be greater or equal to the number of classes = %d" % (n_train, n_classes)
        )
    if n_test < n_classes:
        raise ValueError(
            "The test_size = %d should be greater or equal to the number of classes = %d" % (n_test, n_classes)
        )
    class_indices = np.split(np.argsort(y_indices, kind="mergesort"), np.cumsum(class_counts)[:-1])
    for _ in range(n_splits):
        n_i = approximate_mode(class_counts, n_train, rng)
        class_counts_remaining = class_counts - n_i
        t_i = approximate_mode(class_counts_remaining, n_test, rng)
        train = []
        test = []
        for i in range(n_classes):
            permutation = rng.permutation(class_counts[i])
            perm_indices_class_i = class_indices[i].take(permutation, mode="clip")
            train.extend(perm_indices_class_i[: n_i[i]])
            test.extend(perm_indices_class_i[n_i[i] : n_i[i] + t_i[i]])
        train = rng.permutation(train)
        test = rng.permutation(test)
        yield train, test

Dependencies

Dependency Type Purpose
numpy External Array operations, random state, bincount, and permutation

Usage Example

import numpy as np
from datasets.utils.stratify import stratified_shuffle_split_generate_indices

# Labels with 3 classes
y = np.array([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2])
rng = np.random.RandomState(42)

for train_indices, test_indices in stratified_shuffle_split_generate_indices(
    y, n_train=6, n_test=3, rng=rng, n_splits=1
):
    print("Train indices:", train_indices)
    print("Test indices:", test_indices)

Design Notes

  • The approximate_mode function uses a floor-then-distribute approach rather than computing the exact mode of the multivariate hypergeometric, which would be computationally expensive. The approximation is guaranteed to be within one of the true mode.
  • Tie-breaking is handled via random selection (rng.choice), which prevents systematic biases toward particular classes.
  • The generator pattern (yield) allows memory-efficient iteration over multiple splits without computing all splits upfront.
  • This implementation mirrors scikit-learn's StratifiedShuffleSplit but is self-contained within the datasets library to avoid a hard dependency on scikit-learn.

File Location

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment