Implementation:Microsoft Onnxruntime CPU Adasum
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for performing Adasum all-reduce collective operations on CPU in the ONNX Runtime training framework.
Description
This file implements the AdasumAllReduce kernel, which performs a distributed all-reduce using the Adasum (Adaptive Sum) algorithm via MPI. The kernel accepts a variable number of input tensors, copies them into a contiguous fused buffer, performs the Adasum reduction across all processes in the global parallel worker group, and copies the results back to individual output tensors. The Adasum algorithm provides better convergence than simple averaging by computing an adaptive combination of gradients from different workers. The kernel is conditionally compiled under USE_MPI and uses the AdasumReducer::DispatchFusedAllreduce method. It supports all IEEE float tensor types and uses VariadicAlias(0, 0) for in-place output mapping.
Usage
This kernel is used in distributed training scenarios where Adasum gradient reduction provides better convergence than simple gradient averaging. It requires MPI to be enabled.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/collective/adasum_kernels.cc
- Lines: 1-65
Signature
Status AdasumAllReduce::Compute(OpKernelContext* context) const;
Import
#include "orttraining/orttraining/training_ops/cpu/collective/adasum_kernels.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| tensors (variadic) | Tensor(IEEE float) | Yes | Variable number of tensors to all-reduce |
Outputs
| Name | Type | Description |
|---|---|---|
| reduced_tensors (variadic) | Tensor(IEEE float) | Adasum-reduced tensors (one-to-one with inputs) |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(
AdasumAllReduce, kMSDomain, 1, kCpuExecutionProvider,
KernelDefBuilder()
.VariadicAlias(0, 0) // outputs and inputs are mapped one to one
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
AdasumAllReduce);