Implementation:FMInference FlexLLMGen DeepSpeed Comm
| Field | Value |
|---|---|
| Sources | Repo: FlexLLMGen |
| Domains | Distributed_Communication, Collective_Operations |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Vendored DeepSpeed communication backend abstraction that provides a unified API for distributed collective operations, wrapping PyTorch's torch.distributed with additional profiling, timing, and environment auto-discovery capabilities.
Description
comm.py is the central communication module for DeepSpeed. It defines a drop-in replacement for torch.distributed that adds communication profiling, timing decorators, and automatic environment setup for various cloud platforms (Azure ML, AWS SageMaker, MPI).
Key components:
- Backend abstraction -- A global cdb (current DeepSpeed backend) variable holds the active communication backend. By default, this wraps PyTorch's TorchBackend but supports future custom backends (NCCL, MPI, Gloo).
- Timed operations -- The @timed_op decorator wraps all collective operations (all_reduce, all_gather, broadcast, etc.) with optional timing via CUDA events and communication logging via CommsLogger.
- Collective operations -- Provides broadcast, all_gather, all_reduce, reduce, reduce_scatter, scatter, gather, send, recv, isend, irecv, barrier, and all_to_all_single. Each delegates to the underlying backend while adding profiling hooks.
- Compatibility functions -- reduce_scatter_fn and allgather_fn provide fallbacks when optimized _reduce_scatter_base or _all_gather_base operations are not available in the PyTorch version.
- Environment auto-discovery -- init_distributed automatically detects the execution environment (MPI, Azure ML, AWS SageMaker) and configures environment variables (RANK, WORLD_SIZE, MASTER_ADDR, etc.) accordingly via mpi_discovery and platform-specific patching functions.
- ReduceOp enum -- Defines standard reduction operations (SUM, PRODUCT, MIN, MAX, AVG, etc.) compatible with the torch.distributed API.
This is AUTO_KEEP vendored code from DeepSpeed.
Code Reference
| Field | Value |
|---|---|
| Repository | FlexLLMGen |
| File | benchmark/third_party/DeepSpeed/deepspeed/comm/comm.py |
| Lines | 1-771 |
Key Functions and Classes:
class ReduceOp(Enum):
SUM = 0; PRODUCT = 1; MIN = 2; MAX = 3; ...
def init_distributed(dist_backend="nccl", auto_mpi_discovery=True, ...): ...
def broadcast(tensor, src, group=None, async_op=False, ...): ...
def all_gather(tensor_list, tensor, group=None, ...): ...
def all_reduce(tensor, op=ReduceOp.SUM, group=None, ...): ...
def reduce_scatter_fn(output_tensor, tensor, op=ReduceOp.SUM, ...): ...
def get_rank(group=None): ...
def get_world_size(group=None) -> int: ...
def get_local_rank(): ...
def mpi_discovery(distributed_port, verbose=True): ...
I/O Contract
Key Functions
| Function | Inputs | Output | Description |
|---|---|---|---|
| init_distributed | dist_backend, auto_mpi_discovery, distributed_port, timeout, init_method | None | Initializes the distributed backend and sets environment variables |
| all_reduce | tensor, op, group, async_op | Work handle or None | In-place all-reduce across the process group |
| broadcast | tensor, src, group, async_op | Work handle or None | Broadcasts tensor from src rank to all ranks |
| get_rank | group (optional) | int | Returns current process rank |
| get_world_size | group (optional) | int | Returns number of processes in the group |