Implementation:Microsoft Onnxruntime CUDA GatherGrad
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing the gradient of Gather in the ONNX Runtime CUDA training framework.
Description
Implements the GatherGrad operator for CUDA that scatters upstream gradients back to the original input shape based on the gather indices. The output dX is first zero-initialized, then gradients from dY are scattered to the positions indicated by gathered_indices. The implementation dispatches through a two-level type dispatcher: first by data type (float, MLFloat16, BFloat16) then by index type (int32_t, int64_t). It uses GatherGradImpl with CudaScratchBufferAllocator for workspace allocation. The gather axis is handled by computing num_batches, gather_dimension_size, and num_gathered_per_index from the original input shape and axis attribute.
Usage
Invoked during the backward pass when the model uses Gather operations, such as embedding lookups.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/tensor/gather_grad.cc
- Lines: 1-123
Signature
class GatherGrad : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const;
};
Import
#include "orttraining/training_ops/cuda/tensor/gather_grad.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| X_shape | Tensor(int64_t) | Yes | Shape of original data tensor (CPU memory) |
| indices | Tensor(Tind) | Yes | Index tensor from forward Gather |
| dY | Tensor(T) | Yes | Upstream gradient |
Outputs
| Name | Type | Description |
|---|---|---|
| dX | Tensor(T) | Gradient with respect to data input (zero-initialized then scattered) |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(GatherGrad, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0)
.TypeConstraint("I", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("T", BuildKernelDefConstraints<MLFloat16, float, double, BFloat16>())
.TypeConstraint("Tind", std::vector<MLDataType>{
DataTypeImpl::GetTensorType<int32_t>(),
DataTypeImpl::GetTensorType<int64_t>()}),
GatherGrad);