Implementation:Scikit learn contrib Imbalanced learn estimator checks
Implementation: Scikit-learn-contrib Imbalanced-learn estimator_checks
Source: imblearn/utils/estimator_checks.py (lines 1-961)
Purpose: Provides sampler and classifier compatibility checks analogous to scikit-learn's check_estimator. These checks verify that imbalanced-learn estimators comply with the library's API contract across a wide range of input formats, edge cases, and behavioral requirements.
Import:
from imblearn.utils.estimator_checks import parametrize_with_checks
Key Public Functions
parametrize_with_checks
def parametrize_with_checks(estimators, *, legacy=True, expected_failed_checks=None):
Returns a pytest.mark.parametrize decorator that generates (estimator, check) test cases for every estimator in the input list. Checks that are expected to fail (specified via a callable returning {check_name: reason}) are marked as xfail.
Parameters:
| Parameter | Type | Description |
|---|---|---|
estimators |
list of estimator instances | Estimators to generate checks for |
legacy |
bool, default=True | Whether to include legacy checks |
expected_failed_checks |
callable or None | Callable mapping estimator to dict of {check_name: reason}
|
Usage example:
from imblearn.utils.estimator_checks import parametrize_with_checks
from imblearn.over_sampling import SMOTE
@parametrize_with_checks([SMOTE()])
def test_smote_compatible(estimator, check):
check(estimator)
estimator_checks_generator
def estimator_checks_generator(estimator, *, legacy=True, expected_failed_checks=None, mark=None):
A generator that yields (estimator, check) tuples for a single estimator. The mark parameter controls how expected failures are handled: "xfail" wraps them as pytest.param(..., marks=pytest.mark.xfail), while "skip" wraps the check in a function that raises SkipTest.
Internal Dispatch
The module uses internal generator functions to dispatch checks based on estimator capabilities:
def _yield_all_checks(estimator, legacy=True):
tags = get_tags(estimator)
if tags._skip_test:
warnings.warn(...)
return
if hasattr(estimator, "fit_resample"):
for check in _yield_sampler_checks(estimator):
yield check
if hasattr(estimator, "predict"):
for check in _yield_classifier_checks(estimator):
yield check
Sampler checks are further filtered by the estimator's tags (sparse support, DataFrame support, string support, NaN tolerance).
Individual Check Functions
Sampler Checks
| Check Function | What It Verifies |
|---|---|
check_target_type |
Rejects continuous targets with ValueError("Unknown label type:") and rejects multilabel targets
|
check_samplers_one_label |
Raises ValueError when only a single class is present in y
|
check_samplers_fit |
After fit_resample, the fitted attribute sampling_strategy_ exists
|
check_samplers_fit_resample |
Over-samplers increase minority to at least majority count; under-samplers reduce to minority count; cleaning samplers remove noisy samples from non-minority classes |
check_samplers_sampling_strategy_fit_resample |
A custom sampling_strategy dict/list leaves untargeted classes unchanged
|
check_samplers_sparse |
Sparse CSR input produces sparse output matching dense results (conditional on input_tags.sparse)
|
check_samplers_pandas |
DataFrame/Series input returns DataFrame/Series with preserved column names and series name (conditional on input_tags.dataframe)
|
check_samplers_pandas_sparse |
Pandas sparse DataFrame input returns sparse DataFrame with preserved dtypes (conditional on input_tags.dataframe)
|
check_samplers_list |
Python list inputs return Python list outputs with matching values |
check_samplers_multiclass_ova |
Multiclass target and OVA-binarized target produce consistent resampled results |
check_samplers_preserve_dtype |
float32/int32 dtypes in X and y are preserved through resampling
|
check_samplers_sample_indices |
sample_indices_ attribute exists if and only if sampler_tags.sample_indices is True
|
check_samplers_string |
Object-dtype string feature arrays are handled correctly (conditional on input_tags.string)
|
check_samplers_nan |
NaN values in features are preserved through resampling (conditional on input_tags.allow_nan)
|
check_samplers_2d_target |
A 2D column-vector target y.reshape(-1, 1) is accepted without error
|
check_sampler_get_feature_names_out |
get_feature_names_out returns a numpy array of strings with length matching output features
|
check_sampler_get_feature_names_out_pandas |
When fit on a DataFrame, get_feature_names_out() matches column names and rejects invalid input features
|
Classifier Checks
| Check Function | What It Verifies |
|---|---|
check_classifier_on_multilabel_or_multioutput_targets |
fit() raises ValueError("Multilabel and multioutput targets are not supported.") for multilabel data
|
check_classifiers_with_encoded_labels |
Categorical string labels (e.g. iris species names) are handled correctly; classes_ and predict output match the original categories (regression test for GitHub issue #709)
|
Parameter and Feature Name Checks
| Check Function | What It Verifies |
|---|---|
check_param_validation |
Every constructor parameter has a matching entry in _parameter_constraints; invalid types and values raise ValueError with informative messages
|
check_dataframe_column_names_consistency |
After fitting on a DataFrame, feature_names_in_ is set; prediction methods raise ValueError for reordered, renamed, or missing feature names
|
Test Data Generation
The module generates imbalanced test data using scikit-learn's make_classification:
def sample_dataset_generator():
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
return X, y
This produces a 3-class dataset with class proportions of approximately 20%/30%/50%, providing a realistic imbalanced scenario for testing.
Estimator Parameter Tuning for Tests
The helper _set_checking_parameters adjusts estimator parameters to make checks faster or deterministic:
def _set_checking_parameters(estimator):
params = estimator.get_params()
name = estimator.__class__.__name__
if "n_estimators" in params:
estimator.set_params(n_estimators=min(5, estimator.n_estimators))
if name == "ClusterCentroids":
estimator.set_params(
voting="soft",
estimator=KMeans(random_state=0, algorithm="lloyd", n_init=1),
)
if name == "KMeansSMOTE":
estimator.set_params(kmeans_estimator=12)
Expected Failure Handling
The _maybe_mark and _should_be_skipped_or_marked functions allow checks to be marked as xfail or skip when they are known to fail for a given estimator. This integrates with pytest's parametrization system:
def _maybe_mark(estimator, check, expected_failed_checks=None, mark=None, pytest=None):
should_be_marked, reason = _should_be_skipped_or_marked(
estimator, check, expected_failed_checks
)
if not should_be_marked or mark is None:
return estimator, check
if mark == "xfail":
return pytest.param(estimator, check, marks=pytest.mark.xfail(reason=reason))
else:
@wraps(check)
def wrapped(*args, **kwargs):
raise SkipTest(
f"Skipping {_check_name(check)} for {estimator_name}: {reason}"
)
return estimator, wrapped
Tag-based Conditional Check Yielding
Sampler checks are conditionally yielded based on the estimator's tags, read via get_tags(sampler):
def _yield_sampler_checks(sampler):
tags = get_tags(sampler)
accept_sparse = tags.input_tags.sparse
accept_dataframe = tags.input_tags.dataframe
accept_string = tags.input_tags.string
allow_nan = tags.input_tags.allow_nan
yield check_target_type
yield check_samplers_one_label
yield check_samplers_fit
yield check_samplers_fit_resample
yield check_samplers_sampling_strategy_fit_resample
if accept_sparse:
yield check_samplers_sparse
if accept_dataframe:
yield check_samplers_pandas
yield check_samplers_pandas_sparse
if accept_string:
yield check_samplers_string
if allow_nan:
yield check_samplers_nan
yield check_samplers_list
yield check_samplers_multiclass_ova
yield check_samplers_preserve_dtype
yield check_samplers_sample_indices
yield check_samplers_2d_target
yield check_sampler_get_feature_names_out
yield check_sampler_get_feature_names_out_pandas