Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:NVIDIA TransformerEngine Comm GEMM Overlap Training

From Leeroopedia


Knowledge Sources
Domains LLMs, FP8_Training, Distributed_Training, Performance_Optimization
Last Updated 2026-02-07 21:00 GMT

Overview

End-to-end process for training Transformer Engine models with overlapped communication and GEMM operations in tensor-parallel configurations, maximizing GPU utilization by hiding collective communication latency.

Description

This workflow demonstrates how to use TE's Userbuffers system to overlap NCCL collective communication (all-gather, reduce-scatter) with GEMM (matrix multiplication) operations during tensor-parallel training. In standard tensor-parallel training, each forward and backward pass requires multiple collective communication steps that block compute. By overlapping these operations, the communication latency is hidden behind GEMM execution, significantly improving training throughput on multi-GPU systems.

Key outputs:

  • A tensor-parallel training setup with communication-compute overlap
  • Reduced training step latency through latency hiding
  • Support for both single-node tensor parallelism and mixed data/tensor parallelism

Usage

Execute this workflow when running tensor-parallel training across multiple GPUs on a single node (or across nodes with NVLink/NVSwitch connectivity) and you want to maximize throughput by overlapping the all-gather and reduce-scatter communication with the linear layer GEMM computations. This is most beneficial for large models where communication overhead is a significant fraction of the training step time.

Execution Steps

Step 1: Initialize Tensor Parallel Groups

Set up the distributed process group with NCCL backend and create tensor-parallel communication groups. Assign GPUs to tensor-parallel groups based on the desired parallelism degree. Optionally create data-parallel groups if using mixed data/tensor parallelism. Initialize the CUDA RNG state tracker for reproducible results.

Key considerations:

  • Tensor parallel group size typically matches GPUs per node (e.g., 8 for DGX)
  • Data parallel groups span across tensor-parallel groups
  • Use torch.distributed.new_group to create sub-groups for tensor and data parallelism

Step 2: Build Tensor Parallel TE Model

Construct the model using TE modules configured for tensor parallelism. Set tp_group and tp_size parameters on TE layers (TransformerLayer, Linear, LayerNormLinear, LayerNormMLP, MultiheadAttention) to enable automatic weight sharding and collective communication. Each TE layer internally manages the distribution of weights and the required all-gather/reduce-scatter operations.

Key considerations:

  • Set tp_group and tp_size on each TE module
  • Column-parallel layers (QKV projection, MLP fc1) shard along the output dimension
  • Row-parallel layers (output projection, MLP fc2) shard along the input dimension
  • The set_tensor_parallel_group utility can configure all layers at once

Step 3: Configure Communication Overlap

Enable communication-GEMM overlap by setting the ub_overlap flags on TE modules. This activates the Userbuffers system which uses IPC-based shared memory for zero-copy multi-GPU communication. Configure the overlap method (bulk overlap, ring-exchange, or pipeline) and the communication type (all-gather for forward, reduce-scatter for backward).

What happens:

  • Userbuffers allocates shared memory regions accessible by all GPUs in the group
  • During forward: all-gather of the sharded weight is overlapped with the GEMM computation
  • During backward: reduce-scatter of the gradient is overlapped with the weight gradient GEMM
  • The overlap is managed at the CUDA stream level using events and stream synchronization

Step 4: Optional DDP Wrapping for Data Parallelism

If using mixed data/tensor parallelism, wrap the model with PyTorch's DistributedDataParallel (DDP) using the data-parallel process group. DDP handles gradient synchronization across data-parallel replicas while TE handles tensor-parallel communication internally. This configuration allows scaling to very large GPU counts.

Key considerations:

  • Use the data-parallel group (not the global group) for DDP
  • DDP and TE's tensor-parallel communication operate on different process groups
  • Gradient accumulation can be used to increase the effective batch size

Step 5: Execute Overlapped Training Loop

Run the training loop with FP8 autocast enabled. Each training iteration performs a forward pass where all-gather and GEMM are overlapped, computes the loss, runs the backward pass where reduce-scatter and GEMM are overlapped, and updates the optimizer. The overlap is transparent to the training loop code.

Key considerations:

  • Wrap the forward pass with te.autocast for FP8 precision
  • The communication overlap is handled internally by the TE layers
  • Monitor GPU utilization to verify that communication is effectively hidden
  • Environment variable CUDA_DEVICE_MAX_CONNECTIONS=1 may be needed for proper stream scheduling

Execution Diagram

GitHub URL

Workflow Repository