Implementation:Rapidsai Cuml Traversal Forest
| Knowledge Sources | |
|---|---|
| Domains | Machine_Learning, Forest_Inference |
| Last Updated | 2026-02-08 12:00 GMT |
Overview
Defines the generic abstract forest traversal framework in cuML, providing a polymorphic interface for traversing tree ensembles in multiple orderings (depth-first, breadth-first, and layered) with support for node-level lambda callbacks.
Description
The traversal_forest.hpp header provides the core traversal abstraction for tree ensemble models in the ML::forest namespace:
detail::traversal_container: An internal template that adapts between a stack (depth-first) and a queue (breadth-first) based on the forest_order template parameter. It provides add(), next(), peek(), empty(), and size() operations.
traversal_forest: An abstract template struct parameterized on node type and tree ID type. Key members and methods:
get_node(tree_id, node_id): Pure virtual method that returns a node given its tree and node identifiers. This is the main extension point for concrete implementations (e.g., Treelite integration).for_each<order>(lambda): The core traversal method that visits every node in the forest, calling the provided lambda with(tree_id, node, depth, parent_index)for each node. Supports four traversal orders:depth_first: Standard DFS using a stack.breadth_first: Standard BFS using a queue.layered_children_segregated: Layer-by-layer traversal where hot children are visited before distant children at each layer.layered_children_together: Layer-by-layer traversal where both children are visited together at each layer.
The framework tracks parent indices throughout traversal, enabling reconstruction of tree topology during iteration.
Usage
Use this framework as the base for implementing forest model traversal in cuML's Forest Inference Library (FIL). Concrete implementations (like treelite_traversal_forest) inherit from traversal_forest and implement get_node(). The for_each template method then enables generic algorithms to operate on any tree model regardless of its source format. The different traversal orders support different layout strategies for GPU-optimized inference.
Code Reference
Source Location
- Repository: Rapidsai_Cuml
- File:
cpp/include/cuml/forest/traversal/traversal_forest.hpp
Signature
namespace ML {
namespace forest {
namespace detail {
template <forest_order order, typename T>
struct traversal_container {
void add(T const& val);
void add(T const& hot, T const& distant);
auto next();
auto peek();
[[nodiscard]] auto empty();
auto size();
};
} // namespace detail
template <typename node_t = traversal_node<std::size_t>,
typename tree_id_t = std::size_t>
struct traversal_forest {
using node_type = node_t;
using node_id_type = typename node_type::id_type;
using tree_id_type = tree_id_t;
using node_uid_type = std::pair<tree_id_type, node_id_type>;
using index_type = std::size_t;
virtual node_type get_node(tree_id_type tree_id,
node_id_type node_id) const = 0;
traversal_forest(std::vector<node_uid_type>&& root_node_uids);
template <forest_order order, typename lambda_t>
void for_each(lambda_t&& lambda) const;
};
} // namespace forest
} // namespace ML
Import
#include <cuml/forest/traversal/traversal_forest.hpp>
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| root_node_uids | std::vector<node_uid_type>&& | Yes (constructor) | Vector of (tree_id, root_node_id) pairs for all trees in the forest |
| tree_id | tree_id_type | Yes (get_node) | Identifier of the tree to access |
| node_id | node_id_type | Yes (get_node) | Identifier of the node within the tree |
| lambda | lambda_t | Yes (for_each) | Callback receiving (tree_id, node, depth, parent_index) for each visited node |
| order | forest_order (template) | Yes (for_each) | Traversal order: depth_first, breadth_first, layered_children_segregated, or layered_children_together |
Outputs
| Name | Type | Description |
|---|---|---|
| get_node | node_type | The traversal node at the specified tree and node ID |
| for_each | void | Invokes the lambda for every node in the forest; no return value |
Usage Examples
#include <cuml/forest/traversal/traversal_forest.hpp>
#include <cuml/forest/traversal/traversal_node.hpp>
// Define a concrete forest implementation
struct my_forest : ML::forest::traversal_forest<> {
using base = ML::forest::traversal_forest<>;
my_forest(std::vector<base::node_uid_type>&& roots)
: base(std::move(roots)) {}
base::node_type get_node(base::tree_id_type tree_id,
base::node_id_type node_id) const override {
// Return the node from your internal tree structure
// ...
}
};
// Traverse depth-first
my_forest forest(/* root nodes */);
std::size_t node_count = 0;
forest.for_each<ML::forest::forest_order::depth_first>(
[&node_count](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
++node_count;
});
// Traverse breadth-first
forest.for_each<ML::forest::forest_order::breadth_first>(
[](auto&& tree_id, auto&& node, auto&& depth, auto&& parent_index) {
// Process node layer by layer
});