Implementation:Online ml River Drift DriftRetrainingClassifier
| Knowledge Sources | Domains | Last Updated |
|---|---|---|
| River River Docs | Online Machine Learning, Concept Drift, Meta-Learning | 2026-02-08 16:00 GMT |
Overview
Concrete tool for wrapping any base classifier with automatic drift detection and model retraining, providing a modular meta-learning approach to concept drift adaptation.
Description
The drift.DriftRetrainingClassifier is a wrapper class that composes a base classifier with a drift detector. At each learn_one call, the wrapper first makes a prediction using the current model, computes a binary error indicator (0 = correct, 1 = incorrect), and feeds it to the drift detector. Depending on the detector's state and the train_in_background setting:
- Warning detected + background training enabled: The background model (
bkg_model) is trained on the current sample. - Drift detected + background training enabled: The primary model is replaced by the background model, and a fresh background model is cloned.
- Drift detected + background training disabled: The primary model is cloned (reset to a fresh state).
The class inherits from both base.Wrapper and base.Classifier, making it compatible with River's evaluation and pipeline infrastructure.
Usage
Import drift.DriftRetrainingClassifier when you want to add drift adaptation to any River classifier (e.g., HoeffdingTreeClassifier, LogisticRegression, GaussianNB) without modifying the classifier itself.
Code Reference
Source Location
river/drift/retrain.py:L6-L100
Signature
class DriftRetrainingClassifier(base.Wrapper, base.Classifier):
def __init__(
self,
model: base.Classifier,
drift_detector: base.DriftAndWarningDetector
| base.BinaryDriftAndWarningDetector
| None = None,
train_in_background: bool = True,
)
Import
from river import drift
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
base.Classifier | (required) | The base classifier to wrap. Will be cloned upon drift detection |
drift_detector |
DriftAndWarningDetector or None | None (defaults to drift.binary.DDM()) |
The drift detector for monitoring prediction errors. Must support both warning and drift signals when train_in_background=True
|
train_in_background |
bool | True | If True, train a background model during warning phase and swap on drift. If False, simply reset the model on drift |
I/O Contract
Inputs
| Method | Parameter | Type | Description |
|---|---|---|---|
learn_one |
x | dict | Feature dictionary |
learn_one |
y | ClfTarget | Target class label |
learn_one |
**kwargs | dict | Additional keyword arguments passed to the base model |
predict_proba_one |
x | dict | Feature dictionary |
predict_proba_one |
**kwargs | dict | Additional keyword arguments passed to the base model |
Outputs
| Method | Return Type | Description |
|---|---|---|
predict_proba_one(x) |
dict | Class probability distribution (delegated to the wrapped model) |
predict_one(x) |
ClfTarget | Predicted class label (delegated to the wrapped model) |
Internal State
| Attribute | Type | Description |
|---|---|---|
model |
base.Classifier | The current primary model (may be swapped on drift) |
bkg_model |
base.Classifier | The background model (exists only when train_in_background=True)
|
drift_detector |
DriftDetector | The drift detector instance; inspect drift_detector.drift_detected for state
|
Usage Examples
Basic Usage with DDM Drift Detector
from river import datasets, evaluate, drift, metrics, tree
dataset = datasets.Elec2().take(3000)
model = drift.DriftRetrainingClassifier(
model=tree.HoeffdingTreeClassifier(),
drift_detector=drift.binary.DDM()
)
metric = metrics.Accuracy()
evaluate.progressive_val_score(dataset, model, metric)
# Accuracy: 86.46%
Without Background Training
from river import drift, tree
model = drift.DriftRetrainingClassifier(
model=tree.HoeffdingTreeClassifier(),
drift_detector=drift.binary.DDM(),
train_in_background=False
)
# On drift, the model is immediately reset (no background training)
With Logistic Regression and Pipeline
from river import drift, linear_model, preprocessing
model = drift.DriftRetrainingClassifier(
model=preprocessing.StandardScaler() | linear_model.LogisticRegression(),
drift_detector=drift.binary.DDM()
)
Inspecting Drift Detection State
from river import drift, tree, datasets
model = drift.DriftRetrainingClassifier(
model=tree.HoeffdingTreeClassifier(),
drift_detector=drift.binary.DDM()
)
n_drifts = 0
for x, y in datasets.Elec2().take(5000):
model.learn_one(x, y)
if model.drift_detector.drift_detected:
n_drifts += 1
print(f"Drift #{n_drifts} detected")