Implementation:Microsoft Onnxruntime CPU GatherElementsGrad
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing GatherElements gradient on CPU in the ONNX Runtime training framework.
Description
This file implements the GatherElementsGrad kernel, which computes the gradient of the ONNX GatherElements operation. Given the upstream gradient (dY) and the indices used during the forward GatherElements, it scatters the gradient values back to a zero-initialized output tensor of the original data shape. The kernel validates that indices and dY have the same rank and dimensions, and that index dimensions do not exceed data dimensions (except along the gather axis). It supports float and double data types with int32 and int64 index types. The actual scatter operation is delegated to GatherElementsGradImpl.
Usage
This kernel is invoked during the backward pass when a GatherElements operation was used in the forward pass. It distributes upstream gradients back to the positions from which elements were gathered.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/tensor/gather_elements_grad.cc
- Lines: 1-88
Signature
Status GatherElementsGrad::Compute(OpKernelContext* context) const;
Import
#include "orttraining/orttraining/training_ops/cpu/tensor/gather_elements_grad.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dY | Tensor(T) | Yes | Upstream gradient (same shape as GatherElements output) |
| shape | Tensor(int64) | Yes | Shape of the original data tensor |
| indices | Tensor(Tind) | Yes | Indices used in the forward GatherElements |
Outputs
| Name | Type | Description |
|---|---|---|
| dX | Tensor(T) | Gradient w.r.t. original data (zero-initialized, then scattered) |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(
GatherElementsGrad, kMSDomain, 1, kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<double>()})
.TypeConstraint("Tind", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
GatherElementsGrad);