Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Online ml River Checks Framework

From Leeroopedia


Knowledge Sources
Domains Online_Learning, Testing, Quality_Assurance, API_Compliance
Last Updated 2026-02-08 16:00 GMT

Overview

Comprehensive unit testing framework for validating that River estimators adhere to API conventions and behave correctly.

Description

The checks framework provides a systematic way to validate estimator implementations through a suite of automated unit tests. The main entry point is check_estimator, which runs all applicable tests for a given model, while yield_checks generates individual test functions that can be run selectively.

The framework tests multiple aspects of estimators: basic Python conventions (repr, str, clone), parameter handling (defaults, mutability), state management (pickling, memory addresses), feature handling (emerging, disappearing, shuffled features), and stochastic behavior (seeding idempotence). The specific tests run depend on the estimator type (classifier, regressor, clusterer, etc.).

For each estimator type, the framework generates appropriate test datasets using _yield_datasets, which creates suitable data based on the model's characteristics. For example, classifiers get classification datasets, multi-output regressors get multi-target datasets, and recommendation models get user-item-rating data.

Tests can be skipped by implementing the _unit_test_skips method on estimators, which returns a set of test names to exclude. This allows models to opt out of tests that don't apply to their specific implementation constraints.

The framework includes specialized test modules for different estimator types: common tests for all models, clf for classifiers, reco for recommenders, anomaly for anomaly detectors, and model_selection for model selectors. This modular design allows for comprehensive coverage while maintaining flexibility.

Usage

Use check_estimator during development to ensure your custom estimator follows River's conventions. Run it in unit tests to catch API violations and behavioral issues. The framework is essential for contributing new estimators to River and for maintaining code quality in extensions.

Code Reference

Source Location

Signature

def check_estimator(model: Estimator):
    """Check if a model adheres to River's conventions."""
    ...

def yield_checks(model: Estimator) -> typing.Iterator[typing.Callable]:
    """Generates unit tests for a given model."""
    ...

Import

from river import checks

I/O Contract

Input
Parameter Type Description
model Estimator The estimator instance to check
Output
Function Return Type Description
check_estimator(model) None Runs all checks, raises on failure
yield_checks(model) Iterator[Callable] Generates individual test functions
Test Categories
Category Description
General repr, str, tags, clone behavior, params, documentation
State Management Pickling, memory addresses, mutable attributes
Feature Handling Shuffling, emerging features, disappearing features
Stochastic Seeding idempotence
Type-Specific Classifier probabilities, anomaly detection metrics, etc.

Usage Examples

from river import checks
from river import linear_model
from river import preprocessing
from river import compose

# Example 1: Check a simple estimator
model = linear_model.LinearRegression()

# This runs all applicable checks
checks.check_estimator(model)
print("All checks passed!")

# Example 2: Check a pipeline
pipeline = (
    preprocessing.StandardScaler() |
    linear_model.LogisticRegression()
)

checks.check_estimator(pipeline)

# Example 3: View all checks for a model
model = linear_model.LogisticRegression()

print("Checks for LogisticRegression:")
for i, check in enumerate(checks.yield_checks(model), 1):
    print(f"{i}. {check.__name__}")

# Example 4: Run individual checks
from river import tree

model = tree.HoeffdingTreeClassifier()

# Run just one check
from river.checks import common

try:
    common.check_repr(model.clone())
    print("check_repr passed")
except AssertionError as e:
    print(f"check_repr failed: {e}")

# Example 5: Custom estimator with unit tests
from river import base

class MyRegressor(base.Regressor):
    def __init__(self, alpha=0.01):
        self.alpha = alpha
        self.weights = {}

    def learn_one(self, x, y):
        # Simple learning logic
        for feature, value in x.items():
            if feature not in self.weights:
                self.weights[feature] = 0
            self.weights[feature] += self.alpha * (y - self.predict_one(x)) * value

    def predict_one(self, x):
        return sum(self.weights.get(f, 0) * v for f, v in x.items())

