Overview
Concrete tool for computing pooling gradient operations on CPU in the ONNX Runtime training framework.
Description
This file implements two pooling gradient kernels: MaxPoolGrad and AveragePoolGrad. The MaxPoolGrad kernel scatters the upstream gradient to the positions indicated by the indices tensor (from the MaxPool forward pass that recorded which elements were the maximum). The AveragePoolGrad kernel distributes the upstream gradient evenly across the input elements that contributed to each output element, with support for 1D, 2D, and 3D pooling. The average pool gradient respects the count_include_pad setting when computing the denominator scale factor. Both kernels support the NCHW storage order and are registered under opset 9.
Usage
These kernels are invoked during the backward pass when MaxPool or AveragePool layers are present in the training graph. MaxPoolGrad uses the pooling indices to route gradients, while AveragePoolGrad distributes gradients proportionally.
Code Reference
Source Location
Signature
std::vector<VectorInt64> InferOutputShapes(OpKernelInfo info);
template <typename T>
Status MaxPoolGrad<T>::Compute(OpKernelContext* context) const;
template <typename T>
Status AveragePoolGrad<T>::Compute(OpKernelContext* context) const;
template <typename T>
Status AveragePoolGrad<T>::Compute1DAveragePoolGrad(OpKernelContext* context) const;
template <typename T>
Status AveragePoolGrad<T>::Compute2DAveragePoolGrad(OpKernelContext* context) const;
template <typename T>
Status AveragePoolGrad<T>::Compute3DAveragePoolGrad(OpKernelContext* context) const;
Import
#include "orttraining/orttraining/training_ops/cpu/nn/pool_gradient_op.h"
I/O Contract
Inputs (MaxPoolGrad)
| Name |
Type |
Required |
Description
|
| dY |
Tensor(float) |
Yes |
Upstream gradient
|
| indices |
Tensor(int64) |
Yes |
Indices of max elements from forward pass
|
Outputs (MaxPoolGrad)
| Name |
Type |
Description
|
| dX |
Tensor(float) |
Gradient w.r.t. pooling input
|
Inputs (AveragePoolGrad)
| Name |
Type |
Required |
Description
|
| dY |
Tensor(float) |
Yes |
Upstream gradient
|
Outputs (AveragePoolGrad)
| Name |
Type |
Description
|
| dX |
Tensor(float) |
Gradient w.r.t. pooling input
|
Usage Examples
ONNX_CPU_OPERATOR_KERNEL(
MaxPoolGrad, 9,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
MaxPoolGrad<float>);
ONNX_CPU_OPERATOR_KERNEL(
AveragePoolGrad, 9,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
AveragePoolGrad<float>);
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.