Implementation:Microsoft Onnxruntime CUDA DivGrad
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing the gradient of element-wise division in the ONNX Runtime CUDA training framework.
Description
Implements the DivGrad operator for CUDA that computes gradients for the division operation out = a / b. Given upstream gradient dY and the original operands a and b, it computes da = dY / b and db = -dY * a / (b * b). The implementation handles all broadcasting cases defined by ONNX: NoBroadcast, LeftScalar, RightScalar, RightPerChannelBatch1, RightPerChannelBatchN, and the general multi-dimensional case. For broadcasting cases, the implementation first computes the gradient at the broadcast shape and then reduces (sums) to the original operand shape using ReduceKernelShared with cuDNN. The helper function prepended_dimension_1 pads lower-rank shapes with leading 1-dimensions for broadcast alignment. Registered for MLFloat16, float, and double.
Usage
Invoked during the backward pass when the model contains element-wise division operations.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/math/div_grad.cc
- Lines: 1-255
Signature
template <typename T>
class DivGrad : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const;
};
TensorShapeVector prepended_dimension_1(const TensorShape& shape, size_t total_rank);
Import
#include "orttraining/training_ops/cuda/math/div_grad.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dY | Tensor(T) | Yes | Upstream gradient with broadcast output shape |
| a | Tensor(T) | Yes | Left operand (numerator) from forward pass |
| b | Tensor(T) | Yes | Right operand (denominator) from forward pass |
Outputs
| Name | Type | Description |
|---|---|---|
| da | Tensor(T) | Gradient with respect to numerator (optional, shape of a) |
| db | Tensor(T) | Gradient with respect to denominator (optional, shape of b) |
Usage Examples
DIVGRAD_REGISTER_KERNEL_TYPED(float)
DIVGRAD_REGISTER_KERNEL_TYPED(MLFloat16)
DIVGRAD_REGISTER_KERNEL_TYPED(double)