# Test it
model = MyRegressor()
checks.check_estimator(model)

# Example 6: Skipping specific tests
class MyCustomClassifier(base.Classifier):
    def __init__(self):
        self.classes = set()

    @property
    def _multiclass(self):
        return True

    def learn_one(self, x, y):
        self.classes.add(y)

    def predict_proba_one(self, x):
        # Equal probability for all classes
        n = len(self.classes)
        return {c: 1/n for c in self.classes} if n > 0 else {}

    def _unit_test_skips(self):
        # Skip tests that don't apply
        return {
            'check_emerging_features',
            'check_disappearing_features'
        }

model = MyCustomClassifier()
checks.check_estimator(model)

# Example 7: Running checks in pytest
def test_my_model():
    """Unit test using checks framework"""
    from river import linear_model
    from river import checks

    model = linear_model.LinearRegression()
    checks.check_estimator(model)

# Example 8: Debugging a failing check
model = linear_model.LogisticRegression()

for check in checks.yield_checks(model):
    try:
        check(model.clone())
        print(f"✓ {check.__name__}")
    except Exception as e:
        print(f"✗ {check.__name__}: {e}")

# Example 9: Check custom transformer
from river import base

class MyTransformer(base.Transformer):
    def __init__(self, scale=1.0):
        self.scale = scale

    def learn_one(self, x):
        pass

    def transform_one(self, x):
        return {k: v * self.scale for k, v in x.items()}

transformer = MyTransformer()
checks.check_estimator(transformer)

# Example 10: Check model with default params
from river import naive_bayes

# The framework checks that default params work
model = naive_bayes.GaussianNB()
checks.check_estimator(model)

# Example 11: Check ensemble model
from river import ensemble

model = ensemble.BaggingClassifier(
    model=tree.HoeffdingTreeClassifier(),
    n_models=5
)

checks.check_estimator(model)

# Example 12: Development workflow
class MyNewEstimator(base.Classifier):
    """Custom classifier implementation"""

    def __init__(self, param1=1.0):
        self.param1 = param1
        self._model_state = {}

    @property
    def _multiclass(self):
        return True

    def learn_one(self, x, y):
        # Implementation
        pass

    def predict_proba_one(self, x):
        # Implementation
        return {}

    @classmethod
    def _unit_test_params(cls):
        """Provide default params for testing"""
        yield {'param1': 1.0}
        yield {'param1': 2.0}

    def _unit_test_skips(self):
        """Skip inapplicable tests"""
        return set()

# Test during development
model = MyNewEstimator()
try:
    checks.check_estimator(model)
    print("Ready for production!")
except AssertionError as e:
    print(f"Fix needed: {e}")

# Example 13: Check all models in a module
import inspect
from river import naive_bayes

print("Checking all naive_bayes models:")
for name, obj in inspect.getmembers(naive_bayes):
    if inspect.isclass(obj) and issubclass(obj, base.Estimator):
        if obj != base.Estimator:
            print(f"\nChecking {name}...")
            try:
                checks.check_estimator(obj())
                print(f"  ✓ {name} passes all checks")
            except Exception as e:
                print(f"  ✗ {name} failed: {e}")

# Example 14: List available check modules
from river.checks import common, clf, anomaly, reco, model_selection

print("Available check modules:")
print("- common: General checks for all estimators")
print("- clf: Classifier-specific checks")
print("- anomaly: Anomaly detector checks")
print("- reco: Recommender system checks")
print("- model_selection: Model selector checks")

# Example 15: Generate test datasets for inspection
from river.checks import _yield_datasets

model = linear_model.LinearRegression()

print("Datasets generated for LinearRegression:")
for i, dataset in enumerate(_yield_datasets(model), 1):
    print(f"\nDataset {i}:")
    for j, (x, y) in enumerate(dataset):
        print(f"  Sample {j}: x={x}, y={y}")
        if j >= 2:
            break

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment