Implementation:Online ml River Tree Base Nodes
| Knowledge Sources | |
|---|---|
| Domains | Online_Learning, Decision_Trees, Abstract_Base_Class |
| Last Updated | 2026-02-08 16:00 GMT |
Overview
Generic branch and leaf abstractions that provide the foundation for all tree-based models in River. These classes define standardized interfaces for tree manipulation, traversal, and visualization.
Description
The base module provides abstract base classes (ABCs) that establish common functionality for tree nodes. This design promotes code reuse across different tree implementations (Hoeffding Trees, Mondrian Trees, SGT, etc.) and ensures consistent behavior.
Key abstractions:
- Branch: Base class for internal split nodes
- Leaf: Base class for terminal prediction nodes
These classes are intentionally generic and framework-oriented, not meant for direct instantiation. Instead, specific tree algorithms subclass them to implement their particular node behaviors.
Code Reference
Source Location:
/tmp/kapso_repo_178qi9vb/river/tree/base.py
Signatures:
class Branch(base.Base, abc.ABC):
def __init__(self, *children):
self.children = children
@abc.abstractmethod
def next(self, x) -> Branch | Leaf:
"""Move to the next node down the tree."""
@abc.abstractmethod
def most_common_path(self) -> tuple[int, Leaf | Branch]:
"""Return branch index and child for most traversed path."""
@property
@abc.abstractmethod
def repr_split(self):
"""String representation of the split."""
class Leaf(base.Base):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@property
@abc.abstractmethod
def __repr__(self):
"""String representation for visualization."""
Import:
from river.tree.base import Branch, Leaf
Branch Class
Abstract Methods:
- next(x): Return the child node to traverse given feature vector x
- most_common_path(): Return (branch_index, child) for the most frequently traversed path (used when split feature is missing)
- repr_split: String describing the split condition
Concrete Methods:
- walk(x, until_leaf=True): Iterator over nodes along the path induced by x
- traverse(x, until_leaf=True): Return the final node reached by following x
- iter_dfs(): Depth-first traversal iterator
- iter_bfs(): Breadth-first traversal iterator
- iter_leaves(): Iterator over all leaf nodes
- iter_branches(): Iterator over all branch nodes
- iter_edges(): Iterator over (parent, child) edges
- to_dataframe(): Export tree structure to pandas DataFrame
Properties:
- n_nodes: Total descendant count (including self)
- n_branches: Branch node count (including self)
- n_leaves: Leaf node count
- height: Distance to deepest descendant
Special Methods:
- _repr_html_(): HTML representation for Jupyter notebooks (uses viz module)
Leaf Class
Constructor:
Accepts any keyword arguments and stores them as instance attributes. This flexible design allows different tree types to store different leaf-specific data.
Properties:
- n_nodes: Always returns 1
- n_branches: Always returns 0
- n_leaves: Always returns 1
- height: Always returns 1
Methods:
- walk(x, until_leaf=True): Yields self only
- iter_dfs(): Yields self only
- iter_leaves(): Yields self only
- iter_branches(): Yields nothing
- iter_edges(): Yields nothing
Tree Traversal
Path Walking:
# Iterate over nodes from root to leaf
for node in root.walk(x, until_leaf=True):
if isinstance(node, Branch):
print(f"Branch on {node.feature}")
else:
print(f"Reached leaf: {node}")
Missing Features:
When a split feature is missing from the input, the traversal: 1. Calls `most_common_path()` on the branch 2. Follows the most frequently traversed child 3. Continues until reaching a leaf
Tree Iteration:
# Depth-first traversal
for node in root.iter_dfs():
process(node)
# Breadth-first traversal
for node in root.iter_bfs():
process(node)
# All leaves
for leaf in root.iter_leaves():
print(leaf.stats)
# All branches
for branch in root.iter_branches():
print(branch.feature, branch.threshold)
# All edges
for parent, child in root.iter_edges():
print(f"{parent} -> {child}")
DataFrame Export
The `to_dataframe()` method creates a pandas DataFrame with one row per node:
df = tree._root.to_dataframe()
# Columns: node (index), parent, is_leaf, depth, plus all node attributes
Example output:
| node | parent | is_leaf | depth | feature | threshold |
|---|---|---|---|---|---|
| 0 | NA | False | 0 | 'age' | 30.5 |
| 1 | 0 | True | 1 | NA | NA |
| 2 | 0 | False | 1 | 'income' | 50000 |
| 3 | 2 | True | 2 | NA | NA |
| 4 | 2 | True | 2 | NA | NA |
HTML Visualization
In Jupyter notebooks, branches automatically render as interactive HTML trees:
tree._root # Displays HTML tree visualization
The visualization:
- Shows hierarchical structure
- Displays split conditions at branches
- Shows leaf information
- Uses CSS for styling (defined in viz module)
Usage Patterns
Subclassing Branch:
class NumericBinaryBranch(Branch):
def __init__(self, feature, threshold, left, right):
super().__init__(left, right)
self.feature = feature
self.threshold = threshold
def next(self, x):
if x[self.feature] <= self.threshold:
return self.children[0]
return self.children[1]
def most_common_path(self):
# Return path with more samples
if self.children[0].n_samples > self.children[1].n_samples:
return 0, self.children[0]
return 1, self.children[1]
@property
def repr_split(self):
return f"{self.feature} ≤ {self.threshold}"
Subclassing Leaf:
class RegressionLeaf(Leaf):
def __init__(self, depth):
super().__init__(depth=depth)
self.mean = 0.0
self.n_samples = 0
def update(self, y):
self.n_samples += 1
self.mean += (y - self.mean) / self.n_samples
@property
def __repr__(self):
return f"Mean: {self.mean:.2f} (n={self.n_samples})"
Related Pages
- Online_ml_River_Tree_HoeffdingTree
- Online_ml_River_Tree_MondrianTreeNodes
- Online_ml_River_Tree_Utils
- Online_ml_River_Tree_Viz
Design Philosophy
The base node classes follow several key principles:
1. Separation of Concerns: Split logic (Branch) vs. prediction logic (Leaf) 2. Extensibility: Abstract methods allow diverse implementations 3. Reusability: Common operations implemented once 4. Consistency: All River trees share the same interface 5. Debuggability: Rich introspection and visualization support