Implementation:Hpcaitech ColossalAI Zero Bubble Consumer
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Training, RLHF, Zero_Bubble_Pipeline |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
consumer.py defines the BaseConsumer class, the base training worker in the zero-bubble distributed RL pipeline that receives rollout data from producers, performs policy optimization, and synchronizes model weights.
Description
BaseConsumer manages the full training loop for the consumer side of the zero-bubble architecture. It receives batched rollout data from a shared SharedVariableActor via Ray, buffers incoming data with dynamic filtering based on reward ranges, and dispatches minibatches for gradient accumulation. The class handles ColossalAI HybridParallelPlugin setup (supporting TP, PP, and ZeRO), distributed process group initialization via Ray collective communication, and asynchronous model weight broadcasting to producers/distributors via a background thread. Key methods include setup (initializes torch DDP and ColossalAI booster), loop (main training loop with data reception and stepping), prepare_mini_batch (creates minibatches from the effective group mapping), and state_dict / step (abstract methods implemented by subclasses).
Usage
This class is not used directly but serves as the base class for algorithm-specific consumers such as GRPOConsumer. It is instantiated as a Ray remote actor within the launch_distributed function.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/distributed/zero_bubble/consumer.py
- Lines: 1-347
Signature
class BaseConsumer:
def __init__(
self,
shared_sync_data_actor: SharedVariableActor,
shared_signal_actor: SharedVariableActor,
num_producers: int,
num_episodes: int,
rank: int,
world_size: int,
master_addr: str,
master_port: int,
train_dataset_size: int,
batch_size: int,
model_config: Dict[str, Any],
plugin_config: Dict[str, Any],
minibatch_size: int = 1,
save_interval: int = 100,
save_dir: str = "./model",
enable_profiling: bool = False,
)
Key Methods
def setup(self) -> None
def get_ddp_config(self) -> Dict[str, Any]
def init_collective_group(self, world_size, rank, backend="nccl", group_name="default", gloo_timeout=3000000)
def state_dict(self) -> Dict[str, torch.Tensor] # abstract
def step(self, **kwargs) -> Optional[float] # abstract
def prepare_mini_batch(self, effective_group_to_raw_group_mapping) -> Tuple[Dict, Dict]
def calculate_effective_group_to_raw_group_mapping(self) -> Dict[int, int]
def loop(self) -> None
Import
from coati.distributed.zero_bubble.consumer import BaseConsumer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| shared_sync_data_actor | SharedVariableActor | Yes | Ray actor for buffered data exchange between producers and consumers |
| shared_signal_actor | SharedVariableActor | Yes | Ray actor for signal-based coordination (model sync, global step) |
| num_producers | int | Yes | Number of producer workers generating rollouts |
| num_episodes | int | Yes | Number of training episodes |
| rank | int | Yes | Rank of this consumer in the distributed group |
| world_size | int | Yes | Total number of consumer processes |
| master_addr | str | Yes | Master address for torch DDP initialization |
| master_port | int | Yes | Master port for torch DDP initialization |
| train_dataset_size | int | Yes | Size of the training dataset |
| batch_size | int | Yes | Training batch size |
| model_config | Dict[str, Any] | Yes | Model configuration for loading the policy model |
| plugin_config | Dict[str, Any] | Yes | ColossalAI plugin configuration (tp_size, pp_size, zero_stage, etc.) |
| minibatch_size | int | No | Minibatch size for gradient accumulation (default: 1) |
Outputs
| Name | Type | Description |
|---|---|---|
| (none) | None | The loop() method runs until all episodes complete, signaling termination via shared_signal_actor |
Usage Examples
# BaseConsumer is typically not instantiated directly.
# It is used as a base class:
from coati.distributed.zero_bubble.consumer import BaseConsumer
class MyConsumer(BaseConsumer):
def step(self, **kwargs):
# Implement training step logic
loss = compute_loss(kwargs)
self.booster.backward(loss, self.optimizer)
self.optimizer.step()
return loss.item()
def state_dict(self):
return self.policy_model.state_dict()