Principle:Scikit learn contrib Imbalanced learn Sampler Compatibility Checking
Principle: Scikit-learn-contrib Imbalanced-learn Sampler Compatibility Checking
Theory: The sampler compatibility checking framework extends scikit-learn's estimator check infrastructure to validate that sampler implementations comply with the imbalanced-learn API contract. It provides a systematic, automated approach to verifying behavioral correctness across a comprehensive set of input formats, edge cases, and semantic requirements.
Motivation
Scikit-learn provides check_estimator to verify that estimators conform to its API conventions (cloneability, get/set params, fit/predict interface). However, imbalanced-learn introduces a distinct API surface centered on the fit_resample method and the concept of sampling strategies. The standard scikit-learn checks cannot validate:
- Whether over-samplers actually increase minority class counts
- Whether under-samplers actually decrease majority class counts
- Whether cleaning samplers remove samples from non-minority classes
- Whether
sampling_strategy_is set after fitting - Whether resampled output preserves input container types (DataFrame, sparse matrix, list)
- Whether
sample_indices_tracking is consistent with advertised tags
The estimator_checks module fills this gap by providing a parallel set of checks tailored to the sampler protocol.
Architectural Design
The framework is organized around three layers:
| Layer | Components | Responsibility |
|---|---|---|
| Public API | parametrize_with_checks, estimator_checks_generator |
Entry points for test authors; integrate with pytest parametrization |
| Dispatch | _yield_all_checks, _yield_sampler_checks, _yield_classifier_checks |
Route estimators to applicable check suites based on capabilities and tags |
| Individual checks | check_target_type, check_samplers_fit_resample, etc. |
Self-contained test functions validating one specific aspect of the API contract |
Tag-driven Conditional Testing
Not all samplers support all input formats. The framework reads the estimator's tags via get_tags(sampler) to determine which checks apply:
tags = get_tags(sampler)
accept_sparse = tags.input_tags.sparse
accept_dataframe = tags.input_tags.dataframe
accept_string = tags.input_tags.string
allow_nan = tags.input_tags.allow_nan
Checks for sparse matrices, pandas DataFrames, string features, and NaN handling are only yielded when the corresponding tag is True. This prevents false failures for samplers that legitimately do not support certain input formats, while still ensuring that samplers which claim to support a format are actually tested for it.
Sampler Behavioral Contract
The core behavioral check (check_samplers_fit_resample) validates the fundamental invariant of each sampler category:
| Sampler Type | Base Class | Expected Behavior |
|---|---|---|
| Over-sampler | BaseOverSampler |
After resampling, every class has at least as many samples as the original majority class |
| Under-sampler | BaseUnderSampler |
After resampling, every class has exactly the same count as the original minority class (with an exception for InstanceHardnessThreshold which approximates)
|
| Cleaning sampler | BaseCleaningSampler |
After resampling, every non-minority class has fewer samples than before (noisy samples removed) |
This type-based dispatch uses isinstance checks:
if isinstance(sampler, BaseOverSampler):
n_samples = max(target_stats.values())
assert all(value >= n_samples for value in Counter(y_res).values())
elif isinstance(sampler, BaseUnderSampler):
n_samples = min(target_stats.values())
assert all(value == n_samples for value in Counter(y_res).values())
elif isinstance(sampler, BaseCleaningSampler):
class_minority = min(target_stats, key=target_stats.get)
assert all(
target_stats[class_sample] > target_stats_res[class_sample]
for class_sample in target_stats.keys()
if class_sample != class_minority
)
Input Format Consistency
A key design principle is that samplers must return the same container type they receive:
| Input Format | Expected Output Format | Check Function |
|---|---|---|
| NumPy array | NumPy array | check_samplers_fit_resample
|
| Sparse CSR matrix | Sparse matrix (same format) | check_samplers_sparse
|
| Pandas DataFrame / Series | DataFrame / Series with preserved column names | check_samplers_pandas
|
| Pandas sparse DataFrame | Sparse DataFrame with preserved SparseDtype | check_samplers_pandas_sparse
|
| Python list | Python list | check_samplers_list
|
Additionally, check_samplers_preserve_dtype verifies that non-default dtypes (float32, int32) are not silently upcast to float64/int64.
OVA Encoding Consistency
The check_samplers_multiclass_ova check ensures that resampling with a multiclass integer target produces results consistent with resampling using the equivalent one-vs-all binarized target. This guarantees that samplers handle both label representations identically:
y_ova = label_binarize(y, classes=np.unique(y))
X_res, y_res = sampler.fit_resample(X, y)
X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova)
assert_allclose(X_res, X_res_ova)
assert_allclose(y_res, y_res_ova.argmax(axis=1))
Feature Names Protocol
Two checks validate the get_feature_names_out method:
check_sampler_get_feature_names_out: Verifies that the method returns a numpy array of strings with length matching the number of output features, and that it raisesValueErrorfor mismatched input feature counts.check_sampler_get_feature_names_out_pandas: When fitted on a DataFrame, verifies that default feature names match the DataFrame columns, that explicitly passed matching names produce the same result, and that non-matching names raiseValueError.
Parameter Validation
The check_param_validation check verifies that every constructor parameter listed in get_params(deep=False) has a corresponding entry in _parameter_constraints, and that invalid parameter types or values raise informative ValueError messages matching the pattern:
rf"The '{param_name}' parameter of {name} must be .* Got .* instead."
This ensures users receive actionable error messages rather than cryptic failures deep in the fitting logic.
Integration with pytest
The framework integrates with pytest through two mechanisms:
parametrize_with_checks: A decorator that callspytest.mark.parametrizewith generated(estimator, check)pairs, using human-readable test IDs.estimator_checks_generator: A lower-level generator for manual iteration, supportingmark="xfail"ormark="skip"for known failures.
This allows test suites to be as simple as:
@parametrize_with_checks([SMOTE(), RandomUnderSampler()])
def test_common_checks(estimator, check):
check(estimator)
while still supporting granular control over expected failures via the expected_failed_checks callback.