Implementation:DistrictDataLabs Yellowbrick ClassPredictionError Visualizer
| Knowledge Sources | |
|---|---|
| Domains | Machine_Learning, Classification, Visualization |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for visualizing per-class prediction errors as a stacked bar chart showing how instances of each true class are distributed across predicted classes, provided by the Yellowbrick library.
Description
The ClassPredictionError class is a classification score visualizer that generates a stacked bar chart where each bar represents a true class and the segments within each bar represent the predicted class distribution for that class's instances. The visualizer computes a predictions matrix by cross-tabulating true and predicted labels, then renders this as a stacked bar plot. Correctly classified instances form one segment, while misclassified instances appear as additional colored segments, making it straightforward to identify which classes are most frequently confused.
The companion quick method class_prediction_error() provides a one-call interface that instantiates, fits, scores, and renders the visualizer.
Usage
Use ClassPredictionError when you want an intuitive bar-chart view of classification results. Import it when you need to communicate error patterns to stakeholders who prefer bar charts over heatmaps, or when you want to quickly identify which classes absorb misclassified instances.
Code Reference
Source Location
- Repository: yellowbrick
- File: yellowbrick/classifier/class_prediction_error.py
- Class Lines: L107-186 (ClassPredictionError class)
- Quick Method Lines: L244-361 (class_prediction_error function)
Signature
class ClassPredictionError(ClassificationScoreVisualizer):
def __init__(
self,
estimator,
ax=None,
classes=None,
encoder=None,
is_fitted="auto",
force_model=False,
**kwargs
)
def score(self, X, y)
def class_prediction_error(
estimator,
X_train,
y_train,
X_test=None,
y_test=None,
ax=None,
classes=None,
encoder=None,
is_fitted="auto",
force_model=False,
show=True,
**kwargs
)
Import
from yellowbrick.classifier import ClassPredictionError
from yellowbrick.classifier.class_prediction_error import class_prediction_error
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| estimator | sklearn classifier | Yes | A scikit-learn classifier to evaluate |
| ax | matplotlib Axes | No | Axes object on which to draw the bar chart; uses current axes if not provided |
| classes | list of str | No | Human-readable class labels for the x-axis and legend |
| encoder | dict or LabelEncoder | No | Mapping from target values to human-readable labels |
| is_fitted | bool or str | No | Whether the estimator is already fitted; defaults to "auto" |
| force_model | bool | No | If True, skip the classifier type check on the estimator |
Outputs
| Name | Type | Description |
|---|---|---|
| score_ | float | Global accuracy score from the underlying estimator |
| predictions_ | ndarray | Matrix of shape (n_classes, n_classes) where rows are true classes and columns are predicted classes, containing counts |
| ax | matplotlib Axes | The axes with the rendered stacked bar chart |
Usage Examples
Basic Usage
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from yellowbrick.classifier import ClassPredictionError
from yellowbrick.datasets import load_occupancy
X, y = load_occupancy()
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
viz = ClassPredictionError(RandomForestClassifier(n_estimators=10))
viz.fit(X_train, y_train)
viz.score(X_test, y_test)
viz.show()
Quick Method
from sklearn.ensemble import RandomForestClassifier
from yellowbrick.classifier.class_prediction_error import class_prediction_error
from yellowbrick.datasets import load_occupancy
X, y = load_occupancy()
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
class_prediction_error(RandomForestClassifier(n_estimators=10), X_train, y_train, X_test, y_test)