Implementation:Online ml River Checks Framework
| Knowledge Sources | |
|---|---|
| Domains | Online_Learning, Testing, Quality_Assurance, API_Compliance |
| Last Updated | 2026-02-08 16:00 GMT |
Overview
Comprehensive unit testing framework for validating that River estimators adhere to API conventions and behave correctly.
Description
The checks framework provides a systematic way to validate estimator implementations through a suite of automated unit tests. The main entry point is check_estimator, which runs all applicable tests for a given model, while yield_checks generates individual test functions that can be run selectively.
The framework tests multiple aspects of estimators: basic Python conventions (repr, str, clone), parameter handling (defaults, mutability), state management (pickling, memory addresses), feature handling (emerging, disappearing, shuffled features), and stochastic behavior (seeding idempotence). The specific tests run depend on the estimator type (classifier, regressor, clusterer, etc.).
For each estimator type, the framework generates appropriate test datasets using _yield_datasets, which creates suitable data based on the model's characteristics. For example, classifiers get classification datasets, multi-output regressors get multi-target datasets, and recommendation models get user-item-rating data.
Tests can be skipped by implementing the _unit_test_skips method on estimators, which returns a set of test names to exclude. This allows models to opt out of tests that don't apply to their specific implementation constraints.
The framework includes specialized test modules for different estimator types: common tests for all models, clf for classifiers, reco for recommenders, anomaly for anomaly detectors, and model_selection for model selectors. This modular design allows for comprehensive coverage while maintaining flexibility.
Usage
Use check_estimator during development to ensure your custom estimator follows River's conventions. Run it in unit tests to catch API violations and behavioral issues. The framework is essential for contributing new estimators to River and for maintaining code quality in extensions.
Code Reference
Source Location
- Repository: Online_ml_River
- File: river/checks/__init__.py
Signature
def check_estimator(model: Estimator):
"""Check if a model adheres to River's conventions."""
...
def yield_checks(model: Estimator) -> typing.Iterator[typing.Callable]:
"""Generates unit tests for a given model."""
...
Import
from river import checks
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| model | Estimator | The estimator instance to check |
| Function | Return Type | Description |
|---|---|---|
| check_estimator(model) | None | Runs all checks, raises on failure |
| yield_checks(model) | Iterator[Callable] | Generates individual test functions |
| Category | Description |
|---|---|
| General | repr, str, tags, clone behavior, params, documentation |
| State Management | Pickling, memory addresses, mutable attributes |
| Feature Handling | Shuffling, emerging features, disappearing features |
| Stochastic | Seeding idempotence |
| Type-Specific | Classifier probabilities, anomaly detection metrics, etc. |
Usage Examples
from river import checks
from river import linear_model
from river import preprocessing
from river import compose
# Example 1: Check a simple estimator
model = linear_model.LinearRegression()
# This runs all applicable checks
checks.check_estimator(model)
print("All checks passed!")
# Example 2: Check a pipeline
pipeline = (
preprocessing.StandardScaler() |
linear_model.LogisticRegression()
)
checks.check_estimator(pipeline)
# Example 3: View all checks for a model
model = linear_model.LogisticRegression()
print("Checks for LogisticRegression:")
for i, check in enumerate(checks.yield_checks(model), 1):
print(f"{i}. {check.__name__}")
# Example 4: Run individual checks
from river import tree
model = tree.HoeffdingTreeClassifier()
# Run just one check
from river.checks import common
try:
common.check_repr(model.clone())
print("check_repr passed")
except AssertionError as e:
print(f"check_repr failed: {e}")
# Example 5: Custom estimator with unit tests
from river import base
class MyRegressor(base.Regressor):
def __init__(self, alpha=0.01):
self.alpha = alpha
self.weights = {}
def learn_one(self, x, y):
# Simple learning logic
for feature, value in x.items():
if feature not in self.weights:
self.weights[feature] = 0
self.weights[feature] += self.alpha * (y - self.predict_one(x)) * value
def predict_one(self, x):
return sum(self.weights.get(f, 0) * v for f, v in x.items())
# Test it
model = MyRegressor()
checks.check_estimator(model)
# Example 6: Skipping specific tests
class MyCustomClassifier(base.Classifier):
def __init__(self):
self.classes = set()
@property
def _multiclass(self):
return True
def learn_one(self, x, y):
self.classes.add(y)
def predict_proba_one(self, x):
# Equal probability for all classes
n = len(self.classes)
return {c: 1/n for c in self.classes} if n > 0 else {}
def _unit_test_skips(self):
# Skip tests that don't apply
return {
'check_emerging_features',
'check_disappearing_features'
}
model = MyCustomClassifier()
checks.check_estimator(model)
# Example 7: Running checks in pytest
def test_my_model():
"""Unit test using checks framework"""
from river import linear_model
from river import checks
model = linear_model.LinearRegression()
checks.check_estimator(model)
# Example 8: Debugging a failing check
model = linear_model.LogisticRegression()
for check in checks.yield_checks(model):
try:
check(model.clone())
print(f"✓ {check.__name__}")
except Exception as e:
print(f"✗ {check.__name__}: {e}")
# Example 9: Check custom transformer
from river import base
class MyTransformer(base.Transformer):
def __init__(self, scale=1.0):
self.scale = scale
def learn_one(self, x):
pass
def transform_one(self, x):
return {k: v * self.scale for k, v in x.items()}
transformer = MyTransformer()
checks.check_estimator(transformer)
# Example 10: Check model with default params
from river import naive_bayes
# The framework checks that default params work
model = naive_bayes.GaussianNB()
checks.check_estimator(model)
# Example 11: Check ensemble model
from river import ensemble
model = ensemble.BaggingClassifier(
model=tree.HoeffdingTreeClassifier(),
n_models=5
)
checks.check_estimator(model)
# Example 12: Development workflow
class MyNewEstimator(base.Classifier):
"""Custom classifier implementation"""
def __init__(self, param1=1.0):
self.param1 = param1
self._model_state = {}
@property
def _multiclass(self):
return True
def learn_one(self, x, y):
# Implementation
pass
def predict_proba_one(self, x):
# Implementation
return {}
@classmethod
def _unit_test_params(cls):
"""Provide default params for testing"""
yield {'param1': 1.0}
yield {'param1': 2.0}
def _unit_test_skips(self):
"""Skip inapplicable tests"""
return set()
# Test during development
model = MyNewEstimator()
try:
checks.check_estimator(model)
print("Ready for production!")
except AssertionError as e:
print(f"Fix needed: {e}")
# Example 13: Check all models in a module
import inspect
from river import naive_bayes
print("Checking all naive_bayes models:")
for name, obj in inspect.getmembers(naive_bayes):
if inspect.isclass(obj) and issubclass(obj, base.Estimator):
if obj != base.Estimator:
print(f"\nChecking {name}...")
try:
checks.check_estimator(obj())
print(f" ✓ {name} passes all checks")
except Exception as e:
print(f" ✗ {name} failed: {e}")
# Example 14: List available check modules
from river.checks import common, clf, anomaly, reco, model_selection
print("Available check modules:")
print("- common: General checks for all estimators")
print("- clf: Classifier-specific checks")
print("- anomaly: Anomaly detector checks")
print("- reco: Recommender system checks")
print("- model_selection: Model selector checks")
# Example 15: Generate test datasets for inspection
from river.checks import _yield_datasets
model = linear_model.LinearRegression()
print("Datasets generated for LinearRegression:")
for i, dataset in enumerate(_yield_datasets(model), 1):
print(f"\nDataset {i}:")
for j, (x, y) in enumerate(dataset):
print(f" Sample {j}: x={x}, y={y}")
if j >= 2:
break