Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Online ml River Tree Utils

From Leeroopedia


Knowledge Sources
Domains Online_Learning, Decision_Trees, Utilities
Last Updated 2026-02-08 16:00 GMT

Overview

Utility functions and data structures shared across tree implementations in River. Provides Naive Bayes prediction, branch factories, gradient/hessian tracking, memory calculation, and mathematical helpers.

Description

The utils module contains common functionality needed by various tree algorithms:

  • Naive Bayes prediction for classification trees
  • Branch assembly factories for split candidates
  • Gradient and hessian statistics for Stochastic Gradient Trees
  • Memory size calculation utilities
  • Dictionary operations for statistics aggregation
  • Numerical rounding helpers

These utilities promote code reuse and maintain consistency across different tree implementations.

Code Reference

Source Location: /tmp/kapso_repo_178qi9vb/river/tree/utils.py

Import:

from river.tree.utils import (
    do_naive_bayes_prediction,
    BranchFactory,
    GradHess,
    GradHessStats,
    GradHessMerit,
    calculate_object_size,
    add_dict_values,
    round_sig_fig
)

Functions

do_naive_bayes_prediction

Perform Naive Bayes prediction using observed class distributions and attribute statistics.

Signature:

def do_naive_bayes_prediction(
    x: dict,
    observed_class_distribution: dict,
    splitters: dict
) -> dict

Parameters:

  • x (dict): Feature values
  • observed_class_distribution (dict): Class counts at node
  • splitters (dict): Attribute observers with conditional probability methods

Returns:

  • Dictionary mapping class labels to probabilities

Algorithm: 1. Compute prior: P(c) = count(c) / total_count 2. For each feature: P(x_f | c) from splitter's cond_proba method 3. Compute log-likelihood: log P(c) + Σ log P(x_f | c) 4. Apply log-sum-exp trick for numerical stability 5. Normalize to get probabilities

Example:

x = {'age': 35, 'income': 50000}
class_dist = {0: 100, 1: 150}
splitters = {
    'age': GaussianSplitter(...),
    'income': GaussianSplitter(...)
}

proba = do_naive_bayes_prediction(x, class_dist, splitters)
# Returns: {0: 0.35, 1: 0.65} (example values)

calculate_object_size

Iteratively calculate object size by visiting all related properties and objects.

Signature:

def calculate_object_size(obj: typing.Any, unit: str = "byte") -> int

Parameters:

  • obj (Any): Object to evaluate
  • unit (str): Unit for result ('byte', 'KiB', 'MiB')

Returns:

  • Size in specified unit

Algorithm: 1. Use breadth-first traversal to visit all reachable objects 2. Track visited objects to avoid double-counting 3. Sum sys.getsizeof() for each unique object 4. Recursively explore dicts, attributes, and iterables 5. Convert to requested unit

Example:

tree_size_bytes = calculate_object_size(tree)
tree_size_mb = calculate_object_size(tree, unit="MiB")

add_dict_values

Add two dictionaries element-wise, summing values with matching keys.

Signature:

def add_dict_values(
    dict_a: dict,
    dict_b: dict,
    inplace: bool = False
) -> dict

Parameters:

  • dict_a (dict): First dictionary
  • dict_b (dict): Second dictionary (added to first)
  • inplace (bool): If True, modify dict_a; otherwise create new dict

Returns:

  • Dictionary with summed values

Example:

stats_a = {0: 10, 1: 20}
stats_b = {0: 5, 1: 15, 2: 8}
result = add_dict_values(stats_a, stats_b)
# Result: {0: 15, 1: 35, 2: 8}

round_sig_fig

Round to specified number of significant figures.

Signature:

def round_sig_fig(x: float, significant_digits: int = 2) -> float

Examples:

round_sig_fig(1.2345)                    # 1.2
round_sig_fig(1.2345, significant_digits=3)  # 1.23
round_sig_fig(1999, significant_digits=1)    # 2000
round_sig_fig(0.025, significant_digits=3)   # 0.03

Data Structures

BranchFactory

Helper dataclass for assembling branch nodes from split candidates.

Attributes:

  • merit (float): Split merit (default: -inf for null split)
  • feature (FeatureName | None): Split feature
  • split_info (Hashable | list | tuple): Threshold or category info
  • children_stats (list): Statistics for child nodes
  • numerical_feature (bool): Is feature numeric?
  • multiway_split (bool): Is split multiway?

Methods:

  • assemble(branch, stats, depth, *children, **kwargs): Create branch node

Comparison:

Implements `__lt__` and `__eq__` based on merit for sorting.

Example:

# Create split candidate
candidate = BranchFactory(
    merit=0.15,
    feature='age',
    split_info=35.0,
    children_stats=[{0: 50, 1: 30}, {0: 20, 1: 60}],
    numerical_feature=True,
    multiway_split=False
)

