Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Hpcaitech ColossalAI GRPOConsumer

From Leeroopedia


Knowledge Sources
Domains Reinforcement_Learning, Distributed_Computing
Last Updated 2026-02-09 00:00 GMT

Overview

Ray remote actor for GRPO policy training with ColossalAI distributed optimization, provided by ColossalChat.

Description

GRPOConsumer extends BaseConsumer to implement the GRPO training step. It initializes a ColossalAI-boosted training pipeline, receives experience batches from producers, computes the clipped policy loss with KL penalty, and performs gradient updates.

Usage

Created automatically by launch_distributed() with appropriate GPU allocation.

Code Reference

Source Location

  • Repository: ColossalAI
  • File: applications/ColossalChat/coati/distributed/grpo_consumer.py
  • Lines: 20-612

Signature

@ray.remote
class GRPOConsumer(BaseConsumer):
    def __init__(
        self,
        num_producers: int,
        num_episodes: int,
        rank: int,
        world_size: int,
        master_addr: str,
        master_port: int,
        num_update_per_episode: int,
        num_recv_per_update: int,
        batch_size: int,
        model_config: Dict,
        plugin_config: Dict,
        minibatch_size: int = 1,
        num_generations: int = 8,
        grpo_config: Dict = {},
        save_interval: int = 100,
        save_dir: str = "./model",
    ):
        """GRPO training consumer with ColossalAI Booster."""

    def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
        """Execute one GRPO training step on a mini-batch."""

Import

from coati.distributed.grpo_consumer import GRPOConsumer

I/O Contract

Inputs

Name Type Required Description
model_config Dict Yes Model pretrained path and config
plugin_config Dict Yes ColossalAI plugin config (zero_stage, etc.)
grpo_config Dict Yes GRPO parameters (num_generations, temperature, etc.)
Experience batches Dict[str, Tensor] Yes From producers: input_ids, log_probs, rewards, advantages

Outputs

Name Type Description
Updated weights Dict[str, Tensor] Broadcast to producers after each update
Training metrics Dict loss, reward, KL divergence, entropy
Checkpoints Files Periodic model/optimizer checkpoints

Usage Examples

# Consumers are created internally by launch_distributed()
# The step() method is called in the consumer's training loop:

# Inside GRPOConsumer.loop() (inherited from BaseConsumer):
for episode in range(num_episodes):
    experience_batch = receive_from_producers()
    loss = self.step(step_idx=episode, pbar=progress_bar)
    broadcast_weights_to_producers()

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment