Implementation:Microsoft Onnxruntime CPU GatherNDGrad
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing GatherND gradient on CPU in the ONNX Runtime training framework.
Description
This file implements the GatherNDGrad kernel, which computes the gradient of the ONNX GatherND operation. Given the upstream gradient (corresponding to the forward GatherND output) and the indices tensor, it scatters gradient slices back to a zero-initialized output tensor of the original data shape. The kernel uses the GatherNDBase::PrepareForCompute to compute slice offsets, then dispatches to GatherNDGradComputeImpl which accumulates gradient values at the correct offsets: output[slice_offset + j] += update[i]. It supports float and double data types with int32 and int64 index types, and respects the batch_dims attribute.
Usage
This kernel is invoked during the backward pass when a GatherND operation was used in the forward pass. It distributes gradient slices back to the positions from which data was gathered.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/tensor/gather_nd_grad.cc
- Lines: 1-83
Signature
template <typename InputT>
struct GatherNDGradComputeImpl {
void operator()(GatherNDBase::Prepare& p, const Tensor* update_tensor) const;
};
Status GatherNDGrad::Compute(OpKernelContext* context) const;
Import
#include "orttraining/orttraining/training_ops/cpu/tensor/gather_nd_grad.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| shape | Tensor(int64) | Yes | Shape of the original data tensor |
| indices | Tensor(Tind) | Yes | Indices used in the forward GatherND |
| update | Tensor(T) | Yes | Upstream gradient (same shape as GatherND output) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Tensor(T) | Gradient w.r.t. original data (zero-initialized, then accumulated) |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(
GatherNDGrad, kMSDomain, 1, kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>()})
.TypeConstraint("Tind", {DataTypeImpl::GetTensorType<int64_t>(),
DataTypeImpl::GetTensorType<int32_t>()}),
GatherNDGrad);