Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Scikit learn contrib Imbalanced learn Sampler Compatibility Checking

From Leeroopedia
Revision as of 18:01, 16 February 2026 by Admin (talk | contribs) (Auto-imported from principles/Scikit_learn_contrib_Imbalanced_learn_Sampler_Compatibility_Checking.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Principle: Scikit-learn-contrib Imbalanced-learn Sampler Compatibility Checking

Theory: The sampler compatibility checking framework extends scikit-learn's estimator check infrastructure to validate that sampler implementations comply with the imbalanced-learn API contract. It provides a systematic, automated approach to verifying behavioral correctness across a comprehensive set of input formats, edge cases, and semantic requirements.

Motivation

Scikit-learn provides check_estimator to verify that estimators conform to its API conventions (cloneability, get/set params, fit/predict interface). However, imbalanced-learn introduces a distinct API surface centered on the fit_resample method and the concept of sampling strategies. The standard scikit-learn checks cannot validate:

  • Whether over-samplers actually increase minority class counts
  • Whether under-samplers actually decrease majority class counts
  • Whether cleaning samplers remove samples from non-minority classes
  • Whether sampling_strategy_ is set after fitting
  • Whether resampled output preserves input container types (DataFrame, sparse matrix, list)
  • Whether sample_indices_ tracking is consistent with advertised tags

The estimator_checks module fills this gap by providing a parallel set of checks tailored to the sampler protocol.

Architectural Design

The framework is organized around three layers:

Layer Components Responsibility
Public API parametrize_with_checks, estimator_checks_generator Entry points for test authors; integrate with pytest parametrization
Dispatch _yield_all_checks, _yield_sampler_checks, _yield_classifier_checks Route estimators to applicable check suites based on capabilities and tags
Individual checks check_target_type, check_samplers_fit_resample, etc. Self-contained test functions validating one specific aspect of the API contract

Tag-driven Conditional Testing

Not all samplers support all input formats. The framework reads the estimator's tags via get_tags(sampler) to determine which checks apply:

tags = get_tags(sampler)
accept_sparse = tags.input_tags.sparse
accept_dataframe = tags.input_tags.dataframe
accept_string = tags.input_tags.string
allow_nan = tags.input_tags.allow_nan

Checks for sparse matrices, pandas DataFrames, string features, and NaN handling are only yielded when the corresponding tag is True. This prevents false failures for samplers that legitimately do not support certain input formats, while still ensuring that samplers which claim to support a format are actually tested for it.

Sampler Behavioral Contract

The core behavioral check (check_samplers_fit_resample) validates the fundamental invariant of each sampler category:

Sampler Type Base Class Expected Behavior
Over-sampler BaseOverSampler After resampling, every class has at least as many samples as the original majority class
Under-sampler BaseUnderSampler After resampling, every class has exactly the same count as the original minority class (with an exception for InstanceHardnessThreshold which approximates)
Cleaning sampler BaseCleaningSampler After resampling, every non-minority class has fewer samples than before (noisy samples removed)

This type-based dispatch uses isinstance checks:

if isinstance(sampler, BaseOverSampler):
    n_samples = max(target_stats.values())
    assert all(value >= n_samples for value in Counter(y_res).values())
elif isinstance(sampler, BaseUnderSampler):
    n_samples = min(target_stats.values())
    assert all(value == n_samples for value in Counter(y_res).values())
elif isinstance(sampler, BaseCleaningSampler):
    class_minority = min(target_stats, key=target_stats.get)
    assert all(
        target_stats[class_sample] > target_stats_res[class_sample]
        for class_sample in target_stats.keys()
        if class_sample != class_minority
    )

Input Format Consistency

A key design principle is that samplers must return the same container type they receive:

Input Format Expected Output Format Check Function
NumPy array NumPy array check_samplers_fit_resample
Sparse CSR matrix Sparse matrix (same format) check_samplers_sparse
Pandas DataFrame / Series DataFrame / Series with preserved column names check_samplers_pandas
Pandas sparse DataFrame Sparse DataFrame with preserved SparseDtype check_samplers_pandas_sparse
Python list Python list check_samplers_list

Additionally, check_samplers_preserve_dtype verifies that non-default dtypes (float32, int32) are not silently upcast to float64/int64.

OVA Encoding Consistency

The check_samplers_multiclass_ova check ensures that resampling with a multiclass integer target produces results consistent with resampling using the equivalent one-vs-all binarized target. This guarantees that samplers handle both label representations identically:

y_ova = label_binarize(y, classes=np.unique(y))
X_res, y_res = sampler.fit_resample(X, y)
X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova)
assert_allclose(X_res, X_res_ova)
assert_allclose(y_res, y_res_ova.argmax(axis=1))

Feature Names Protocol

Two checks validate the get_feature_names_out method:

  1. check_sampler_get_feature_names_out: Verifies that the method returns a numpy array of strings with length matching the number of output features, and that it raises ValueError for mismatched input feature counts.
  2. check_sampler_get_feature_names_out_pandas: When fitted on a DataFrame, verifies that default feature names match the DataFrame columns, that explicitly passed matching names produce the same result, and that non-matching names raise ValueError.

Parameter Validation

The check_param_validation check verifies that every constructor parameter listed in get_params(deep=False) has a corresponding entry in _parameter_constraints, and that invalid parameter types or values raise informative ValueError messages matching the pattern:

rf"The '{param_name}' parameter of {name} must be .* Got .* instead."

This ensures users receive actionable error messages rather than cryptic failures deep in the fitting logic.

Integration with pytest

The framework integrates with pytest through two mechanisms:

  1. parametrize_with_checks: A decorator that calls pytest.mark.parametrize with generated (estimator, check) pairs, using human-readable test IDs.
  2. estimator_checks_generator: A lower-level generator for manual iteration, supporting mark="xfail" or mark="skip" for known failures.

This allows test suites to be as simple as:

@parametrize_with_checks([SMOTE(), RandomUnderSampler()])
def test_common_checks(estimator, check):
    check(estimator)

while still supporting granular control over expected failures via the expected_failed_checks callback.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment