Implementation:NVIDIA TransformerEngine PyTorch Ext Comm Overlap
| Field | Value |
|---|---|
| Sources | TransformerEngine |
| Domains | Deep_Learning, PyTorch, Distributed |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Implements communication-computation overlap helpers (CommOverlapHelper, CommOverlap, CommOverlapP2P) that overlap NCCL collectives with GEMM computation for distributed training.
Description
CommOverlapHelper manages PyTorch distributed process groups, extracting rank/size/node topology information and providing ub_allgather and ub_barrier wrappers that use PyTorch's c10d collective operations. Supports both MPI-based and NCCL-based backends (controlled by NVTE_UB_WITH_MPI compile flag). CommOverlap and CommOverlapP2P classes wrap the TE core's communication overlap infrastructure, providing methods to set up userbuffer-based communication that runs concurrently with GEMM kernels on separate CUDA streams.
Usage
Enables overlapping all-gather/reduce-scatter with matrix multiplications during distributed training, hiding communication latency behind computation for multi-GPU scaling improvements.
Code Reference
Source Location
- Repository
NVIDIA/TransformerEngine- File
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp- Lines
- 1--320
Signature
namespace transformer_engine::pytorch {
class CommOverlapHelper {
public:
CommOverlapHelper(py::handle world_group);
void ub_allgather(at::Tensor input, at::Tensor output, ...);
void ub_barrier(at::Tensor tensor);
};
class CommOverlap {
public:
CommOverlap(CommOverlapHelper &helper, ...);
// Methods for overlapped AG+GEMM and GEMM+RS
};
class CommOverlapP2P {
public:
CommOverlapP2P(CommOverlapHelper &helper, ...);
// Methods for P2P overlapped communication
};
}
Import
#include "../extensions.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| world_group | py::handle |
Yes | PyTorch distributed process group |
| input | at::Tensor |
Yes | Input tensor for communication |
| output | at::Tensor |
Yes | Output tensor for gathered/scattered data |
Outputs
| Name | Type | Description |
|---|---|---|
| N/A | N/A | Operations are performed in-place on the provided output tensors |
Usage Examples
import transformer_engine_torch as tex
# Create comm overlap helper
helper = tex.CommOverlapHelper(world_group)
# Create overlap object for AG+GEMM fusion
comm_overlap = tex.CommOverlap(helper, tensor_shape, ...)