Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine PyTorch Ext Comm Overlap

From Leeroopedia


Field Value
Sources TransformerEngine
Domains Deep_Learning, PyTorch, Distributed
Last Updated 2026-02-07 14:00 GMT

Overview

Implements communication-computation overlap helpers (CommOverlapHelper, CommOverlap, CommOverlapP2P) that overlap NCCL collectives with GEMM computation for distributed training.

Description

CommOverlapHelper manages PyTorch distributed process groups, extracting rank/size/node topology information and providing ub_allgather and ub_barrier wrappers that use PyTorch's c10d collective operations. Supports both MPI-based and NCCL-based backends (controlled by NVTE_UB_WITH_MPI compile flag). CommOverlap and CommOverlapP2P classes wrap the TE core's communication overlap infrastructure, providing methods to set up userbuffer-based communication that runs concurrently with GEMM kernels on separate CUDA streams.

Usage

Enables overlapping all-gather/reduce-scatter with matrix multiplications during distributed training, hiding communication latency behind computation for multi-GPU scaling improvements.

Code Reference

Source Location

Repository
NVIDIA/TransformerEngine
File
transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
Lines
1--320

Signature

namespace transformer_engine::pytorch {

class CommOverlapHelper {
  public:
    CommOverlapHelper(py::handle world_group);
    void ub_allgather(at::Tensor input, at::Tensor output, ...);
    void ub_barrier(at::Tensor tensor);
};

class CommOverlap {
  public:
    CommOverlap(CommOverlapHelper &helper, ...);
    // Methods for overlapped AG+GEMM and GEMM+RS
};

class CommOverlapP2P {
  public:
    CommOverlapP2P(CommOverlapHelper &helper, ...);
    // Methods for P2P overlapped communication
};

}

Import

#include "../extensions.h"

I/O Contract

Inputs

Name Type Required Description
world_group py::handle Yes PyTorch distributed process group
input at::Tensor Yes Input tensor for communication
output at::Tensor Yes Output tensor for gathered/scattered data

Outputs

Name Type Description
N/A N/A Operations are performed in-place on the provided output tensors

Usage Examples

import transformer_engine_torch as tex

# Create comm overlap helper
helper = tex.CommOverlapHelper(world_group)

# Create overlap object for AG+GEMM fusion
comm_overlap = tex.CommOverlap(helper, tensor_shape, ...)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment