Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Sktime Pytorch forecasting Check Estimator

From Leeroopedia


Knowledge Sources
Domains Time_Series, Forecasting, Deep_Learning
Last Updated 2026-02-08 08:00 GMT

Overview

check_estimator and parametrize_with_checks are utility functions for running the unified API conformance test suite against any pytorch-forecasting estimator or object.

Description

check_estimator runs all applicable API conformance tests on a given estimator class or instance, returning a dictionary of pass/fail results. It resolves the appropriate test classes for the estimator's type hierarchy, supports filtering by test name or fixture, and can either raise exceptions directly or collect them in the results dictionary. parametrize_with_checks is a companion decorator that generates pytest parametrize decorators for running conformance tests across one or more estimator objects, inspired by scikit-learn's utility of the same name.

Usage

Use check_estimator during development and CI to verify that a new or modified model, metric, or other pytorch-forecasting object conforms to the unified API contract. Use parametrize_with_checks in pytest test suites to systematically parametrize conformance tests across multiple estimators with minimal boilerplate.

Code Reference

Source Location

Signature

def check_estimator(
    estimator,
    raise_exceptions=False,
    tests_to_run=None,
    fixtures_to_run=None,
    verbose=True,
    tests_to_exclude=None,
    fixtures_to_exclude=None,
):
def parametrize_with_checks(objs, obj_varname="obj", check_varname="test_name"):

Import

from pytorch_forecasting.utils import check_estimator
from pytorch_forecasting.utils import parametrize_with_checks

I/O Contract

Inputs (check_estimator)

Name Type Required Description
estimator class or instance Yes Any pytorch-forecasting estimator class or instance for which suite tests exist
raise_exceptions bool No If True, raises exceptions as they occur; if False, returns them in the results dict; defaults to False
tests_to_run str or list[str] No Names of specific tests to run; defaults to running all applicable tests
fixtures_to_run str or list[str] No Pytest test-fixture combination codes to run; combined as union with tests_to_run if both provided
verbose int or bool No Verbosity level: 0/False for no output, 1/True for summary, 2 for full test output; defaults to True
tests_to_exclude str or list[str] No Names of tests to exclude after subsetting; defaults to None
fixtures_to_exclude str or list[str] No Test-fixture combinations to exclude after subsetting; defaults to None

Inputs (parametrize_with_checks)

Name Type Required Description
objs class, instance, or list thereof Yes Estimator objects to generate parametrized test names for
obj_varname str No Variable name for objects in the parametrization; defaults to "obj"
check_varname str No Variable name for test name strings in the parametrization; defaults to "test_name"

Outputs

Name Type Description
check_estimator return dict Dictionary mapping test/fixture strings to "PASSED" or the exception raised
parametrize_with_checks return pytest.mark.parametrize A pytest parametrize decorator for use on test functions

Usage Examples

from pytorch_forecasting.models import NBeats
from pytorch_forecasting.utils import check_estimator

# Run all conformance tests for the NBeats model class
results = check_estimator(NBeats)
# Output: All tests PASSED!

# Run a specific test across all fixtures
results = check_estimator(NBeats, tests_to_run="test_pkg_linkage")

# Run a single test-fixture combination
results = check_estimator(
    NBeats, fixtures_to_run="test_pkg_linkage[NBeats_pkg-NBeats]"
)
from pytorch_forecasting.utils import parametrize_with_checks, check_estimator
from pytorch_forecasting.models import DecoderMLP, NBeats

@parametrize_with_checks([NBeats, DecoderMLP])
def test_sktime_compatible_estimators(obj, test_name):
    check_estimator(obj, tests_to_run=test_name, raise_exceptions=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment