Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Online ml River Forest ARFClassifier

From Leeroopedia


Knowledge Sources Domains Last Updated
River River Docs Adaptive Random Forests for Evolving Data Stream Classification Online Machine Learning, Concept Drift, Ensemble Learning 2026-02-08 16:00 GMT

Overview

Concrete tool for building an online adaptive random forest classifier that uses per-tree drift detection, background tree replacement, Poisson resampling, and random feature subsets for drift-adaptive ensemble classification.

Description

The forest.ARFClassifier class implements the Adaptive Random Forest ensemble. It extends BaseForest (which inherits from base.Ensemble) and uses BaseTreeClassifier (a feature-randomized Hoeffding Tree) as its base learner. The ensemble manages n_models trees, each with its own ADWIN-based drift and warning detectors.

During learn_one, each tree receives the instance with a Poisson-distributed weight. Per-tree metrics track individual accuracy for weighted voting. Warning detection triggers background tree training, and drift detection triggers tree replacement. The predict_proba_one method aggregates predictions across all trees using accuracy-weighted soft voting, normalizing the result to a valid probability distribution.

Key tracking methods n_warnings_detected() and n_drifts_detected() allow inspection of how many warnings and drifts were detected across the ensemble or for individual trees.

Usage

Import forest.ARFClassifier when you need a high-performing drift-adaptive ensemble classifier for online classification on non-stationary data streams. It is one of the most competitive online classifiers available in River.

Code Reference

Source Location

river/forest/adaptive_random_forest.py:L448-L713

Signature

class ARFClassifier(BaseForest, base.Classifier):
    def __init__(
        self,
        n_models: int = 10,
        max_features: bool | str | int = "sqrt",
        lambda_value: int = 6,
        metric: metrics.base.MultiClassMetric | None = None,
        disable_weighted_vote=False,
        drift_detector: base.DriftDetector | None = None,
        warning_detector: base.DriftDetector | None = None,
        # Tree parameters
        grace_period: int = 50,
        max_depth: int | None = None,
        split_criterion: str = "info_gain",
        delta: float = 0.01,
        tau: float = 0.05,
        leaf_prediction: str = "nba",
        nb_threshold: int = 0,
        nominal_attributes: list | None = None,
        splitter: Splitter | None = None,
        binary_split: bool = False,
        min_branch_fraction: float = 0.01,
        max_share_to_split: float = 0.99,
        max_size: float = 100.0,
        memory_estimate_period: int = 2_000_000,
        stop_mem_management: bool = False,
        remove_poor_attrs: bool = False,
        merit_preprune: bool = True,
        seed: int | None = None,
    )

Import

from river import forest

Key Parameters (Ensemble-Specific)

Parameter Type Default Description
n_models int 10 Number of trees in the ensemble
max_features str, int, float, or None "sqrt" Number of features considered per split. "sqrt", "log2", int, float (percentage), or None (all features)
lambda_value int 6 Poisson parameter for online bagging. lambda=6 corresponds to Leveraging Bagging
metric MultiClassMetric or None None (defaults to metrics.Accuracy()) Metric for tracking per-tree performance and weighting votes
disable_weighted_vote bool False If True, uses unweighted voting instead of metric-weighted voting
drift_detector DriftDetector or None None (defaults to ADWIN(delta=0.001)) Per-tree drift detector. Set to None to disable
warning_detector DriftDetector or None None (defaults to ADWIN(delta=0.01)) Per-tree warning detector. Set to None to disable
seed int or None None Random seed for reproducibility

I/O Contract

Inputs

Method Parameter Type Description
learn_one x dict Feature dictionary
learn_one y ClfTarget Target class label
predict_proba_one x dict Feature dictionary
predict_one x dict Feature dictionary

Outputs

Method Return Type Description
predict_proba_one(x) dict[ClfTarget, float] Normalized probability distribution aggregated via weighted voting across all trees
predict_one(x) ClfTarget Predicted class label (argmax of aggregated probabilities)
n_warnings_detected(tree_id=None) int Total warnings detected (all trees) or for a specific tree
n_drifts_detected(tree_id=None) int Total drifts detected (all trees) or for a specific tree

Usage Examples

Basic Adaptive Random Forest Classification

from river import evaluate, forest, metrics
from river.datasets import synth

dataset = synth.ConceptDriftStream(
    seed=42, position=500, width=40
).take(1000)

model = forest.ARFClassifier(seed=8, leaf_prediction="mc")
metric = metrics.Accuracy()

evaluate.progressive_val_score(dataset, model, metric)
# Accuracy: 67.97%

Inspecting Per-Tree Drift Statistics

from river import forest, datasets, evaluate, metrics

dataset = datasets.Elec2().take(10000)
model = forest.ARFClassifier(n_models=10, seed=42)
metric = metrics.Accuracy()

evaluate.progressive_val_score(dataset, model, metric)

# Total warnings and drifts across all trees
print(f"Total warnings: {model.n_warnings_detected()}")
print(f"Total drifts: {model.n_drifts_detected()}")

# Per-tree statistics
for i in range(model.n_models):
    print(f"Tree {i}: {model.n_warnings_detected(i)} warnings, {model.n_drifts_detected(i)} drifts")

Custom Drift Detectors

from river import forest, drift

# Use Page-Hinkley as the drift detector instead of ADWIN
model = forest.ARFClassifier(
    n_models=15,
    drift_detector=drift.ADWIN(delta=0.0005),
    warning_detector=drift.ADWIN(delta=0.005),
    lambda_value=8,
    seed=42
)

Evaluation on Insects Dataset

from river import datasets, evaluate, forest, metrics

dataset = datasets.Insects(variant="abrupt_balanced")
model = forest.ARFClassifier(n_models=10, seed=42)
metric = metrics.Accuracy()

evaluate.progressive_val_score(dataset, model, metric, print_every=10000)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment