Overview
The DriftDetector classes define the interface for concept drift detection algorithms in River, supporting detection of distributional changes in data streams with optional warning capabilities.
Description
River provides four drift detector base classes with increasing capability. The private _BaseDriftDetector and _BaseDriftAndWarningDetector classes provide internal state management. The public DriftDetector class defines the interface for detectors that monitor numeric or continuous values, while BinaryDriftDetector specializes for boolean inputs. Both have variants (DriftAndWarningDetector and BinaryDriftAndWarningDetector) that can issue early warnings before drift is confirmed. All detectors maintain drift_detected and optionally warning_detected properties that indicate the current detection state after each update.
Usage
Use DriftDetector as the parent class when implementing drift detection algorithms for continuous data streams. Use BinaryDriftDetector for algorithms that monitor binary signals like prediction correctness. Extend the "AndWarning" variants if your algorithm can provide early warning signals. All detectors must implement an update method that accepts new observations and internally updates the drift detection state.
Code Reference
Source Location
Signature
class DriftDetector(_BaseDriftDetector):
"""A drift detector."""
@abc.abstractmethod
def update(self, x: int | float) -> None
class DriftAndWarningDetector(DriftDetector, _BaseDriftAndWarningDetector):
"""A drift detector that is also capable of issuing warnings."""
class BinaryDriftDetector(_BaseDriftDetector):
"""A drift detector for binary data."""
@abc.abstractmethod
def update(self, x: bool) -> None
class BinaryDriftAndWarningDetector(BinaryDriftDetector, _BaseDriftAndWarningDetector):
"""A binary drift detector that is also capable of issuing warnings."""
# Private base classes (not exposed)
class _BaseDriftDetector(base.Base):
def __init__(self) -> None
def _reset(self) -> None
@property
def drift_detected(self) -> bool
class _BaseDriftAndWarningDetector(_BaseDriftDetector):
def __init__(self) -> None
def _reset(self) -> None
@property
def warning_detected(self) -> bool
Import
from river.base import DriftDetector, DriftAndWarningDetector
from river.base import BinaryDriftDetector, BinaryDriftAndWarningDetector
I/O Contract
DriftDetector.update
| Parameter |
Type |
Description
|
| x |
float |
A numeric value to monitor for drift
|
BinaryDriftDetector.update
| Parameter |
Type |
Description
|
| x |
bool |
A boolean value to monitor for drift (e.g., prediction correctness)
|
Properties
| Property |
Type |
Description
|
| drift_detected |
bool |
True if drift was detected after the last update call
|
| warning_detected |
bool |
True if warning was detected (only for "AndWarning" variants)
|
Internal Methods
| Method |
Description
|
| _reset() |
Reset the detector's internal state (called when drift is detected)
|
Usage Examples
from river import drift
from river import datasets
from river import tree
from river import metrics
# Using a binary drift detector
detector = drift.ADWIN()
model = tree.HoeffdingTreeClassifier()
for x, y in datasets.Phishing():
# Make prediction
y_pred = model.predict_one(x)
# Update detector with prediction correctness
is_correct = (y_pred == y)
detector.update(is_correct)
# Check if drift detected
if detector.drift_detected:
print("Drift detected! Resetting model...")
model = tree.HoeffdingTreeClassifier()
# Train model
model.learn_one(x, y)
# Using a detector with warnings
detector = drift.DDM()
for x, y in datasets.Phishing():
y_pred = model.predict_one(x)
is_correct = (y_pred == y)
detector.update(is_correct)
if detector.warning_detected:
print("Warning: potential drift ahead")
if detector.drift_detected:
print("Drift confirmed!")
model = tree.HoeffdingTreeClassifier()
model.learn_one(x, y)
# Implementing a custom drift detector
from river.base import BinaryDriftAndWarningDetector
class SimpleErrorRateDetector(BinaryDriftAndWarningDetector):
def __init__(self, warning_threshold=0.3, drift_threshold=0.5, window_size=100):
super().__init__()
self.warning_threshold = warning_threshold
self.drift_threshold = drift_threshold
self.window_size = window_size
self.errors = []
def update(self, x):
# Add to window
self.errors.append(0 if x else 1)
if len(self.errors) > self.window_size:
self.errors.pop(0)
# Check error rate
if len(self.errors) >= self.window_size:
error_rate = sum(self.errors) / len(self.errors)
self._warning_detected = error_rate > self.warning_threshold
self._drift_detected = error_rate > self.drift_threshold
if self._drift_detected:
self._reset()
self.errors = []
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.