Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Hpcaitech ColossalAI Ray Broadcast Tensor Dict

From Leeroopedia


Knowledge Sources
Domains Distributed_Computing, Infrastructure
Last Updated 2026-02-09 00:00 GMT

Overview

Concrete tool for broadcasting tensor dictionaries across Ray actors using collective communication, provided by ColossalChat.

Description

ray_broadcast_tensor_dict() broadcasts a dictionary of PyTorch tensors from a source rank to all other ranks in a Ray collective group. It handles metadata broadcasting, device placement, and special bfloat16 handling for Gloo backends.

Usage

Called after each consumer training step to distribute updated weights to producers. Also called by producers to receive weights.

Code Reference

Source Location

  • Repository: ColossalAI
  • File: applications/ColossalChat/coati/distributed/comm.py
  • Lines: 36-75

Signature

def ray_broadcast_tensor_dict(
    tensor_dict: Dict[str, torch.Tensor],
    src: int = 0,
    device=None,
    group_name: str = "default",
    backend: str = "nccl",
    offload_to_cpu: bool = False,
    pin_memory: bool = False,
) -> Dict[str, torch.Tensor]:
    """
    Broadcast tensor dict from src rank to all ranks in the group.

    Args:
        tensor_dict: Dictionary of tensors to broadcast
        src: Source rank (default: 0)
        device: Target device for tensors
        group_name: Ray collective group name
        backend: Communication backend ("nccl" or "gloo")
        offload_to_cpu: Move received tensors to CPU
        pin_memory: Pin memory for faster GPU transfer
    """

Import

from coati.distributed.comm import ray_broadcast_tensor_dict

I/O Contract

Inputs

Name Type Required Description
tensor_dict Dict[str, Tensor] Yes Model state_dict to broadcast
src int No Source rank (default: 0)
group_name str No Ray collective group name
backend str No "nccl" or "gloo" (default: "nccl")

Outputs

Name Type Description
tensor_dict Dict[str, Tensor] Received tensors on all non-source ranks

Usage Examples

from coati.distributed.comm import ray_broadcast_tensor_dict

# On consumer (src=0): broadcast weights
state_dict = model.state_dict()
ray_broadcast_tensor_dict(
    tensor_dict=state_dict,
    src=0,
    group_name="sync_group",
    backend="nccl",
)

# On producer (src!=0): receive weights
received_weights = ray_broadcast_tensor_dict(
    tensor_dict={},  # Empty dict on non-source ranks
    src=0,
    group_name="sync_group",
    backend="nccl",
)
model.load_state_dict(received_weights)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment