Implementation:Microsoft Onnxruntime CUDA Adasum
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for Adasum (Adaptive Summation) all-reduce collective operation in the ONNX Runtime CUDA training framework.
Description
Implements the AdasumAllReduce operator for CUDA that performs an adaptive summation all-reduce across distributed workers. Adasum preserves orthogonal gradient components better than simple averaging, resulting in improved convergence. The implementation copies input tensors from GPU to CPU pinned memory, then dispatches a fused all-reduce via the adasum_reducer_ using MPI communicators. For GPU hierarchical reduction, the VHDD (Vector Halving, Distance Doubling) start level is set to the node-local data-parallel group size. Results are copied back from CPU to GPU output tensors. The operator uses variadic aliasing (VariadicAlias(0, 0)) for one-to-one input-output mapping and supports all IEEE float tensor types.
Usage
Used during distributed training as an alternative to standard AllReduce for gradient aggregation, particularly effective for large-batch training.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/collective/adasum_kernels.cc
- Lines: 1-73
Signature
class AdasumAllReduce : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const;
};
Import
#include "orttraining/training_ops/cuda/collective/adasum_kernels.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| tensors | Tensor(T)... | Yes | Variadic input tensors to reduce (IEEE float types) |
Outputs
| Name | Type | Description |
|---|---|---|
| reduced_tensors | Tensor(T)... | Reduced output tensors, one per input (mapped one-to-one) |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(
AdasumAllReduce, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.VariadicAlias(0, 0)
.TypeConstraint("T", DataTypeImpl::AllIEEEFloatTensorTypes()),
AdasumAllReduce);