Implementation:FMInference FlexLLMGen Initialize Distributed
| Knowledge Sources | |
|---|---|
| Domains | Distributed Computing, Pipeline Parallelism |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
Initializes PyTorch distributed communication and creates pairwise process groups for pipeline-parallel inference across multiple GPUs.
Description
The dist_utils module provides the foundational distributed communication primitives used by FlexLLMGen's multi-GPU inference engine. The central function initialize_distributed() sets up the PyTorch distributed backend (either NCCL for GPU-based communication or Gloo for CPU-based communication) and constructs pairwise process groups linking each adjacent pair of pipeline stages.
For a world of N processes, the module creates N process groups, each containing a pair of ranks (pred, succ) where succ = (pred + 1) % N. Each rank stores references to its predecessor group (for receiving) and its successor group (for sending). This ring topology enables efficient point-to-point tensor transfers between consecutive pipeline stages.
The module also provides accessor functions get_pipeline_parallel_pred_group() and get_pipeline_parallel_succ_group() that return the appropriate process group for inter-stage communication, get_comm_device() that returns the configured communication device type, and suppress_output() that overrides the built-in print function so only forced output or rank-0 output is displayed.
Usage
Call initialize_distributed() once at startup on each process before any distributed inference begins. The accessor functions are then used throughout the distributed inference pipeline to retrieve the correct communication groups and device settings.
Code Reference
Source Location
- Repository: FMInference_FlexLLMGen
- File: flexllmgen/dist_utils.py
- Lines: 1-64
Signature
def initialize_distributed(head_ip, port, world_size, rank, local_rank, comm_device):
...
def get_pipeline_parallel_pred_group():
...
def get_pipeline_parallel_succ_group():
...
def get_comm_device():
...
def suppress_output(rank):
...
Import
from flexllmgen.dist_utils import (
initialize_distributed,
get_pipeline_parallel_pred_group,
get_pipeline_parallel_succ_group,
get_comm_device,
suppress_output,
)
I/O Contract
Inputs (initialize_distributed)
| Name | Type | Required | Description |
|---|---|---|---|
| head_ip | str | Yes | IP address of the head (rank-0) node used for the TCP rendezvous endpoint. |
| port | int | Yes | Port number on the head node for the rendezvous endpoint. |
| world_size | int | Yes | Total number of processes participating in distributed inference. |
| rank | int | Yes | Global rank of the current process (0-indexed). |
| local_rank | int | Yes | Local rank on the current node, used to select the CUDA device. |
| comm_device | str | Yes | Communication device type: "cpu" selects the Gloo backend; "gpu" selects the NCCL backend. |
Outputs
| Name | Type | Description |
|---|---|---|
| (side effect) | None | Sets module-level globals _COMM_DEVICE, _PIPELINE_PARALLEL_PRED_GROUP, and _PIPELINE_PARALLEL_SUCC_GROUP. Initializes the PyTorch distributed process group. Overrides built-in print via suppress_output(). |
Usage Examples
from flexllmgen.dist_utils import (
initialize_distributed,
get_pipeline_parallel_pred_group,
get_pipeline_parallel_succ_group,
get_comm_device,
)
# Initialize on each process at startup
initialize_distributed(
head_ip="192.168.1.1",
port=29500,
world_size=4,
rank=2,
local_rank=2,
comm_device="gpu",
)
# Later, retrieve the process groups for send/recv
pred_group = get_pipeline_parallel_pred_group()
succ_group = get_pipeline_parallel_succ_group()
comm_dev = get_comm_device() # returns "gpu"