Implementation:Microsoft Onnxruntime CUDA ResizeGrad
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing the gradient of 2D resize (interpolation) operations in the ONNX Runtime CUDA training framework.
Description
Implements the ResizeGrad operator for CUDA that computes the gradient of bilinear resize for 4-D tensors (NCHW format). The implementation reads scale factors from the input scales tensor (height and width at indices 2 and 3), supports the align_corners coordinate transformation mode, and calls ResizeGradImpl to distribute the upstream gradient dY back to the input spatial dimensions. When input and output shapes are identical (scale factor of 1), a simple memcpy is used. Otherwise, the output is first zero-initialized before gradient accumulation. Registered for MLFloat16, float, and double types. ROI and scales inputs are kept on CPU memory.
Usage
Invoked during the backward pass when the model contains resize/upsample operations, commonly found in image segmentation networks.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc
- Lines: 1-81
Signature
template <typename T>
class ResizeGrad : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const;
};
Import
#include "orttraining/training_ops/cuda/tensor/resize_grad.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dY | Tensor(T) | Yes | Upstream gradient with output shape [N, C, H_out, W_out] |
| X | Tensor(T) | Yes | Original input tensor [N, C, H_in, W_in] |
| roi | Tensor(T) | No | Region of interest (CPU memory) |
| scales | Tensor(float) | No | Scale factors [1, 1, scale_h, scale_w] (CPU memory) |
Outputs
| Name | Type | Description |
|---|---|---|
| dX | Tensor(T) | Gradient with respect to input [N, C, H_in, W_in] |
Usage Examples
REGISTER_RESIZEGRAD_KERNEL_TYPED(float)
REGISTER_RESIZEGRAD_KERNEL_TYPED(MLFloat16)
REGISTER_RESIZEGRAD_KERNEL_TYPED(double)