Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Scikit learn contrib Imbalanced learn estimator checks

From Leeroopedia


Implementation: Scikit-learn-contrib Imbalanced-learn estimator_checks

Source: imblearn/utils/estimator_checks.py (lines 1-961)

Purpose: Provides sampler and classifier compatibility checks analogous to scikit-learn's check_estimator. These checks verify that imbalanced-learn estimators comply with the library's API contract across a wide range of input formats, edge cases, and behavioral requirements.

Import:

from imblearn.utils.estimator_checks import parametrize_with_checks

Key Public Functions

parametrize_with_checks

def parametrize_with_checks(estimators, *, legacy=True, expected_failed_checks=None):

Returns a pytest.mark.parametrize decorator that generates (estimator, check) test cases for every estimator in the input list. Checks that are expected to fail (specified via a callable returning {check_name: reason}) are marked as xfail.

Parameters:

Parameter Type Description
estimators list of estimator instances Estimators to generate checks for
legacy bool, default=True Whether to include legacy checks
expected_failed_checks callable or None Callable mapping estimator to dict of {check_name: reason}

Usage example:

from imblearn.utils.estimator_checks import parametrize_with_checks
from imblearn.over_sampling import SMOTE

@parametrize_with_checks([SMOTE()])
def test_smote_compatible(estimator, check):
    check(estimator)

estimator_checks_generator

def estimator_checks_generator(estimator, *, legacy=True, expected_failed_checks=None, mark=None):

A generator that yields (estimator, check) tuples for a single estimator. The mark parameter controls how expected failures are handled: "xfail" wraps them as pytest.param(..., marks=pytest.mark.xfail), while "skip" wraps the check in a function that raises SkipTest.

Internal Dispatch

The module uses internal generator functions to dispatch checks based on estimator capabilities:

def _yield_all_checks(estimator, legacy=True):
    tags = get_tags(estimator)
    if tags._skip_test:
        warnings.warn(...)
        return
    if hasattr(estimator, "fit_resample"):
        for check in _yield_sampler_checks(estimator):
            yield check
    if hasattr(estimator, "predict"):
        for check in _yield_classifier_checks(estimator):
            yield check

Sampler checks are further filtered by the estimator's tags (sparse support, DataFrame support, string support, NaN tolerance).

Individual Check Functions

Sampler Checks

Check Function What It Verifies
check_target_type Rejects continuous targets with ValueError("Unknown label type:") and rejects multilabel targets
check_samplers_one_label Raises ValueError when only a single class is present in y
check_samplers_fit After fit_resample, the fitted attribute sampling_strategy_ exists
check_samplers_fit_resample Over-samplers increase minority to at least majority count; under-samplers reduce to minority count; cleaning samplers remove noisy samples from non-minority classes
check_samplers_sampling_strategy_fit_resample A custom sampling_strategy dict/list leaves untargeted classes unchanged
check_samplers_sparse Sparse CSR input produces sparse output matching dense results (conditional on input_tags.sparse)
check_samplers_pandas DataFrame/Series input returns DataFrame/Series with preserved column names and series name (conditional on input_tags.dataframe)
check_samplers_pandas_sparse Pandas sparse DataFrame input returns sparse DataFrame with preserved dtypes (conditional on input_tags.dataframe)
check_samplers_list Python list inputs return Python list outputs with matching values
check_samplers_multiclass_ova Multiclass target and OVA-binarized target produce consistent resampled results
check_samplers_preserve_dtype float32/int32 dtypes in X and y are preserved through resampling
check_samplers_sample_indices sample_indices_ attribute exists if and only if sampler_tags.sample_indices is True
check_samplers_string Object-dtype string feature arrays are handled correctly (conditional on input_tags.string)
check_samplers_nan NaN values in features are preserved through resampling (conditional on input_tags.allow_nan)
check_samplers_2d_target A 2D column-vector target y.reshape(-1, 1) is accepted without error
check_sampler_get_feature_names_out get_feature_names_out returns a numpy array of strings with length matching output features
check_sampler_get_feature_names_out_pandas When fit on a DataFrame, get_feature_names_out() matches column names and rejects invalid input features

