Implementation:Microsoft Onnxruntime CUDA BatchNormGrad
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing batch normalization gradients in the ONNX Runtime CUDA training framework.
Description
Implements the BatchNormalizationGrad operator for CUDA that computes gradients for the batch normalization backward pass using cuDNN's cudnnBatchNormalizationBackward. The operator takes the upstream gradient dY, original input X, scale parameters, saved mean, and saved inverse standard deviation from the forward pass, and produces three gradient outputs: dX (input gradient), dScale (scale gradient), and dBias (bias gradient). The implementation supports spatial mode (CUDNN_BATCHNORM_SPATIAL) with the CUDNN_BATCHNORM_SPATIAL_PERSISTENT flag and correctly handles tensor descriptors for the per-channel statistics. Template parameters T, T1, T2 support mixed precision configurations including float/float/float, double/double/double, MLFloat16/MLFloat16/MLFloat16, MLFloat16/MLFloat16/float, and MLFloat16/float/float.
Usage
Invoked during the backward pass of training whenever the model contains batch normalization layers.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/nn/batch_norm_grad.cc
- Lines: 1-135
Signature
template <typename T, typename T1, typename T2>
class BatchNormalizationGrad : public CudaKernel {
Status ComputeInternal(OpKernelContext* ctx) const;
};
Import
#include "orttraining/training_ops/cuda/nn/batch_norm_grad.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dY | Tensor(T) | Yes | Gradient of loss with respect to batch norm output |
| X | Tensor(T) | Yes | Original input tensor from forward pass |
| Scale | Tensor(T1) | Yes | Scale parameter (gamma) |
| saved_mean | Tensor(T2) | Yes | Batch mean saved during forward pass |
| saved_inv_std | Tensor(T2) | Yes | Batch inverse standard deviation saved during forward pass |
Outputs
| Name | Type | Description |
|---|---|---|
| dX | Tensor(T) | Gradient with respect to input |
| dScale | Tensor(T1) | Gradient with respect to scale parameter |
| dBias | Tensor(T1) | Gradient with respect to bias parameter |
Usage Examples
// Registration for mixed precision: MLFloat16 input, float scale/stats
REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16, float, float)
// Calls cudnnBatchNormalizationBackward with CUDNN_BATCHNORM_SPATIAL