Implementation:Microsoft Onnxruntime CPU BroadcastGradArgs
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for computing broadcast gradient reduction axes on CPU in the ONNX Runtime training framework.
Description
This file implements the BroadcastGradientArgs kernel, which computes the axes along which gradients must be reduced when two tensors were broadcast together during the forward pass. Given two input shapes (A_shape and B_shape), it outputs two lists of axes: one for each input indicating which dimensions were broadcast (and thus need reduction in the backward pass). The algorithm right-aligns the two shapes and compares dimensions from the innermost outward, marking an axis for reduction if a dimension was 1 (and thus broadcast) or absent. The kernel is registered under kMSDomain opset 1.
Usage
This kernel is used during the backward pass of any elementwise binary operation that involved broadcasting. The output axes are then used to reduce (sum) the gradient along the broadcast dimensions.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/nn/broadcast_grad_args.cc
- Lines: 1-91
Signature
Status BroadcastGradientArgs::Compute(OpKernelContext* context) const;
Import
#include "orttraining/orttraining/training_ops/cpu/nn/broadcast_grad_args.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| A_shape | Tensor(int64) | Yes | Shape of the first input tensor |
| B_shape | Tensor(int64) | Yes | Shape of the second input tensor |
Outputs
| Name | Type | Description |
|---|---|---|
| A_axes | Tensor(int64) | Axes along which A must be reduced |
| B_axes | Tensor(int64) | Axes along which B must be reduced |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(
BroadcastGradientArgs,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<int64_t>()),
BroadcastGradientArgs);