Overview
LinearRegression and LogisticRegression are interpretable glassbox model wrappers around scikit-learn's linear models, providing coefficient-based global and local explanations within the InterpretML framework.
Description
This module wraps scikit-learn's linear models to provide them with the InterpretML explainer API:
- BaseLinear: Abstract base class that implements shared functionality for linear models including fitting, prediction, and both global and local explanation generation. It stores per-feature density histograms and categorical unique values for visualization purposes.
- LinearRegression: Concrete class for regression tasks. Wraps scikit-learn's
LinearRegression (or any compatible linear class) via RegressorMixin and BaseLinear.
- LogisticRegression: Concrete class for classification tasks. Wraps scikit-learn's
LogisticRegression (or any compatible linear class) via ClassifierMixin and BaseLinear. Provides predict_proba for probability estimates.
- LinearExplanation: Custom explanation class extending
FeatureValueExplanation that visualizes global explanations as horizontal bar charts of coefficients (top 15 by absolute value) and per-feature effects as line/bar plots showing the coefficient multiplied by feature values. Supports both the standard visualization and an MLI-compatible provider.
Global explanations show the model coefficients with an intercept term. Local explanations show per-instance breakdowns where each feature's contribution is its coefficient multiplied by the feature value.
Usage
Use LinearRegression for regression tasks and LogisticRegression for classification tasks when you need a fully transparent linear model with built-in explanation capabilities. These are ideal for baseline interpretable models, regulatory compliance scenarios, or when model coefficients directly need to be communicated.
Code Reference
Source Location
Signature
class BaseLinear(ExplainerMixin):
available_explanations = ["local", "global"]
explainer_type = "model"
def __init__(self, feature_names=None, feature_types=None,
linear_class=SKLinear, **kwargs):
def fit(self, X, y):
def predict(self, X):
def explain_local(self, X, y=None, name=None):
def explain_global(self, name=None):
class LinearRegression(RegressorMixin, BaseLinear):
def __init__(self, feature_names=None, feature_types=None,
linear_class=SKLinear, **kwargs):
def fit(self, X, y):
class LogisticRegression(ClassifierMixin, BaseLinear):
def __init__(self, feature_names=None, feature_types=None,
linear_class=SKLogistic, **kwargs):
def fit(self, X, y):
def predict_proba(self, X):
Import
from interpret.glassbox import LinearRegression, LogisticRegression
I/O Contract
Constructor Inputs
| Name |
Type |
Required |
Description
|
| feature_names |
list of str |
No |
List of feature names
|
| feature_types |
list of str |
No |
List of feature types (e.g. "continuous", "nominal", "ordinal")
|
| linear_class |
class |
No |
A scikit-learn linear model class (default SKLinear for regression, SKLogistic for classification)
|
| **kwargs |
varies |
No |
Additional keyword arguments passed to the scikit-learn linear model constructor
|
fit Inputs
| Name |
Type |
Required |
Description
|
| X |
numpy array or compatible |
Yes |
Training feature matrix
|
| y |
numpy array |
Yes |
Training labels (1-dimensional)
|
explain_global Outputs
| Name |
Type |
Description
|
| explanation |
LinearExplanation |
Global explanation showing coefficients as bar chart and per-feature effects
|
explain_local Inputs / Outputs
| Name |
Type |
Required |
Description
|
| X |
numpy array or compatible |
Yes |
Instances to explain
|
| y |
numpy array |
No |
True labels for performance metrics
|
| name |
str |
No |
User-defined explanation name
|
| Name |
Type |
Description
|
| explanation |
FeatureValueExplanation |
Local explanation showing per-instance coefficient*value breakdowns
|
Usage Examples
Regression Example
from interpret.glassbox import LinearRegression
import numpy as np
X = np.random.randn(200, 4)
y = 3 * X[:, 0] - 2 * X[:, 1] + np.random.randn(200) * 0.1
lr = LinearRegression(feature_names=["f0", "f1", "f2", "f3"])
lr.fit(X, y)
# Global explanation
global_exp = lr.explain_global(name="Linear Regression")
global_exp.visualize(key=None) # Overall coefficients bar chart
# Local explanation
local_exp = lr.explain_local(X[:5], y[:5], name="Local Linear")
local_exp.visualize(key=0)
Classification Example
from interpret.glassbox import LogisticRegression
from sklearn.datasets import load_breast_cancer
X, y = load_breast_cancer(return_X_y=True)
log_reg = LogisticRegression()
log_reg.fit(X, y)
global_exp = log_reg.explain_global(name="Logistic Regression")
global_exp.visualize(key=None) # Top 15 coefficients
Related Pages