Overview
Provides GPU-accelerated TreeSHAP computations for explaining tree-based model predictions, supporting standard, interventional, interaction, and Taylor interaction SHAP methods.
Description
This header defines the API for GPU TreeSHAP, which computes SHAP (SHapley Additive exPlanations) values for tree ensemble models efficiently on the GPU. It consists of the following components:
TreePathInfo<T>: A forward-declared template class that stores the internal path information extracted from a Treelite model, parameterized on float or double.
TreePathHandle: A std::variant holding either std::shared_ptr<TreePathInfo<float>> or std::shared_ptr<TreePathInfo<double>>, providing type-erased access to the path information.
FloatPointer: A std::variant<float*, double*> for type-erased data pointers.
extract_path_info: Extracts the tree path information from a Treelite model handle, returning a TreePathHandle.
gpu_treeshap: Computes standard TreeSHAP values.
gpu_treeshap_interventional: Computes interventional TreeSHAP values using a background dataset.
gpu_treeshap_interactions: Computes pairwise SHAP interaction values.
gpu_treeshap_taylor_interactions: Computes Taylor interaction SHAP values.
Usage
Use these functions to explain predictions from tree ensemble models (e.g., XGBoost, LightGBM, Random Forest) that have been loaded as Treelite models. First call extract_path_info to prepare the tree structure, then call the appropriate gpu_treeshap* variant to compute SHAP values on GPU.
Code Reference
Source Location
- Repository: Rapidsai_Cuml
- File:
cpp/include/cuml/explainer/tree_shap.hpp
Signature
namespace ML {
namespace Explainer {
template <typename T>
class TreePathInfo;
using TreePathHandle =
std::variant<std::shared_ptr<TreePathInfo<float>>,
std::shared_ptr<TreePathInfo<double>>>;
using FloatPointer = std::variant<float*, double*>;
TreePathHandle extract_path_info(TreeliteModelHandle model);
void gpu_treeshap(TreePathHandle path_info,
const FloatPointer data,
std::size_t n_rows,
std::size_t n_cols,
FloatPointer out_preds,
std::size_t out_preds_size);
void gpu_treeshap_interventional(TreePathHandle path_info,
const FloatPointer data,
std::size_t n_rows,
std::size_t n_cols,
const FloatPointer background_data,
std::size_t background_n_rows,
std::size_t background_n_cols,
FloatPointer out_preds,
std::size_t out_preds_size);
void gpu_treeshap_interactions(TreePathHandle path_info,
const FloatPointer data,
std::size_t n_rows,
std::size_t n_cols,
FloatPointer out_preds,
std::size_t out_preds_size);
void gpu_treeshap_taylor_interactions(TreePathHandle path_info,
const FloatPointer data,
std::size_t n_rows,
std::size_t n_cols,
FloatPointer out_preds,
std::size_t out_preds_size);
} // namespace Explainer
} // namespace ML
Import
#include <cuml/explainer/tree_shap.hpp>
I/O Contract
| Name |
Type |
Required |
Description
|
| model |
TreeliteModelHandle |
Yes |
Handle to a loaded Treelite model
|
Inputs (gpu_treeshap)
| Name |
Type |
Required |
Description
|
| path_info |
TreePathHandle |
Yes |
Tree path information extracted from the model
|
| data |
const FloatPointer |
Yes |
Device pointer to the input feature matrix [n_rows x n_cols]
|
| n_rows |
std::size_t |
Yes |
Number of rows in the data
|
| n_cols |
std::size_t |
Yes |
Number of columns (features) in the data
|
| out_preds_size |
std::size_t |
Yes |
Size of the output predictions buffer
|
Inputs (gpu_treeshap_interventional, additional)
| Name |
Type |
Required |
Description
|
| background_data |
const FloatPointer |
Yes |
Device pointer to the background dataset [background_n_rows x background_n_cols]
|
| background_n_rows |
std::size_t |
Yes |
Number of rows in the background dataset
|
| background_n_cols |
std::size_t |
Yes |
Number of columns in the background dataset
|
Outputs
| Name |
Type |
Description
|
| TreePathHandle (from extract_path_info) |
TreePathHandle |
Extracted tree path information for SHAP computation
|
| out_preds |
FloatPointer |
Device pointer to the output SHAP values [out_preds_size]
|
Usage Examples
#include <cuml/explainer/tree_shap.hpp>
void explain_tree_model(TreeliteModelHandle model,
float* data,
std::size_t n_rows,
std::size_t n_cols) {
// Step 1: Extract path information from the Treelite model
auto path_info = ML::Explainer::extract_path_info(model);
// Step 2: Allocate output buffer for SHAP values
// For standard TreeSHAP: n_rows * (n_cols + 1) values
std::size_t out_size = n_rows * (n_cols + 1);
float* out_preds;
cudaMalloc(&out_preds, out_size * sizeof(float));
// Step 3: Compute standard TreeSHAP values
ML::Explainer::gpu_treeshap(
path_info,
ML::Explainer::FloatPointer{data},
n_rows,
n_cols,
ML::Explainer::FloatPointer{out_preds},
out_size);
// out_preds now contains SHAP values for each sample and feature
cudaFree(out_preds);
}
Related Pages