Implementation:Online ml River Tree StochasticGradientTree
| Knowledge Sources | |
|---|---|
| Domains | Online_Learning, Decision_Trees, Classification, Regression, Gradient_Boosting |
| Last Updated | 2026-02-08 16:00 GMT |
Overview
Stochastic Gradient Trees (SGT) are incremental decision trees that directly minimize a loss function to guide tree growth and update predictions, differing from traditional trees that use impurity-based heuristics. Includes both classification (SGTClassifier) and regression (SGTRegressor) variants.
Description
SGT operates by minimizing loss functions through gradient descent at each node. Instead of using information gain or variance reduction, SGT computes gradient and hessian statistics for each potential split and uses F-tests to determine when splits are statistically significant. The tree grows by:
1. Computing gradients and hessians of the loss function 2. Aggregating statistics at each node and potential split point 3. Using F-tests to determine if expected loss reduction is significant 4. Splitting when p-value < delta and loss mean < 0
Key innovations:
- Direct loss minimization rather than impurity reduction
- Statistical significance testing via F-tests
- Regularization via lambda (prediction penalty) and gamma (split penalty)
- Feature quantization for efficient split evaluation
- Support for both static and dynamic quantizers
The loss functions used are:
- Classification: Binary cross-entropy with logistic transfer
- Regression: Squared error
Usage
Classification:
from river import datasets
from river import evaluate
from river import metrics
from river import tree
dataset = datasets.Phishing()
model = tree.SGTClassifier(
feature_quantizer=tree.splitter.StaticQuantizer(
n_bins=32, warm_start=10
)
)
metric = metrics.Accuracy()
evaluate.progressive_val_score(dataset, model, metric)
# Accuracy: 82.24%
Regression:
from river import datasets
from river import evaluate
from river import metrics
from river import tree
dataset = datasets.TrumpApproval()
model = tree.SGTRegressor(
delta=0.01,
lambda_value=0.01,
grace_period=20,
feature_quantizer=tree.splitter.DynamicQuantizer(std_prop=0.1)
)
metric = metrics.MAE()
evaluate.progressive_val_score(dataset, model, metric)
# MAE: 1.721818
Code Reference
Source Location:
/tmp/kapso_repo_178qi9vb/river/tree/stochastic_gradient_tree.py
Signature:
class SGTClassifier(StochasticGradientTree, base.Classifier):
def __init__(
self,
delta: float = 1e-7,
grace_period: int = 200,
init_pred: float = 0.0,
max_depth: int | None = None,
lambda_value: float = 0.1,
gamma: float = 1.0,
nominal_attributes: list | None = None,
feature_quantizer: tree.splitter.Quantizer | None = None,
)
class SGTRegressor(StochasticGradientTree, base.Regressor):
def __init__(
self,
delta: float = 1e-7,
grace_period: int = 200,
init_pred: float = 0.0,
max_depth: int | None = None,
lambda_value: float = 0.1,
gamma: float = 1.0,
nominal_attributes: list | None = None,
feature_quantizer: tree.splitter.Quantizer | None = None,
)
Import:
from river.tree import SGTClassifier, SGTRegressor
I/O Contract
Input:
- x (dict): Feature dictionary with attribute names as keys
- y (float/int): Target value (continuous for regression, 0/1 for classification)
- w (float, optional): Sample weight (default: 1.0)
Output:
- SGTClassifier.predict_proba_one(x): Dict mapping {True: p, False: 1-p}
- SGTRegressor.predict_one(x): Predicted value (float)
Key Parameters
- delta (float): Significance level for F-tests (lower = more conservative splits)
- grace_period (int): Interval between split attempts
- init_pred (float): Initial prediction value at root
- max_depth (int): Maximum tree depth (None = unlimited)
- lambda_value (float): L2 regularization on predictions (≥0)
- gamma (float): Penalty on splits to discourage excessive splitting (≥0)
- nominal_attributes (list): List of categorical feature names
- feature_quantizer (Quantizer): Algorithm for discretizing numeric features
Implementation Details
Key Methods:
- learn_one(x, y, w=1.0): Update tree with one instance
- _compute_p_value(merit, n_observations): Compute F-test p-value
- _target_transform(y): Transform target (identity for regression, float for classification)
Node Types:
- SGTLeaf: Leaf node storing gradient/hessian statistics
- NumericBinaryBranch: Binary split on numeric feature
- NominalMultiwayBranch: Multiway split on categorical feature
Gradient and Hessian Tracking:
Each node maintains:
- GradHess: Pair of gradient and hessian values
- GradHessStats: Aggregated statistics including mean, variance, covariance
- GradHessMerit: Split merit based on expected loss reduction
Split Evaluation:
For each candidate split: 1. Compute delta_pred = -G/(H + lambda) for each child 2. Compute delta_loss = delta_pred * G + 0.5 * H * delta_pred² 3. Calculate loss_mean and loss_var from statistics 4. Compute F-statistic: F = n * loss_mean² / loss_var 5. Compute p-value from F-distribution 6. Split if p < delta AND loss_mean < 0
Feature Quantization:
StaticQuantizer (original implementation):
- Pre-allocates fixed number of bins
- Warm-up period before splits enabled
- Bins defined by equal-frequency quantiles
DynamicQuantizer (enhancement):
- Incrementally adjusts bin boundaries
- Based on Quantization Observer (QO)
- More adaptive to data distribution
Loss Functions
BinaryCrossEntropyLoss (classification):
- gradient = sigmoid(y_pred) - y_true
- hessian = sigmoid(y_pred) * (1 - sigmoid(y_pred))
- transfer = sigmoid(y) = 1/(1 + exp(-y))
SquaredErrorLoss (regression):
- gradient = y_pred - y_true
- hessian = 1.0
- transfer = identity(y) = y
Statistical Test
The F-test evaluates:
- Null hypothesis: expected loss is zero
- Alternative: expected loss is non-zero
- Test statistic: F = n * (loss_mean)² / loss_var
- Degrees of freedom: (1, n-1)
- Reject null (split) if p-value < delta
Properties
- n_splits: Number of splits performed
- n_node_updates: Number of node updates
- n_observations: Total samples processed
- height: Tree height
- n_nodes: Total node count
- n_branches: Branch node count
- n_leaves: Leaf node count
Related Pages
- Online_ml_River_Tree_HoeffdingTreeClassifier
- Online_ml_River_Tree_HoeffdingTreeRegressor
- Online_ml_River_Tree_SGT_Losses
- Online_ml_River_Tree_Utils
- Online_ml_River_Tree_Splitter
References
Gouk, H., Pfahringer, B., & Frank, E. (2019, October). "Stochastic Gradient Trees." In Asian Conference on Machine Learning (pp. 1094-1109).