Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI Zero Bubble Consumer

From Leeroopedia
Revision as of 15:10, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Hpcaitech_ColossalAI_Zero_Bubble_Consumer.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Distributed_Training, RLHF, Zero_Bubble_Pipeline
Last Updated 2026-02-09 00:00 GMT

Overview

consumer.py defines the BaseConsumer class, the base training worker in the zero-bubble distributed RL pipeline that receives rollout data from producers, performs policy optimization, and synchronizes model weights.

Description

BaseConsumer manages the full training loop for the consumer side of the zero-bubble architecture. It receives batched rollout data from a shared SharedVariableActor via Ray, buffers incoming data with dynamic filtering based on reward ranges, and dispatches minibatches for gradient accumulation. The class handles ColossalAI HybridParallelPlugin setup (supporting TP, PP, and ZeRO), distributed process group initialization via Ray collective communication, and asynchronous model weight broadcasting to producers/distributors via a background thread. Key methods include setup (initializes torch DDP and ColossalAI booster), loop (main training loop with data reception and stepping), prepare_mini_batch (creates minibatches from the effective group mapping), and state_dict / step (abstract methods implemented by subclasses).

Usage

This class is not used directly but serves as the base class for algorithm-specific consumers such as GRPOConsumer. It is instantiated as a Ray remote actor within the launch_distributed function.

Code Reference

Source Location

Signature

class BaseConsumer:
    def __init__(
        self,
        shared_sync_data_actor: SharedVariableActor,
        shared_signal_actor: SharedVariableActor,
        num_producers: int,
        num_episodes: int,
        rank: int,
        world_size: int,
        master_addr: str,
        master_port: int,
        train_dataset_size: int,
        batch_size: int,
        model_config: Dict[str, Any],
        plugin_config: Dict[str, Any],
        minibatch_size: int = 1,
        save_interval: int = 100,
        save_dir: str = "./model",
        enable_profiling: bool = False,
    )

Key Methods

def setup(self) -> None
def get_ddp_config(self) -> Dict[str, Any]
def init_collective_group(self, world_size, rank, backend="nccl", group_name="default", gloo_timeout=3000000)
def state_dict(self) -> Dict[str, torch.Tensor]  # abstract
def step(self, **kwargs) -> Optional[float]  # abstract
def prepare_mini_batch(self, effective_group_to_raw_group_mapping) -> Tuple[Dict, Dict]
def calculate_effective_group_to_raw_group_mapping(self) -> Dict[int, int]
def loop(self) -> None

Import

from coati.distributed.zero_bubble.consumer import BaseConsumer

I/O Contract

Inputs

Name Type Required Description
shared_sync_data_actor SharedVariableActor Yes Ray actor for buffered data exchange between producers and consumers
shared_signal_actor SharedVariableActor Yes Ray actor for signal-based coordination (model sync, global step)
num_producers int Yes Number of producer workers generating rollouts
num_episodes int Yes Number of training episodes
rank int Yes Rank of this consumer in the distributed group
world_size int Yes Total number of consumer processes
master_addr str Yes Master address for torch DDP initialization
master_port int Yes Master port for torch DDP initialization
train_dataset_size int Yes Size of the training dataset
batch_size int Yes Training batch size
model_config Dict[str, Any] Yes Model configuration for loading the policy model
plugin_config Dict[str, Any] Yes ColossalAI plugin configuration (tp_size, pp_size, zero_stage, etc.)
minibatch_size int No Minibatch size for gradient accumulation (default: 1)

Outputs

Name Type Description
(none) None The loop() method runs until all episodes complete, signaling termination via shared_signal_actor

Usage Examples

# BaseConsumer is typically not instantiated directly.
# It is used as a base class:
from coati.distributed.zero_bubble.consumer import BaseConsumer

class MyConsumer(BaseConsumer):
    def step(self, **kwargs):
        # Implement training step logic
        loss = compute_loss(kwargs)
        self.booster.backward(loss, self.optimizer)
        self.optimizer.step()
        return loss.item()

    def state_dict(self):
        return self.policy_model.state_dict()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment