Implementation:DistrictDataLabs Yellowbrick StatsModelsWrapper
| Knowledge Sources | |
|---|---|
| Domains | Regression, Utilities |
| Last Updated | 2026-02-08 05:00 GMT |
Overview
Wrapper class that adapts a statsmodels GLM model to the scikit-learn estimator interface for use with Yellowbrick visualizers.
Description
The StatsModelsWrapper wraps a statsmodels generalized linear model (GLM) using a partial function pattern so it can be used with Yellowbrick and scikit-learn tools. It implements fit, predict, and score methods following the sklearn convention, using R2 score as the default metric.
Usage
Import StatsModelsWrapper when you want to use Yellowbrick regression visualizers with statsmodels GLM models instead of scikit-learn estimators.
Code Reference
Source Location
- Repository: DistrictDataLabs_Yellowbrick
- File: yellowbrick/contrib/statsmodels/base.py
- Lines: 1-85
Signature
class StatsModelsWrapper(BaseEstimator):
def __init__(self, glm_partial, stated_estimator_type="regressor", scorer=r2_score):
"""Wraps statsmodels GLM as sklearn-compatible estimator."""
def fit(self, X, y): ...
def predict(self, X): ...
def score(self, X, y): ...
Import
from yellowbrick.contrib.statsmodels import StatsModelsWrapper
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| glm_partial | functools.partial | Yes | Partial function wrapping a statsmodels GLM constructor |
| stated_estimator_type | str | No | Estimator type (default: "regressor") |
| scorer | callable | No | Scoring function (default: r2_score) |
Outputs
| Name | Type | Description |
|---|---|---|
| predict() | array-like | Model predictions |
| score() | float | R2 score (or custom metric) |
Usage Examples
from functools import partial
import statsmodels.api as sm
from yellowbrick.contrib.statsmodels import StatsModelsWrapper
from yellowbrick.regressor import ResidualsPlot
glm_partial = partial(sm.GLM, family=sm.families.Gaussian())
model = StatsModelsWrapper(glm_partial)
viz = ResidualsPlot(model)
viz.fit(X_train, y_train)
viz.score(X_test, y_test)
viz.show()