Classifier Checks

Check Function What It Verifies
check_classifier_on_multilabel_or_multioutput_targets fit() raises ValueError("Multilabel and multioutput targets are not supported.") for multilabel data
check_classifiers_with_encoded_labels Categorical string labels (e.g. iris species names) are handled correctly; classes_ and predict output match the original categories (regression test for GitHub issue #709)

Parameter and Feature Name Checks

Check Function What It Verifies
check_param_validation Every constructor parameter has a matching entry in _parameter_constraints; invalid types and values raise ValueError with informative messages
check_dataframe_column_names_consistency After fitting on a DataFrame, feature_names_in_ is set; prediction methods raise ValueError for reordered, renamed, or missing feature names

Test Data Generation

The module generates imbalanced test data using scikit-learn's make_classification:

def sample_dataset_generator():
    X, y = make_classification(
        n_samples=1000,
        n_classes=3,
        n_informative=4,
        weights=[0.2, 0.3, 0.5],
        random_state=0,
    )
    return X, y

This produces a 3-class dataset with class proportions of approximately 20%/30%/50%, providing a realistic imbalanced scenario for testing.

Estimator Parameter Tuning for Tests

The helper _set_checking_parameters adjusts estimator parameters to make checks faster or deterministic:

def _set_checking_parameters(estimator):
    params = estimator.get_params()
    name = estimator.__class__.__name__
    if "n_estimators" in params:
        estimator.set_params(n_estimators=min(5, estimator.n_estimators))
    if name == "ClusterCentroids":
        estimator.set_params(
            voting="soft",
            estimator=KMeans(random_state=0, algorithm="lloyd", n_init=1),
        )
    if name == "KMeansSMOTE":
        estimator.set_params(kmeans_estimator=12)

Expected Failure Handling

The _maybe_mark and _should_be_skipped_or_marked functions allow checks to be marked as xfail or skip when they are known to fail for a given estimator. This integrates with pytest's parametrization system:

def _maybe_mark(estimator, check, expected_failed_checks=None, mark=None, pytest=None):
    should_be_marked, reason = _should_be_skipped_or_marked(
        estimator, check, expected_failed_checks
    )
    if not should_be_marked or mark is None:
        return estimator, check
    if mark == "xfail":
        return pytest.param(estimator, check, marks=pytest.mark.xfail(reason=reason))
    else:
        @wraps(check)
        def wrapped(*args, **kwargs):
            raise SkipTest(
                f"Skipping {_check_name(check)} for {estimator_name}: {reason}"
            )
        return estimator, wrapped

Tag-based Conditional Check Yielding

Sampler checks are conditionally yielded based on the estimator's tags, read via get_tags(sampler):

def _yield_sampler_checks(sampler):
    tags = get_tags(sampler)
    accept_sparse = tags.input_tags.sparse
    accept_dataframe = tags.input_tags.dataframe
    accept_string = tags.input_tags.string
    allow_nan = tags.input_tags.allow_nan

    yield check_target_type
    yield check_samplers_one_label
    yield check_samplers_fit
    yield check_samplers_fit_resample
    yield check_samplers_sampling_strategy_fit_resample
    if accept_sparse:
        yield check_samplers_sparse
    if accept_dataframe:
        yield check_samplers_pandas
        yield check_samplers_pandas_sparse
    if accept_string:
        yield check_samplers_string
    if allow_nan:
        yield check_samplers_nan
    yield check_samplers_list
    yield check_samplers_multiclass_ova
    yield check_samplers_preserve_dtype
    yield check_samplers_sample_indices
    yield check_samplers_2d_target
    yield check_sampler_get_feature_names_out
    yield check_sampler_get_feature_names_out_pandas

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment