Implementation:Rapidsai Cuml Permutation SHAP
| Knowledge Sources | |
|---|---|
| Domains | Machine_Learning, Explainability |
| Last Updated | 2026-02-08 12:00 GMT |
Overview
Provides GPU-accelerated functions for generating permutation-based SHAP datasets and aggregating SHAP value estimates for model explainability.
Description
This header defines three functions used in the Permutation SHAP algorithm:
permutation_shap_dataset: Generates a dataset by tiling the background matrix and inserting a forward and backward permutation pass of an observation row at positions defined by a permutation index array. The resulting dataset has dimensions(2 * ncols * nrows_bg + nrows_bg) x ncolsand supports both row-major and column-major output.
shap_main_effect_dataset: Similar topermutation_shap_datasetbut generates data for computing main effects, where each feature is individually toggled between the observation and background values according to the permutation indices.
update_perm_shap_values: Aggregates model prediction differences into SHAP value estimates. For each feature index in the permutation, it computes the forward differencey_hat[i+1] - y_hat[i]and backward differencey_hat[i+ncols] - y_hat[i+ncols+1], accumulating these into the SHAP values array.
All functions operate on single-precision float data on the GPU.
Usage
Use these functions as part of a Permutation SHAP pipeline. First call permutation_shap_dataset (or shap_main_effect_dataset) to generate the perturbation data, run the model on that data to obtain predictions, then call update_perm_shap_values to accumulate the SHAP value estimates. Repeat across multiple permutations for more accurate estimates.
Code Reference
Source Location
- Repository: Rapidsai_Cuml
- File:
cpp/include/cuml/explainer/permutation_shap.hpp
Signature
namespace ML {
namespace Explainer {
void permutation_shap_dataset(const raft::handle_t& handle,
float* dataset,
const float* background,
int nrows_bg,
int ncols,
const float* row,
int* idx,
bool row_major);
void shap_main_effect_dataset(const raft::handle_t& handle,
float* dataset,
const float* background,
int nrows_bg,
int ncols,
const float* row,
int* idx,
bool row_major);
void update_perm_shap_values(const raft::handle_t& handle,
float* shap_values,
const float* y_hat,
const int ncols,
const int* idx);
} // namespace Explainer
} // namespace ML
Import
#include <cuml/explainer/permutation_shap.hpp>
I/O Contract
Inputs (permutation_shap_dataset / shap_main_effect_dataset)
| Name | Type | Required | Description |
|---|---|---|---|
| handle | const raft::handle_t& | Yes | cuML handle for GPU resource management |
| background | const float* | Yes | Background dataset on device [nrows_bg x ncols] |
| nrows_bg | int | Yes | Number of rows in the background dataset |
| ncols | int | Yes | Number of columns (features) |
| row | const float* | Yes | Observation row to explain on device [ncols] |
| idx | int* | Yes | Permutation index array on device [ncols] |
| row_major | bool | Yes | Whether to produce row-major or column-major output |
Inputs (update_perm_shap_values)
| Name | Type | Required | Description |
|---|---|---|---|
| handle | const raft::handle_t& | Yes | cuML handle for GPU resource management |
| y_hat | const float* | Yes | Model predictions on device [2 * ncols + 2] (forward and backward pass results) |
| ncols | int | Yes | Number of features |
| idx | const int* | Yes | Permutation index array on device [ncols] |
Outputs
| Name | Type | Description |
|---|---|---|
| dataset | float* | Device pointer to generated permutation dataset [(2 * ncols * nrows_bg + nrows_bg) x ncols] |
| shap_values | float* | Device pointer to SHAP values array [ncols], aggregated in-place |
Usage Examples
#include <cuml/explainer/permutation_shap.hpp>
#include <raft/core/handle.hpp>
void compute_permutation_shap() {
raft::handle_t handle;
int ncols = 3;
int nrows_bg = 3;
// Allocate device memory
float* background; // [nrows_bg x ncols]
float* row; // [ncols]
int* idx; // permutation indices [ncols]
float* dataset; // output [(2 * ncols * nrows_bg + nrows_bg) x ncols]
float* shap_values; // SHAP values [ncols]
float* y_hat; // model predictions [2 * ncols + 2]
int dataset_rows = 2 * ncols * nrows_bg + nrows_bg;
cudaMalloc(&background, nrows_bg * ncols * sizeof(float));
cudaMalloc(&row, ncols * sizeof(float));
cudaMalloc(&idx, ncols * sizeof(int));
cudaMalloc(&dataset, dataset_rows * ncols * sizeof(float));
cudaMalloc(&shap_values, ncols * sizeof(float));
cudaMalloc(&y_hat, (2 * ncols + 2) * sizeof(float));
// Initialize background, row, idx on device...
// Zero-initialize shap_values...
// Step 1: Generate the permutation SHAP dataset
ML::Explainer::permutation_shap_dataset(
handle, dataset, background, nrows_bg, ncols, row, idx, true);
handle.sync_stream();
// Step 2: Run model on dataset to fill y_hat...
// Step 3: Aggregate SHAP values
ML::Explainer::update_perm_shap_values(handle, shap_values, y_hat, ncols, idx);
handle.sync_stream();
// shap_values now contains the estimated SHAP contributions
cudaFree(background);
cudaFree(row);
cudaFree(idx);
cudaFree(dataset);
cudaFree(shap_values);
cudaFree(y_hat);
}