Workflow:NVIDIA TransformerEngine Comm GEMM Overlap Training
| Knowledge Sources | |
|---|---|
| Domains | LLMs, FP8_Training, Distributed_Training, Performance_Optimization |
| Last Updated | 2026-02-07 21:00 GMT |
Overview
End-to-end process for training Transformer Engine models with overlapped communication and GEMM operations in tensor-parallel configurations, maximizing GPU utilization by hiding collective communication latency.
Description
This workflow demonstrates how to use TE's Userbuffers system to overlap NCCL collective communication (all-gather, reduce-scatter) with GEMM (matrix multiplication) operations during tensor-parallel training. In standard tensor-parallel training, each forward and backward pass requires multiple collective communication steps that block compute. By overlapping these operations, the communication latency is hidden behind GEMM execution, significantly improving training throughput on multi-GPU systems.
Key outputs:
- A tensor-parallel training setup with communication-compute overlap
- Reduced training step latency through latency hiding
- Support for both single-node tensor parallelism and mixed data/tensor parallelism
Usage
Execute this workflow when running tensor-parallel training across multiple GPUs on a single node (or across nodes with NVLink/NVSwitch connectivity) and you want to maximize throughput by overlapping the all-gather and reduce-scatter communication with the linear layer GEMM computations. This is most beneficial for large models where communication overhead is a significant fraction of the training step time.
Execution Steps
Step 1: Initialize Tensor Parallel Groups
Set up the distributed process group with NCCL backend and create tensor-parallel communication groups. Assign GPUs to tensor-parallel groups based on the desired parallelism degree. Optionally create data-parallel groups if using mixed data/tensor parallelism. Initialize the CUDA RNG state tracker for reproducible results.
Key considerations:
- Tensor parallel group size typically matches GPUs per node (e.g., 8 for DGX)
- Data parallel groups span across tensor-parallel groups
- Use torch.distributed.new_group to create sub-groups for tensor and data parallelism
Step 2: Build Tensor Parallel TE Model
Construct the model using TE modules configured for tensor parallelism. Set tp_group and tp_size parameters on TE layers (TransformerLayer, Linear, LayerNormLinear, LayerNormMLP, MultiheadAttention) to enable automatic weight sharding and collective communication. Each TE layer internally manages the distribution of weights and the required all-gather/reduce-scatter operations.
Key considerations:
- Set tp_group and tp_size on each TE module
- Column-parallel layers (QKV projection, MLP fc1) shard along the output dimension
- Row-parallel layers (output projection, MLP fc2) shard along the input dimension
- The set_tensor_parallel_group utility can configure all layers at once
Step 3: Configure Communication Overlap
Enable communication-GEMM overlap by setting the ub_overlap flags on TE modules. This activates the Userbuffers system which uses IPC-based shared memory for zero-copy multi-GPU communication. Configure the overlap method (bulk overlap, ring-exchange, or pipeline) and the communication type (all-gather for forward, reduce-scatter for backward).
What happens:
- Userbuffers allocates shared memory regions accessible by all GPUs in the group
- During forward: all-gather of the sharded weight is overlapped with the GEMM computation
- During backward: reduce-scatter of the gradient is overlapped with the weight gradient GEMM
- The overlap is managed at the CUDA stream level using events and stream synchronization
Step 4: Optional DDP Wrapping for Data Parallelism
If using mixed data/tensor parallelism, wrap the model with PyTorch's DistributedDataParallel (DDP) using the data-parallel process group. DDP handles gradient synchronization across data-parallel replicas while TE handles tensor-parallel communication internally. This configuration allows scaling to very large GPU counts.
Key considerations:
- Use the data-parallel group (not the global group) for DDP
- DDP and TE's tensor-parallel communication operate on different process groups
- Gradient accumulation can be used to increase the effective batch size
Step 5: Execute Overlapped Training Loop
Run the training loop with FP8 autocast enabled. Each training iteration performs a forward pass where all-gather and GEMM are overlapped, computes the loss, runs the backward pass where reduce-scatter and GEMM are overlapped, and updates the optimizer. The overlap is transparent to the training loop code.
Key considerations:
- Wrap the forward pass with te.autocast for FP8 precision
- The communication overlap is handled internally by the TE layers
- Monitor GPU utilization to verify that communication is effectively hidden
- Environment variable CUDA_DEVICE_MAX_CONNECTIONS=1 may be needed for proper stream scheduling