Implementation:Online ml River Tree Utils
| Knowledge Sources | |
|---|---|
| Domains | Online_Learning, Decision_Trees, Utilities |
| Last Updated | 2026-02-08 16:00 GMT |
Overview
Utility functions and data structures shared across tree implementations in River. Provides Naive Bayes prediction, branch factories, gradient/hessian tracking, memory calculation, and mathematical helpers.
Description
The utils module contains common functionality needed by various tree algorithms:
- Naive Bayes prediction for classification trees
- Branch assembly factories for split candidates
- Gradient and hessian statistics for Stochastic Gradient Trees
- Memory size calculation utilities
- Dictionary operations for statistics aggregation
- Numerical rounding helpers
These utilities promote code reuse and maintain consistency across different tree implementations.
Code Reference
Source Location:
/tmp/kapso_repo_178qi9vb/river/tree/utils.py
Import:
from river.tree.utils import (
do_naive_bayes_prediction,
BranchFactory,
GradHess,
GradHessStats,
GradHessMerit,
calculate_object_size,
add_dict_values,
round_sig_fig
)
Functions
do_naive_bayes_prediction
Perform Naive Bayes prediction using observed class distributions and attribute statistics.
Signature:
def do_naive_bayes_prediction(
x: dict,
observed_class_distribution: dict,
splitters: dict
) -> dict
Parameters:
- x (dict): Feature values
- observed_class_distribution (dict): Class counts at node
- splitters (dict): Attribute observers with conditional probability methods
Returns:
- Dictionary mapping class labels to probabilities
Algorithm: 1. Compute prior: P(c) = count(c) / total_count 2. For each feature: P(x_f | c) from splitter's cond_proba method 3. Compute log-likelihood: log P(c) + Σ log P(x_f | c) 4. Apply log-sum-exp trick for numerical stability 5. Normalize to get probabilities
Example:
x = {'age': 35, 'income': 50000}
class_dist = {0: 100, 1: 150}
splitters = {
'age': GaussianSplitter(...),
'income': GaussianSplitter(...)
}
proba = do_naive_bayes_prediction(x, class_dist, splitters)
# Returns: {0: 0.35, 1: 0.65} (example values)
calculate_object_size
Iteratively calculate object size by visiting all related properties and objects.
Signature:
def calculate_object_size(obj: typing.Any, unit: str = "byte") -> int
Parameters:
- obj (Any): Object to evaluate
- unit (str): Unit for result ('byte', 'KiB', 'MiB')
Returns:
- Size in specified unit
Algorithm: 1. Use breadth-first traversal to visit all reachable objects 2. Track visited objects to avoid double-counting 3. Sum sys.getsizeof() for each unique object 4. Recursively explore dicts, attributes, and iterables 5. Convert to requested unit
Example:
tree_size_bytes = calculate_object_size(tree)
tree_size_mb = calculate_object_size(tree, unit="MiB")
add_dict_values
Add two dictionaries element-wise, summing values with matching keys.
Signature:
def add_dict_values(
dict_a: dict,
dict_b: dict,
inplace: bool = False
) -> dict
Parameters:
- dict_a (dict): First dictionary
- dict_b (dict): Second dictionary (added to first)
- inplace (bool): If True, modify dict_a; otherwise create new dict
Returns:
- Dictionary with summed values
Example:
stats_a = {0: 10, 1: 20}
stats_b = {0: 5, 1: 15, 2: 8}
result = add_dict_values(stats_a, stats_b)
# Result: {0: 15, 1: 35, 2: 8}
round_sig_fig
Round to specified number of significant figures.
Signature:
def round_sig_fig(x: float, significant_digits: int = 2) -> float
Examples:
round_sig_fig(1.2345) # 1.2
round_sig_fig(1.2345, significant_digits=3) # 1.23
round_sig_fig(1999, significant_digits=1) # 2000
round_sig_fig(0.025, significant_digits=3) # 0.03
Data Structures
BranchFactory
Helper dataclass for assembling branch nodes from split candidates.
Attributes:
- merit (float): Split merit (default: -inf for null split)
- feature (FeatureName | None): Split feature
- split_info (Hashable | list | tuple): Threshold or category info
- children_stats (list): Statistics for child nodes
- numerical_feature (bool): Is feature numeric?
- multiway_split (bool): Is split multiway?
Methods:
- assemble(branch, stats, depth, *children, **kwargs): Create branch node
Comparison:
Implements `__lt__` and `__eq__` based on merit for sorting.
Example:
# Create split candidate
candidate = BranchFactory(
merit=0.15,
feature='age',
split_info=35.0,
children_stats=[{0: 50, 1: 30}, {0: 20, 1: 60}],
numerical_feature=True,
multiway_split=False
)
# Sort candidates by merit
candidates = [candidate1, candidate2, candidate3]
candidates.sort() # Ascending order by merit
best = candidates[-1]
# Assemble into branch node
branch = best.assemble(
NumericBinaryBranch,
stats={0: 70, 1: 90},
depth=2,
left_leaf,
right_leaf,
splitter=splitter_instance
)
GradHess
Container for gradient and hessian values in SGT.
Attributes:
- gradient (float): First derivative
- hessian (float): Second derivative
Operations:
- Addition: `gh1 + gh2`
- Subtraction: `gh1 - gh2`
- In-place: `gh1 += gh2`, `gh1 -= gh2`
Example:
gh1 = GradHess(gradient=0.5, hessian=0.3)
gh2 = GradHess(gradient=0.2, hessian=0.4)
total = gh1 + gh2
# GradHess(gradient=0.7, hessian=0.7)
gh1 += gh2 # In-place addition
GradHessStats
Aggregate gradient/hessian statistics with mean, variance, and covariance tracking.
Attributes:
- g_var (Var): Gradient variance tracker
- h_var (Var): Hessian variance tracker
- gh_cov (Cov): Gradient-hessian covariance tracker
Methods:
- update(gh, w): Update with new GradHess observation
- mean (property): Returns GradHess with mean values
- variance (property): Returns GradHess with variances
- covariance (property): Returns gradient-hessian covariance
- total_weight (property): Total weight of observations
- delta_loss_mean_var(delta_pred): Compute expected loss change statistics
Operations:
- Addition: `stats1 + stats2`
- Subtraction: `stats1 - stats2`
- In-place: `stats1 += stats2`, `stats1 -= stats2`
Example:
stats = GradHessStats()
# Update with observations
for i in range(100):
gh = compute_grad_hess(y_true[i], y_pred[i])
stats.update(gh, w=1.0)
# Access statistics
mean_g = stats.mean.gradient
mean_h = stats.mean.hessian
var_g = stats.variance.gradient
cov_gh = stats.covariance
# Compute prediction update
delta_pred = -mean_g / (mean_h + lambda_value)
delta_loss = stats.delta_loss_mean_var(delta_pred)
GradHessMerit
Dataclass storing split merit based on gradient/hessian statistics.
Attributes:
- loss_mean (float): Mean loss reduction
- loss_var (float): Variance of loss reduction
- delta_pred (float | dict | None): Prediction update(s)
Comparison:
Implements `__lt__` and `__eq__` based on loss_mean.
Example:
merit = GradHessMerit(
loss_mean=-0.05, # Negative means improvement
loss_var=0.001,
delta_pred=0.15
)
# Used in F-test
f_value = n * (merit.loss_mean ** 2) / merit.loss_var
Naive Bayes Implementation Details
Log-Sum-Exp Trick:
To avoid numerical underflow when computing probabilities:
1. Compute log-likelihoods: LL[c] = log P(c) + Σ log P(x_f | c) 2. Find maximum: max_ll = max(LL) 3. Compute log-sum-exp: lse = max_ll + log(Σ exp(LL[c] - max_ll)) 4. Normalize: P(c) = exp(LL[c] - lse)
This ensures numerical stability even when probabilities are very small.
Jeffreys Prior:
For smoothing, many Hoeffding Tree variants use:
P(c | x) ∝ (count(c) + α) × Π P(x_f | c)
where α is a Dirichlet parameter (often 0.5 for Jeffreys prior).
Memory Management
Size Calculation Strategy:
The `calculate_object_size` function is crucial for tree memory management:
1. Visited Set: Prevents counting shared objects multiple times 2. Recursive Exploration: Follows object graphs via __dict__ and __iter__ 3. Special Cases: Handles dicts, iterables, strings, bytes separately 4. Unit Conversion: Byte → KiB (÷1024) → MiB (÷2²⁰)
Usage in Trees:
# Periodic size checks
if n_samples % memory_estimate_period == 0:
tree_size = calculate_object_size(tree, unit="MiB")
if tree_size > max_size:
enforce_size_limit()
Related Pages
- Online_ml_River_Tree_HoeffdingTree
- Online_ml_River_Tree_StochasticGradientTree
- Online_ml_River_Tree_Base_Nodes
- Online_ml_River_Tree_SGT_Losses
- Online_ml_River_Stats
Performance Considerations
Naive Bayes Prediction:
- O(k × d) where k = number of classes, d = number of features
- Log-space computation prevents underflow
- Missing features are skipped (not penalized)
Memory Calculation:
- O(n) where n = number of objects in tree
- Can be expensive for large trees
- Should be called periodically, not per sample
Dictionary Addition:
- O(|dict_b|) complexity
- In-place mode saves memory allocations
- Used frequently in node statistics aggregation