Implementation:Microsoft Onnxruntime CPU BatchNormGrad
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing batch normalization gradients on CPU in the ONNX Runtime training framework.
Description
This file implements the BatchNormalizationGrad kernel for CPU training. It computes gradients with respect to the input (dX), scale (d_scale), and bias (d_bias) given the upstream gradient (dY). The kernel uses Eigen array operations for efficient batch processing. It first computes x_hat = (X - mean) * inv_std_var, then derives d_scale = sum(dY * x_hat) and d_bias = sum(dY) across the batch. The input gradient uses the formula: dX = scale * inv_std_var * (dY - mean(dY) - x_hat * mean(dY * x_hat)) * N / (N - 1). The kernel supports both float and double types, registered under kMSDomain opset 9.
Usage
This kernel is invoked during the backward pass when a batch normalization layer is present in the training graph. It receives the upstream gradient and the saved mean/variance from the forward pass to compute parameter and input gradients.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/nn/batch_norm_grad.cc
- Lines: 1-84
Signature
template <typename T>
Status BatchNormalizationGrad<T>::Compute(OpKernelContext* context) const;
Import
#include "orttraining/orttraining/training_ops/cpu/nn/batch_norm_grad.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dY | Tensor(T) | Yes | Upstream gradient [N, C, spatial...] |
| X | Tensor(T) | Yes | Input tensor from forward pass [N, C, spatial...] |
| scale | Tensor(T) | Yes | Scale parameter [C] |
| saved_mean | Tensor(float) | Yes | Saved batch mean from forward [C] |
| saved_inv_std | Tensor(float) | Yes | Saved inverse standard deviation from forward [C] |
Outputs
| Name | Type | Description |
|---|---|---|
| dX | Tensor(T) | Gradient w.r.t. input X [N, C, spatial...] |
| d_scale | Tensor(T) | Gradient w.r.t. scale [C] |
| d_bias | Tensor(T) | Gradient w.r.t. bias [C] |
Usage Examples
ONNX_OPERATOR_TYPED_KERNEL_EX(
BatchNormalizationGrad, kMSDomain, 9, float, kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
BatchNormalizationGrad<float>);