# Sort candidates by merit
candidates = [candidate1, candidate2, candidate3]
candidates.sort()  # Ascending order by merit
best = candidates[-1]

# Assemble into branch node
branch = best.assemble(
    NumericBinaryBranch,
    stats={0: 70, 1: 90},
    depth=2,
    left_leaf,
    right_leaf,
    splitter=splitter_instance
)

GradHess

Container for gradient and hessian values in SGT.

Attributes:

  • gradient (float): First derivative
  • hessian (float): Second derivative

Operations:

  • Addition: `gh1 + gh2`
  • Subtraction: `gh1 - gh2`
  • In-place: `gh1 += gh2`, `gh1 -= gh2`

Example:

gh1 = GradHess(gradient=0.5, hessian=0.3)
gh2 = GradHess(gradient=0.2, hessian=0.4)

total = gh1 + gh2
# GradHess(gradient=0.7, hessian=0.7)

gh1 += gh2  # In-place addition

GradHessStats

Aggregate gradient/hessian statistics with mean, variance, and covariance tracking.

Attributes:

  • g_var (Var): Gradient variance tracker
  • h_var (Var): Hessian variance tracker
  • gh_cov (Cov): Gradient-hessian covariance tracker

Methods:

  • update(gh, w): Update with new GradHess observation
  • mean (property): Returns GradHess with mean values
  • variance (property): Returns GradHess with variances
  • covariance (property): Returns gradient-hessian covariance
  • total_weight (property): Total weight of observations
  • delta_loss_mean_var(delta_pred): Compute expected loss change statistics

Operations:

  • Addition: `stats1 + stats2`
  • Subtraction: `stats1 - stats2`
  • In-place: `stats1 += stats2`, `stats1 -= stats2`

Example:

stats = GradHessStats()

# Update with observations
for i in range(100):
    gh = compute_grad_hess(y_true[i], y_pred[i])
    stats.update(gh, w=1.0)

# Access statistics
mean_g = stats.mean.gradient
mean_h = stats.mean.hessian
var_g = stats.variance.gradient
cov_gh = stats.covariance

# Compute prediction update
delta_pred = -mean_g / (mean_h + lambda_value)
delta_loss = stats.delta_loss_mean_var(delta_pred)

GradHessMerit

Dataclass storing split merit based on gradient/hessian statistics.

Attributes:

  • loss_mean (float): Mean loss reduction
  • loss_var (float): Variance of loss reduction
  • delta_pred (float | dict | None): Prediction update(s)

Comparison:

Implements `__lt__` and `__eq__` based on loss_mean.

Example:

merit = GradHessMerit(
    loss_mean=-0.05,  # Negative means improvement
    loss_var=0.001,
    delta_pred=0.15
)

# Used in F-test
f_value = n * (merit.loss_mean ** 2) / merit.loss_var

Naive Bayes Implementation Details

Log-Sum-Exp Trick:

To avoid numerical underflow when computing probabilities:

1. Compute log-likelihoods: LL[c] = log P(c) + Σ log P(x_f | c) 2. Find maximum: max_ll = max(LL) 3. Compute log-sum-exp: lse = max_ll + log(Σ exp(LL[c] - max_ll)) 4. Normalize: P(c) = exp(LL[c] - lse)

This ensures numerical stability even when probabilities are very small.

Jeffreys Prior:

For smoothing, many Hoeffding Tree variants use:

P(c | x) ∝ (count(c) + α) × Π P(x_f | c)

where α is a Dirichlet parameter (often 0.5 for Jeffreys prior).

Memory Management

Size Calculation Strategy:

The `calculate_object_size` function is crucial for tree memory management:

1. Visited Set: Prevents counting shared objects multiple times 2. Recursive Exploration: Follows object graphs via __dict__ and __iter__ 3. Special Cases: Handles dicts, iterables, strings, bytes separately 4. Unit Conversion: Byte → KiB (÷1024) → MiB (÷2²⁰)

Usage in Trees:

# Periodic size checks
if n_samples % memory_estimate_period == 0:
    tree_size = calculate_object_size(tree, unit="MiB")
    if tree_size > max_size:
        enforce_size_limit()

Related Pages

Performance Considerations

Naive Bayes Prediction:

  • O(k × d) where k = number of classes, d = number of features
  • Log-space computation prevents underflow
  • Missing features are skipped (not penalized)

Memory Calculation:

  • O(n) where n = number of objects in tree
  • Can be expensive for large trees
  • Should be called periodically, not per sample

Dictionary Addition:

  • O(|dict_b|) complexity
  • In-place mode saves memory allocations
  • Used frequently in node statistics aggregation

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment