Implementation:Microsoft Onnxruntime CUDA NcclCommon
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool providing common NCCL context management and type conversion utilities in the ONNX Runtime CUDA training framework.
Description
Implements the NcclContext class and shared NCCL utilities. GetNcclDataType maps ORT data types (uint8, int8, int32, int64, MLFloat16, float, double) to NCCL data types. NcclContext manages NCCL communicators for five worker group types: GlobalParallel, DataParallel, HorizontalParallel, NodeLocalDataParallel, and CrossNodeDataParallel. During construction, it initializes MPI (if not already initialized), creates MPI sub-groups and sub-communicators for each worker group, broadcasts NCCL unique IDs, and initializes NCCL communicators via ncclCommInitRank. The destructor properly destroys all NCCL communicators and finalizes MPI. NcclKernel is the base class for NCCL-based operators that holds a pointer to the static NcclContext singleton and reads the group_type attribute.
Usage
Used as the foundation for all NCCL-based collective operations. NcclContext is initialized once and shared across all NCCL kernel instances throughout the training session.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/collective/nccl_common.cc
- Lines: 1-156
Signature
ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type);
class NcclContext {
NcclContext();
~NcclContext();
ncclComm_t Comm(training::WorkerGroupType group_type);
};
class NcclKernel : public CudaKernel {
NcclKernel(const OpKernelInfo& info);
protected:
NcclContext* nccl_;
training::WorkerGroupType group_type_;
};
Import
#include "orttraining/training_ops/cuda/collective/nccl_common.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| group_type | int64_t (attribute) | No | Worker group type (default 0 = GlobalParallel) |
| type | MLDataType | Yes | Data type to convert to NCCL type (for GetNcclDataType) |
Outputs
| Name | Type | Description |
|---|---|---|
| nccl_type | ncclDataType_t | Corresponding NCCL data type |
| comm | ncclComm_t | NCCL communicator for the specified group type |
Usage Examples
// NcclKernel base class usage
NcclKernel::NcclKernel(const OpKernelInfo& info) : CudaKernel(info) {
static NcclContext context;
nccl_ = &context;
int64_t group_type;
info.GetAttrOrDefault("group_type", &group_type, static_cast<int64_t>(0));
group_type_ = static_cast<training::WorkerGroupType>(group_type);
}
// Type conversion
ncclDataType_t dtype = GetNcclDataType(onnx_type);