Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Rapidsai Cuml Tree SHAP

From Leeroopedia


Knowledge Sources
Domains Machine_Learning, Explainability
Last Updated 2026-02-08 12:00 GMT

Overview

Provides GPU-accelerated TreeSHAP computations for explaining tree-based model predictions, supporting standard, interventional, interaction, and Taylor interaction SHAP methods.

Description

This header defines the API for GPU TreeSHAP, which computes SHAP (SHapley Additive exPlanations) values for tree ensemble models efficiently on the GPU. It consists of the following components:

  • TreePathInfo<T>: A forward-declared template class that stores the internal path information extracted from a Treelite model, parameterized on float or double.
  • TreePathHandle: A std::variant holding either std::shared_ptr<TreePathInfo<float>> or std::shared_ptr<TreePathInfo<double>>, providing type-erased access to the path information.
  • FloatPointer: A std::variant<float*, double*> for type-erased data pointers.
  • extract_path_info: Extracts the tree path information from a Treelite model handle, returning a TreePathHandle.
  • gpu_treeshap: Computes standard TreeSHAP values.
  • gpu_treeshap_interventional: Computes interventional TreeSHAP values using a background dataset.
  • gpu_treeshap_interactions: Computes pairwise SHAP interaction values.
  • gpu_treeshap_taylor_interactions: Computes Taylor interaction SHAP values.

Usage

Use these functions to explain predictions from tree ensemble models (e.g., XGBoost, LightGBM, Random Forest) that have been loaded as Treelite models. First call extract_path_info to prepare the tree structure, then call the appropriate gpu_treeshap* variant to compute SHAP values on GPU.

Code Reference

Source Location

  • Repository: Rapidsai_Cuml
  • File: cpp/include/cuml/explainer/tree_shap.hpp

Signature

namespace ML {
namespace Explainer {

template <typename T>
class TreePathInfo;

using TreePathHandle =
  std::variant<std::shared_ptr<TreePathInfo<float>>,
               std::shared_ptr<TreePathInfo<double>>>;

using FloatPointer = std::variant<float*, double*>;

TreePathHandle extract_path_info(TreeliteModelHandle model);

void gpu_treeshap(TreePathHandle path_info,
                  const FloatPointer data,
                  std::size_t n_rows,
                  std::size_t n_cols,
                  FloatPointer out_preds,
                  std::size_t out_preds_size);

void gpu_treeshap_interventional(TreePathHandle path_info,
                                 const FloatPointer data,
                                 std::size_t n_rows,
                                 std::size_t n_cols,
                                 const FloatPointer background_data,
                                 std::size_t background_n_rows,
                                 std::size_t background_n_cols,
                                 FloatPointer out_preds,
                                 std::size_t out_preds_size);

void gpu_treeshap_interactions(TreePathHandle path_info,
                               const FloatPointer data,
                               std::size_t n_rows,
                               std::size_t n_cols,
                               FloatPointer out_preds,
                               std::size_t out_preds_size);

void gpu_treeshap_taylor_interactions(TreePathHandle path_info,
                                      const FloatPointer data,
                                      std::size_t n_rows,
                                      std::size_t n_cols,
                                      FloatPointer out_preds,
                                      std::size_t out_preds_size);

}  // namespace Explainer
}  // namespace ML

Import

#include <cuml/explainer/tree_shap.hpp>

I/O Contract

Inputs (extract_path_info)

Name Type Required Description
model TreeliteModelHandle Yes Handle to a loaded Treelite model

Inputs (gpu_treeshap)

Name Type Required Description
path_info TreePathHandle Yes Tree path information extracted from the model
data const FloatPointer Yes Device pointer to the input feature matrix [n_rows x n_cols]
n_rows std::size_t Yes Number of rows in the data
n_cols std::size_t Yes Number of columns (features) in the data
out_preds_size std::size_t Yes Size of the output predictions buffer

Inputs (gpu_treeshap_interventional, additional)

Name Type Required Description
background_data const FloatPointer Yes Device pointer to the background dataset [background_n_rows x background_n_cols]
background_n_rows std::size_t Yes Number of rows in the background dataset
background_n_cols std::size_t Yes Number of columns in the background dataset

Outputs

Name Type Description
TreePathHandle (from extract_path_info) TreePathHandle Extracted tree path information for SHAP computation
out_preds FloatPointer Device pointer to the output SHAP values [out_preds_size]

Usage Examples

#include <cuml/explainer/tree_shap.hpp>

void explain_tree_model(TreeliteModelHandle model,
                        float* data,
                        std::size_t n_rows,
                        std::size_t n_cols) {
    // Step 1: Extract path information from the Treelite model
    auto path_info = ML::Explainer::extract_path_info(model);

    // Step 2: Allocate output buffer for SHAP values
    // For standard TreeSHAP: n_rows * (n_cols + 1) values
    std::size_t out_size = n_rows * (n_cols + 1);
    float* out_preds;
    cudaMalloc(&out_preds, out_size * sizeof(float));

    // Step 3: Compute standard TreeSHAP values
    ML::Explainer::gpu_treeshap(
        path_info,
        ML::Explainer::FloatPointer{data},
        n_rows,
        n_cols,
        ML::Explainer::FloatPointer{out_preds},
        out_size);

    // out_preds now contains SHAP values for each sample and feature

    cudaFree(out_preds);
}

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment