Implementation:Microsoft Onnxruntime CUDA GatherNDGrad
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing the gradient of GatherND in the ONNX Runtime CUDA training framework.
Description
Implements the GatherNDGrad operator for CUDA that scatters upstream gradients (updates) back to the original data shape based on multi-dimensional indices. The output is first zero-initialized (noted as potentially expensive), then GatherNDGradImpl atomically scatters update values to the positions specified by index tuples. The implementation validates that indices rank is greater than 0, the last indices dimension does not exceed input rank (accounting for batch_dims), and batch dimensions match. Slice offsets are precomputed via PrepareCompute. The operation is inherently non-deterministic due to atomic additions, with a warning logged when deterministic compute is requested. Registered for int64_t index type with float, MLFloat16, double, and BFloat16 data types.
Usage
Invoked during the backward pass when the model uses GatherND operations for multi-dimensional indexed access.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/tensor/gather_nd_grad.cc
- Lines: 1-100
Signature
template <typename TIndex>
class GatherNDGrad : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const;
};
Import
#include "orttraining/training_ops/cuda/tensor/gather_nd_grad.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| shape | Tensor(int64_t) | Yes | Shape of original data tensor (CPU memory) |
| indices | Tensor(Tind) | Yes | Multi-dimensional index tensor from forward pass |
| update | Tensor(T) | Yes | Upstream gradient values to scatter |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Tensor(T) | Gradient with respect to data (zero-initialized then scattered) |
Usage Examples
REGISTER_KERNEL_TYPED_GATHER_ND_GRAD(int64_t)
// Dispatches to GatherNDGradImpl<CudaT> for float, MLFloat16, double, BFloat16