Implementation:NVIDIA TransformerEngine CommOverlapCore
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, Distributed_Computing |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements the CommOverlapCore class that orchestrates overlapping NCCL/userbuffer communication with GEMM computation for tensor-parallel Transformer training.
Description
comm_gemm_overlap.cpp creates a userbuffers communicator for inter-GPU shared memory communication, manages multiple CUDA streams with configurable priorities, and uses CUDA events for synchronization. It supports atomic GEMM mode for Hopper GPUs with persistent CTA execution, and partitions SMs between communication and compute workloads.
Key features:
- Userbuffers communicator: Creates either MPI-based or external-callback-based communicators for inter-GPU communication.
- SM partitioning: Configurable allocation of streaming multiprocessors between communication and compute kernels via
num_comm_smandset_sm_margin. - CUDA stream management: Creates prioritized CUDA streams for compute and communication, with configurable priority levels.
- Fast Dependent Launch: Uses CUDA events to schedule communication kernels before GEMM on Hopper GPUs when
CUDA_DEVICE_MAX_CONNECTIONS > 1. - Atomic GEMM: Supports atomic GEMM counters for fine-grained overlap with persistent CTA execution.
Usage
This class is used internally by the CommOverlapBase and CommOverlapP2PBase derived classes to implement specific overlap strategies (bulk, split-pipelined, atomic GEMM) for tensor-parallel communication.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp- Lines
- 1--1220
Signature
class CommOverlapCore {
public:
CommOverlapCore(int myrank, int numranks, int mylocal, int numlocal,
int mynode, int numnodes, int tp_size,
ExtAllgatherOp allgather_handle,
ExtBarrierOp barrier_handle,
int num_splits, int num_max_streams,
int comm_cga_size, int gemm_priority,
int comm_priority, int num_comm_sm,
bool set_sm_margin, bool use_ce, bool atomic_gemm);
virtual ~CommOverlapCore();
virtual void bulk_overlap(...);
virtual void atomic_gemm_overlap_rs(...);
virtual void split_overlap_rs(...);
virtual void atomic_gemm_overlap_ag(...);
virtual void split_overlap_ag(...);
};
Import
#include <transformer_engine/comm_gemm_overlap.h>
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
myrank |
int |
Yes | Global rank of the current process |
numranks |
int |
Yes | Total number of ranks |
tp_size |
int |
Yes | Tensor parallel world size |
num_splits |
int |
Yes | Number of pipeline splits for communication |
num_comm_sm |
int |
Yes | Number of SMs dedicated to communication |
Outputs
| Name | Type | Description |
|---|---|---|
| overlap result | TensorWrapper |
Result of the overlapped GEMM+communication operation |
Usage Examples
#include <transformer_engine/comm_gemm_overlap.h>
// CommOverlapCore is typically used via derived classes
// CommOverlapBase or CommOverlapP2PBase