Heuristic:DistrictDataLabs Yellowbrick Model Fitted State Detection
| Knowledge Sources | |
|---|---|
| Domains | Machine_Learning, Compatibility |
| Last Updated | 2026-02-08 05:00 GMT |
Overview
Multi-strategy fallback approach for detecting whether a scikit-learn estimator has been fitted, handling edge cases for clustering models and third-party estimators.
Description
Determining whether a scikit-learn estimator has been fitted is non-trivial because sklearn does not provide a universal `is_fitted` method (as of the versions Yellowbrick supports). Yellowbrick implements a cascading detection strategy: (1) attempt to call `predict()` and catch `NotFittedError`, (2) if the model lacks `predict` (e.g., PCA, LDA, Agglomerative clustering), fall back to checking for common fitted attributes (`coef_`, `labels_`, `components_`, etc.), (3) assume fitted if any non-NotFittedError exception is raised. A configurable `is_fitted_by` parameter allows users to override automatic detection for third-party estimators that don't follow sklearn conventions.
Usage
Apply this heuristic in any Yellowbrick visualizer that needs to decide whether to call `fit()` on a model or use it as-is. This is critical for the `ModelVisualizer` base class, which supports both the "fit-then-visualize" and "visualize-already-fitted" workflows. The detection is called automatically via `check_fitted()` when `is_fitted_by="auto"` (the default).
The Insight (Rule of Thumb)
- Action: Use the cascading detection: `predict()` -> attribute check -> assume fitted. Provide a manual `is_fitted_by` override for edge cases.
- Value: The attribute check list covers 11 common fitted attributes: `coef_`, `estimator_`, `labels_`, `n_clusters_`, `children_`, `components_`, `n_components_`, `n_iter_`, `n_batch_iter_`, `explained_variance_`, `singular_values_`, `mean_`. Uses `any` (not `all`) matching.
- Trade-off: The `predict(np.zeros((7,3)))` approach creates a temporary array and invokes the model, which could be slow for very large models. However, it is the most reliable method since `NotFittedError` is sklearn's canonical signal. The magic shape `(7, 3)` is an arbitrary small array unlikely to match any real data constraint.
- Override: Set `is_fitted_by=True` or `is_fitted_by=False` for third-party estimators (Keras, XGBoost, CatBoost) that may not implement the sklearn API precisely.
Reasoning
The sklearn ecosystem lacks a standardized "is fitted" check because the convention of trailing-underscore attributes (`coef_`, `classes_`) is a naming convention, not an enforced API contract. Different model types store different attributes after fitting: regressors store `coef_`, clustering stores `labels_`, dimensionality reduction stores `components_`, etc. Rather than maintaining an exhaustive mapping of model types to attributes, Yellowbrick uses the pragmatic approach of trying `predict()` first (which internally triggers sklearn's own `check_is_fitted`) and falling back to a broad attribute scan.
The `check_fitted()` wrapper adds a "manual override" layer because real-world pipelines often include non-sklearn estimators from libraries like XGBoost, CatBoost, or Keras that may raise unexpected exceptions when `predict()` is called on unfitted models. The `is_fitted_by` parameter allows users to bypass automatic detection entirely.
Code Evidence
Cascading fitted detection from `yellowbrick/utils/helpers.py:38-81`:
def is_fitted(estimator):
try:
estimator.predict(np.zeros((7, 3)))
except sklearn.exceptions.NotFittedError:
return False
except AttributeError:
# Some clustering models (LDA, PCA, Agglomerative) don't implement predict
try:
check_is_fitted(
estimator,
[
"coef_", "estimator_", "labels_", "n_clusters_",
"children_", "components_", "n_components_",
"n_iter_", "n_batch_iter_", "explained_variance_",
"singular_values_", "mean_",
],
all_or_any=any,
)
return True
except sklearn.exceptions.NotFittedError:
return False
except Exception:
# Assume it's fitted, since NotFittedError wasn't raised
return True
return True
Manual override wrapper from `yellowbrick/utils/helpers.py:84-117`:
def check_fitted(estimator, is_fitted_by="auto", **kwargs):
if isinstance(is_fitted_by, str) and is_fitted_by.lower() == "auto":
return is_fitted(estimator)
return bool(is_fitted_by)
Proxy attribute limitation note from `yellowbrick/classifier/base.py:219`:
# NOTE: cannot test if hasattr(self, "classes_") because it will be proxied.