Implementation:Microsoft Onnxruntime CUDA ReductionOps
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for training-specific reduce sum operations in the ONNX Runtime CUDA training framework.
Description
Implements the ReduceSumTraining operator for CUDA, a training variant of ReduceSum that takes reduction axes as a runtime input tensor rather than a fixed attribute. The implementation overrides ComputeImplEx from ReduceKernel to read axes from the second input tensor. It supports noop_with_empty_axes for pass-through when no axes are specified, and uses cuDNN's cudnnReduceTensor for the actual reduction with configurable keepdims and fast_reduction flags. A specialized int32_t implementation is provided that first casts to float, performs the cuDNN reduction, then casts back to int32_t since cuDNN does not directly support int32 reductions. Supports MLFloat16, float, double, and int32_t types.
Usage
Used during training when dynamic reduction axes are needed, such as in gradient computation where reduction dimensions may vary.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/reduction/reduction_ops.cc
- Lines: 1-158
Signature
template <bool allow_multi_axes>
template <typename T, cudnnReduceTensorIndices_t ReduceTensorIndices>
Status ReduceKernel<allow_multi_axes>::ComputeImplEx(
OpKernelContext* ctx, cudnnReduceTensorOp_t cudnn_reduce_op) const;
Import
#include "orttraining/training_ops/cuda/reduction/reduction_ops.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| X | Tensor(T) | Yes | Input tensor to reduce |
| axes | Tensor(int64_t) | Yes | 1D tensor specifying axes to reduce (CPU memory) |
Outputs
| Name | Type | Description |
|---|---|---|
| Y | Tensor(T) | Reduced output tensor |
Usage Examples
REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, float)
REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, MLFloat16)
REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, double)
REGISTER_MS_KERNEL_TYPED(ReduceSumTraining, int32_t)