Implementation:Scikit learn contrib Imbalanced learn BalancedBatchGenerator
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Data_Preprocessing, Imbalanced_Learning |
| Last Updated | 2026-02-09 03:00 GMT |
Overview
Concrete tool for generating balanced mini-batches for Keras model training provided by the imbalanced-learn library.
Description
The BalancedBatchGenerator class creates a Keras Sequence (or PyDataset) that yields balanced mini-batches. On initialization, it applies a sampler (default: RandomUnderSampler) to create balanced indices, shuffles them, and partitions them into batches. Compatible with model.fit() for seamless Keras integration.
Usage
Import this class when training Keras models on imbalanced data. Pass it directly to model.fit() as the training data generator.
Code Reference
Source Location
- Repository: imbalanced-learn
- File: imblearn/keras/_generator.py
- Lines: L64-202
Signature
class BalancedBatchGenerator(*ParentClass):
def __init__(
self,
X,
y,
*,
sample_weight=None,
sampler=None,
batch_size=32,
keep_sparse=False,
random_state=None,
):
"""
Args:
X: ndarray of shape (n_samples, n_features) - Training data.
y: ndarray of shape (n_samples,) or (n_samples, n_classes) - Targets.
sample_weight: ndarray or None - Sample weights.
sampler: sampler with sample_indices_ or None - Balancing sampler
(default: RandomUnderSampler).
batch_size: int - Samples per batch (default: 32).
keep_sparse: bool - Preserve sparse input (default: False).
random_state: int, RandomState, or None - Seed.
"""
Import
from imblearn.keras import BalancedBatchGenerator
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| X | ndarray of shape (n_samples, n_features) | Yes | Training features |
| y | ndarray of shape (n_samples,) or (n_samples, n_classes) | Yes | Targets (can be one-hot encoded) |
| sampler | sampler with sample_indices_ or None | No | Balancing sampler (default: RandomUnderSampler) |
| batch_size | int | No | Samples per gradient update (default: 32) |
Outputs
| Name | Type | Description |
|---|---|---|
| __getitem__(index) | tuple of (X_batch, y_batch) | Balanced mini-batch |
| __len__() | int | Number of batches per epoch |
Usage Examples
import tensorflow
from sklearn.datasets import load_iris
from imblearn.datasets import make_imbalance
from imblearn.keras import BalancedBatchGenerator
# Prepare data
iris = load_iris()
X, y = make_imbalance(
iris.data, iris.target,
sampling_strategy={0: 30, 1: 50, 2: 40}
)
y_cat = tensorflow.keras.utils.to_categorical(y, 3)
# Build model
model = tensorflow.keras.models.Sequential([
tensorflow.keras.layers.Dense(y_cat.shape[1], input_dim=X.shape[1], activation="softmax")
])
model.compile(optimizer="sgd", loss="categorical_crossentropy", metrics=["accuracy"])
# Train with balanced batches
training_generator = BalancedBatchGenerator(X, y_cat, batch_size=10, random_state=42)
model.fit(training_generator, epochs=10, verbose=0)