Implementation:Sktime Pytorch forecasting Check Estimator
| Knowledge Sources | |
|---|---|
| Domains | Time_Series, Forecasting, Deep_Learning |
| Last Updated | 2026-02-08 08:00 GMT |
Overview
check_estimator and parametrize_with_checks are utility functions for running the unified API conformance test suite against any pytorch-forecasting estimator or object.
Description
check_estimator runs all applicable API conformance tests on a given estimator class or instance, returning a dictionary of pass/fail results. It resolves the appropriate test classes for the estimator's type hierarchy, supports filtering by test name or fixture, and can either raise exceptions directly or collect them in the results dictionary. parametrize_with_checks is a companion decorator that generates pytest parametrize decorators for running conformance tests across one or more estimator objects, inspired by scikit-learn's utility of the same name.
Usage
Use check_estimator during development and CI to verify that a new or modified model, metric, or other pytorch-forecasting object conforms to the unified API contract. Use parametrize_with_checks in pytest test suites to systematically parametrize conformance tests across multiple estimators with minimal boilerplate.
Code Reference
Source Location
- Repository: Sktime_Pytorch_forecasting
- File: pytorch_forecasting/utils/_estimator_checks.py
- Lines: 1-250
Signature
def check_estimator(
estimator,
raise_exceptions=False,
tests_to_run=None,
fixtures_to_run=None,
verbose=True,
tests_to_exclude=None,
fixtures_to_exclude=None,
):
def parametrize_with_checks(objs, obj_varname="obj", check_varname="test_name"):
Import
from pytorch_forecasting.utils import check_estimator
from pytorch_forecasting.utils import parametrize_with_checks
I/O Contract
Inputs (check_estimator)
| Name | Type | Required | Description |
|---|---|---|---|
| estimator | class or instance | Yes | Any pytorch-forecasting estimator class or instance for which suite tests exist |
| raise_exceptions | bool | No | If True, raises exceptions as they occur; if False, returns them in the results dict; defaults to False |
| tests_to_run | str or list[str] | No | Names of specific tests to run; defaults to running all applicable tests |
| fixtures_to_run | str or list[str] | No | Pytest test-fixture combination codes to run; combined as union with tests_to_run if both provided |
| verbose | int or bool | No | Verbosity level: 0/False for no output, 1/True for summary, 2 for full test output; defaults to True |
| tests_to_exclude | str or list[str] | No | Names of tests to exclude after subsetting; defaults to None |
| fixtures_to_exclude | str or list[str] | No | Test-fixture combinations to exclude after subsetting; defaults to None |
Inputs (parametrize_with_checks)
| Name | Type | Required | Description |
|---|---|---|---|
| objs | class, instance, or list thereof | Yes | Estimator objects to generate parametrized test names for |
| obj_varname | str | No | Variable name for objects in the parametrization; defaults to "obj" |
| check_varname | str | No | Variable name for test name strings in the parametrization; defaults to "test_name" |
Outputs
| Name | Type | Description |
|---|---|---|
| check_estimator return | dict | Dictionary mapping test/fixture strings to "PASSED" or the exception raised |
| parametrize_with_checks return | pytest.mark.parametrize | A pytest parametrize decorator for use on test functions |
Usage Examples
from pytorch_forecasting.models import NBeats
from pytorch_forecasting.utils import check_estimator
# Run all conformance tests for the NBeats model class
results = check_estimator(NBeats)
# Output: All tests PASSED!
# Run a specific test across all fixtures
results = check_estimator(NBeats, tests_to_run="test_pkg_linkage")
# Run a single test-fixture combination
results = check_estimator(
NBeats, fixtures_to_run="test_pkg_linkage[NBeats_pkg-NBeats]"
)
from pytorch_forecasting.utils import parametrize_with_checks, check_estimator
from pytorch_forecasting.models import DecoderMLP, NBeats
@parametrize_with_checks([NBeats, DecoderMLP])
def test_sktime_compatible_estimators(obj, test_name):
check_estimator(obj, tests_to_run=test_name, raise_exceptions=True)