Implementation:Hpcaitech ColossalAI Ray Broadcast Tensor Dict
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Infrastructure |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for broadcasting tensor dictionaries across Ray actors using collective communication, provided by ColossalChat.
Description
ray_broadcast_tensor_dict() broadcasts a dictionary of PyTorch tensors from a source rank to all other ranks in a Ray collective group. It handles metadata broadcasting, device placement, and special bfloat16 handling for Gloo backends.
Usage
Called after each consumer training step to distribute updated weights to producers. Also called by producers to receive weights.
Code Reference
Source Location
- Repository: ColossalAI
- File: applications/ColossalChat/coati/distributed/comm.py
- Lines: 36-75
Signature
def ray_broadcast_tensor_dict(
tensor_dict: Dict[str, torch.Tensor],
src: int = 0,
device=None,
group_name: str = "default",
backend: str = "nccl",
offload_to_cpu: bool = False,
pin_memory: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Broadcast tensor dict from src rank to all ranks in the group.
Args:
tensor_dict: Dictionary of tensors to broadcast
src: Source rank (default: 0)
device: Target device for tensors
group_name: Ray collective group name
backend: Communication backend ("nccl" or "gloo")
offload_to_cpu: Move received tensors to CPU
pin_memory: Pin memory for faster GPU transfer
"""
Import
from coati.distributed.comm import ray_broadcast_tensor_dict
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| tensor_dict | Dict[str, Tensor] | Yes | Model state_dict to broadcast |
| src | int | No | Source rank (default: 0) |
| group_name | str | No | Ray collective group name |
| backend | str | No | "nccl" or "gloo" (default: "nccl") |
Outputs
| Name | Type | Description |
|---|---|---|
| tensor_dict | Dict[str, Tensor] | Received tensors on all non-source ranks |
Usage Examples
from coati.distributed.comm import ray_broadcast_tensor_dict
# On consumer (src=0): broadcast weights
state_dict = model.state_dict()
ray_broadcast_tensor_dict(
tensor_dict=state_dict,
src=0,
group_name="sync_group",
backend="nccl",
)
# On producer (src!=0): receive weights
received_weights = ray_broadcast_tensor_dict(
tensor_dict={}, # Empty dict on non-source ranks
src=0,
group_name="sync_group",
backend="nccl",
)
model.load_state_dict(received_weights)
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment