Implementation:DistrictDataLabs Yellowbrick DiscriminationThreshold Visualizer
| Knowledge Sources | |
|---|---|
| Domains | Machine_Learning, Classification, Visualization |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for visualizing how precision, recall, F-score, and queue rate change across discrimination thresholds for binary classifiers, provided by the Yellowbrick library.
Description
The DiscriminationThreshold class is a model visualizer that evaluates a binary probabilistic classifier by running multiple trials of shuffled train-test splits and computing precision, recall, F-score, and queue rate at a uniform set of thresholds from 0 to 1. The results are aggregated across trials using quantile-based statistics, producing median curves with confidence bands. The visualizer can annotate the threshold that maximizes a selected metric (defaulting to F-score) and supports excluding specific metrics from the display.
Unlike most Yellowbrick classification visualizers, DiscriminationThreshold does not inherit from ClassificationScoreVisualizer. Instead, it extends ModelVisualizer directly and performs its own internal splitting and fitting. This is because it needs to evaluate the model across multiple random splits to produce stable threshold curves. The visualizer requires a binary classifier that implements either predict_proba or decision_function.
The companion quick method discrimination_threshold() provides a one-call interface that instantiates the visualizer, fits it, and renders the plot.
Usage
Use DiscriminationThreshold when tuning the operating threshold of a binary probabilistic classifier to optimize a specific metric. Import it when the default 0.5 threshold is suboptimal and you need to visualize the precision-recall-queue rate tradeoff across thresholds.
Code Reference
Source Location
- Repository: yellowbrick
- File: yellowbrick/classifier/threshold.py
- Class Lines: L182-314 (DiscriminationThreshold class)
- Quick Method Lines: L526-698 (discrimination_threshold function)
Signature
class DiscriminationThreshold(ModelVisualizer):
def __init__(
self,
estimator,
ax=None,
n_trials=50,
cv=0.1,
fbeta=1.0,
argmax="fscore",
exclude=None,
quantiles=np.array([0.1, 0.5, 0.9]),
random_state=None,
is_fitted="auto",
force_model=False,
**kwargs
)
def fit(self, X, y, **kwargs)
def discrimination_threshold(
estimator,
X,
y,
ax=None,
n_trials=50,
cv=0.1,
fbeta=1.0,
argmax="fscore",
exclude=None,
quantiles=np.array([0.1, 0.5, 0.9]),
random_state=None,
is_fitted="auto",
force_model=False,
show=True,
**kwargs
)
Import
from yellowbrick.classifier import DiscriminationThreshold
from yellowbrick.classifier.threshold import discrimination_threshold
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| estimator | sklearn classifier | Yes | A probabilistic binary classifier with predict_proba or decision_function |
| ax | matplotlib Axes | No | Axes object on which to draw the plot; uses current axes if not provided |
| n_trials | int | No | Number of shuffled train-test splits to average over; defaults to 50 |
| cv | float or CV generator | No | Test split fraction (float) or cross-validation generator; defaults to 0.1 |
| fbeta | float | No | Beta parameter for F-beta score weighting; defaults to 1.0 (standard F1) |
| argmax | str or None | No | Metric to maximize and annotate on the plot; defaults to "fscore" |
| exclude | str or list | No | Metrics to omit from the visualization; can include "precision", "recall", "queue_rate", "fscore" |
| quantiles | array-like of 3 floats | No | Lower, median, and upper quantiles for confidence bands; defaults to [0.1, 0.5, 0.9] |
| random_state | int or None | No | Seed for reproducible shuffling of train-test splits |
| is_fitted | bool or str | No | Whether the estimator is already fitted; defaults to "auto" |
| force_model | bool | No | If True, skip the classifier and probabilistic type checks |
Outputs
| Name | Type | Description |
|---|---|---|
| thresholds_ | ndarray | Uniform array of threshold values from 0.0 to 1.0 |
| cv_scores_ | dict | Dictionary mapping metric names (and their _lower/_upper variants) to arrays of aggregated scores at each threshold |
| ax | matplotlib Axes | The axes with the rendered threshold curves and confidence bands |
Usage Examples
Basic Usage
from sklearn.linear_model import LogisticRegression
from yellowbrick.classifier import DiscriminationThreshold
from yellowbrick.datasets import load_spam
X, y = load_spam()
viz = DiscriminationThreshold(LogisticRegression(), n_trials=50, random_state=42)
viz.fit(X, y)
viz.show()
Quick Method
from sklearn.linear_model import LogisticRegression
from yellowbrick.classifier.threshold import discrimination_threshold
from yellowbrick.datasets import load_spam
X, y = load_spam()
discrimination_threshold(LogisticRegression(), X, y, random_state=42)