Implementation:NVIDIA TransformerEngine TP Group Setup
Overview
Pattern for initializing tensor-parallel and data-parallel process groups for TE training.
Description
This pattern initializes torch.distributed, creates a tensor-parallel process group from ranked GPU subsets, and optionally creates a data-parallel group. The resulting tp_group and tp_size are passed to TE TransformerLayer construction.
This is a Pattern Doc. It describes a common initialization pattern rather than a single function or class. The pattern is extracted from TransformerEngine's example scripts and represents the recommended approach for setting up distributed process groups.
The pattern consists of the following steps:
- Initialize the distributed backend using
torch.distributed.init_process_groupwith NCCL. - Set the CUDA device for the current rank based on
LOCAL_RANK. - Create TP subgroups by enumerating rank lists, where each sublist contains
tp_sizeconsecutive ranks. - Create DP subgroups (optional) by grouping ranks that share the same TP-local position across replicas.
- Extract TP metadata (tp_rank, tp_size) for use in model construction and RNG seeding.
Source
examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py, L193-257
Interface
# Pattern: TP Group Initialization
import os
import torch
import transformer_engine.pytorch as te
# Step 1: Read environment variables (set by torchrun or equivalent launcher)
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
tp_size = 4 # Configure as needed
# Step 2: Initialize distributed backend
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
nccl_world = torch.distributed.new_group(backend="nccl")
# Step 3: Create tensor-parallel subgroups
# Each subgroup contains tp_size consecutive ranks
ranks_per_replica = []
for i in range(0, world_size, tp_size):
ranks_per_replica.append([j for j in range(i, i + tp_size)])
tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
ranks_per_replica_list=ranks_per_replica,
backend="nccl",
)
# Step 4: Extract TP metadata
tp_rank = torch.distributed.get_rank(tp_group)
# Step 5: Use tp_group and tp_size in TE module construction
layer = te.TransformerLayer(
hidden_size=4096,
ffn_hidden_size=11008,
num_attention_heads=32,
tp_group=tp_group,
tp_size=tp_size,
set_parallel_mode=True,
)
I/O
| Direction | Description |
|---|---|
| Input | Environment variables: RANK (global rank), WORLD_SIZE (total GPUs), LOCAL_RANK (GPU index on current node). Configuration: tp_size (tensor-parallel group size).
|
| Output | tp_group (ProcessGroup): The tensor-parallel process group for this rank. dp_group (ProcessGroup, optional): The data-parallel process group. tp_size (int): Tensor-parallel group size. tp_rank (int): This rank's position within its TP group.
|
Key Components
| Component | Description |
|---|---|
torch.distributed.init_process_group |
Initializes the NCCL distributed backend. Must be called before any collective operations. |
torch.distributed.new_group |
Creates a new process group. Used to create the world-level NCCL group. |
torch.distributed.new_subgroups_by_enumeration |
Creates multiple subgroups from explicit rank lists. Each rank list becomes a TP group. |
torch.cuda.set_device |
Binds the current process to its local GPU. Must be called before any CUDA operations. |
tp_group |
The resulting process group passed to te.TransformerLayer(tp_group=...).
|
tp_size |
The group size passed to te.TransformerLayer(tp_size=...).
|
Example: Full Initialization with RNG Tracking
import os
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
def initialize_distributed(tp_size: int):
"""Complete distributed initialization for TE tensor-parallel training."""
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
# Initialize backend
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(backend="nccl")
# Create TP groups
ranks_per_replica = []
for i in range(0, world_size, tp_size):
ranks_per_replica.append([j for j in range(i, i + tp_size)])
tp_group, _ = torch.distributed.new_subgroups_by_enumeration(
ranks_per_replica_list=ranks_per_replica,
backend="nccl",
)
tp_rank = torch.distributed.get_rank(tp_group)
# Set up RNG tracking for reproducible model-parallel dropout
base_seed = 1234
rng_tracker = CudaRNGStatesTracker()
rng_tracker.add("model-parallel-rng", base_seed + tp_rank)
def get_rng_tracker():
return rng_tracker
return tp_group, tp_size, tp_rank, get_rng_tracker