Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Online ml River Forest ARFClassifier

From Leeroopedia
Revision as of 16:08, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Online_ml_River_Forest_ARFClassifier.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources Domains Last Updated
River River Docs Adaptive Random Forests for Evolving Data Stream Classification Online Machine Learning, Concept Drift, Ensemble Learning 2026-02-08 16:00 GMT

Overview

Concrete tool for building an online adaptive random forest classifier that uses per-tree drift detection, background tree replacement, Poisson resampling, and random feature subsets for drift-adaptive ensemble classification.

Description

The forest.ARFClassifier class implements the Adaptive Random Forest ensemble. It extends BaseForest (which inherits from base.Ensemble) and uses BaseTreeClassifier (a feature-randomized Hoeffding Tree) as its base learner. The ensemble manages n_models trees, each with its own ADWIN-based drift and warning detectors.

During learn_one, each tree receives the instance with a Poisson-distributed weight. Per-tree metrics track individual accuracy for weighted voting. Warning detection triggers background tree training, and drift detection triggers tree replacement. The predict_proba_one method aggregates predictions across all trees using accuracy-weighted soft voting, normalizing the result to a valid probability distribution.

Key tracking methods n_warnings_detected() and n_drifts_detected() allow inspection of how many warnings and drifts were detected across the ensemble or for individual trees.

Usage

Import forest.ARFClassifier when you need a high-performing drift-adaptive ensemble classifier for online classification on non-stationary data streams. It is one of the most competitive online classifiers available in River.

Code Reference

Source Location

river/forest/adaptive_random_forest.py:L448-L713

Signature

class ARFClassifier(BaseForest, base.Classifier):
    def __init__(
        self,
        n_models: int = 10,
        max_features: bool | str | int = "sqrt",
        lambda_value: int = 6,
        metric: metrics.base.MultiClassMetric | None = None,
        disable_weighted_vote=False,
        drift_detector: base.DriftDetector | None = None,
        warning_detector: base.DriftDetector | None = None,
        # Tree parameters
        grace_period: int = 50,
        max_depth: int | None = None,
        split_criterion: str = "info_gain",
        delta: float = 0.01,
        tau: float = 0.05,
        leaf_prediction: str = "nba",
        nb_threshold: int = 0,
        nominal_attributes: list | None = None,
        splitter: Splitter | None = None,
        binary_split: bool = False,
        min_branch_fraction: float = 0.01,
        max_share_to_split: float = 0.99,
        max_size: float = 100.0,
        memory_estimate_period: int = 2_000_000,
        stop_mem_management: bool = False,
        remove_poor_attrs: bool = False,
        merit_preprune: bool = True,
        seed: int | None = None,
    )

Import

from river import forest

Key Parameters (Ensemble-Specific)

Parameter Type Default Description
n_models int 10 Number of trees in the ensemble
max_features str, int, float, or None "sqrt" Number of features considered per split. "sqrt", "log2", int, float (percentage), or None (all features)
lambda_value int 6 Poisson parameter for online bagging. lambda=6 corresponds to Leveraging Bagging
metric MultiClassMetric or None None (defaults to metrics.Accuracy()) Metric for tracking per-tree performance and weighting votes
disable_weighted_vote bool False If True, uses unweighted voting instead of metric-weighted voting
drift_detector DriftDetector or None None (defaults to ADWIN(delta=0.001)) Per-tree drift detector. Set to None to disable
warning_detector DriftDetector or None None (defaults to ADWIN(delta=0.01)) Per-tree warning detector. Set to None to disable
seed int or None None Random seed for reproducibility

I/O Contract

Inputs

Method Parameter Type Description
learn_one x dict Feature dictionary
learn_one y ClfTarget Target class label
predict_proba_one x dict Feature dictionary
predict_one x dict Feature dictionary

Outputs

Method Return Type Description
predict_proba_one(x) dict[ClfTarget, float] Normalized probability distribution aggregated via weighted voting across all trees
predict_one(x) ClfTarget Predicted class label (argmax of aggregated probabilities)
n_warnings_detected(tree_id=None) int Total warnings detected (all trees) or for a specific tree
n_drifts_detected(tree_id=None) int Total drifts detected (all trees) or for a specific tree

Usage Examples

Basic Adaptive Random Forest Classification

from river import evaluate, forest, metrics
from river.datasets import synth

dataset = synth.ConceptDriftStream(
    seed=42, position=500, width=40
).take(1000)

model = forest.ARFClassifier(seed=8, leaf_prediction="mc")
metric = metrics.Accuracy()

evaluate.progressive_val_score(dataset, model, metric)
# Accuracy: 67.97%

Inspecting Per-Tree Drift Statistics

from river import forest, datasets, evaluate, metrics

dataset = datasets.Elec2().take(10000)
model = forest.ARFClassifier(n_models=10, seed=42)
metric = metrics.Accuracy()

evaluate.progressive_val_score(dataset, model, metric)

# Total warnings and drifts across all trees
print(f"Total warnings: {model.n_warnings_detected()}")
print(f"Total drifts: {model.n_drifts_detected()}")

# Per-tree statistics
for i in range(model.n_models):
    print(f"Tree {i}: {model.n_warnings_detected(i)} warnings, {model.n_drifts_detected(i)} drifts")

Custom Drift Detectors

from river import forest, drift

# Use Page-Hinkley as the drift detector instead of ADWIN
model = forest.ARFClassifier(
    n_models=15,
    drift_detector=drift.ADWIN(delta=0.0005),
    warning_detector=drift.ADWIN(delta=0.005),
    lambda_value=8,
    seed=42
)

Evaluation on Insects Dataset

from river import datasets, evaluate, forest, metrics

dataset = datasets.Insects(variant="abrupt_balanced")
model = forest.ARFClassifier(n_models=10, seed=42)
metric = metrics.Accuracy()

evaluate.progressive_val_score(dataset, model, metric, print_every=10000)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment