Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime CUDA ConvGrad

From Leeroopedia


Knowledge Sources
Domains Training, CUDA_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for computing convolution gradients in the ONNX Runtime CUDA training framework.

Description

Implements the ConvGrad operator for CUDA that computes gradients for the standard convolution backward pass. The PrepareArgs method validates input shapes, computes kernel shape, padding, dilations, strides, and sets up cuDNN tensor descriptors and convolution descriptors. It caches dimension information to avoid redundant descriptor re-creation when dimensions do not change. The ComputeInternal method computes up to three gradient outputs using cuDNN: dB (bias gradient via backward bias), dW (weight gradient via backward filter with algorithm search), and dX (input gradient via backward data with algorithm search). Algorithm selection uses AlgoIterator to try multiple cuDNN algorithms for best performance. Registered for float, double, and MLFloat16 types.

Usage

Invoked during the backward pass of training for standard convolution layers (Conv1d, Conv2d, Conv3d).

Code Reference

Source Location

Signature

template <typename T>
class ConvGrad : public CudaKernel {
  Status PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w,
                     Tensor* dB, Tensor* dX, Tensor* dW, cudnnHandle_t cudnn_handle) const;
  Status ComputeInternal(OpKernelContext* context) const;
};

Import

#include "orttraining/training_ops/cuda/nn/conv_grad.h"

I/O Contract

Inputs

Name Type Required Description
dY Tensor(T) Yes Gradient of loss with respect to convolution output
X Tensor(T) Yes Original input tensor from forward pass
W Tensor(T) Yes Convolution weight tensor from forward pass

Outputs

Name Type Description
dX Tensor(T) Gradient with respect to input (optional)
dW Tensor(T) Gradient with respect to weights (optional)
dB Tensor(T) Gradient with respect to bias (optional)

Usage Examples

ONNX_OPERATOR_TYPED_KERNEL_EX(ConvGrad, kMSDomain, 1, float,
    kCudaExecutionProvider,
    (*KernelDefBuilder::Create())
        .TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
    ConvGrad<float>);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment