Implementation:Hpcaitech ColossalAI GRPOConsumer
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Distributed_Computing |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Ray remote actor for GRPO policy training with ColossalAI distributed optimization, provided by ColossalChat.
Description
GRPOConsumer extends BaseConsumer to implement the GRPO training step. It initializes a ColossalAI-boosted training pipeline, receives experience batches from producers, computes the clipped policy loss with KL penalty, and performs gradient updates.
Usage
Created automatically by launch_distributed() with appropriate GPU allocation.
Code Reference
Source Location
- Repository: ColossalAI
- File: applications/ColossalChat/coati/distributed/grpo_consumer.py
- Lines: 20-612
Signature
@ray.remote
class GRPOConsumer(BaseConsumer):
def __init__(
self,
num_producers: int,
num_episodes: int,
rank: int,
world_size: int,
master_addr: str,
master_port: int,
num_update_per_episode: int,
num_recv_per_update: int,
batch_size: int,
model_config: Dict,
plugin_config: Dict,
minibatch_size: int = 1,
num_generations: int = 8,
grpo_config: Dict = {},
save_interval: int = 100,
save_dir: str = "./model",
):
"""GRPO training consumer with ColossalAI Booster."""
def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
"""Execute one GRPO training step on a mini-batch."""
Import
from coati.distributed.grpo_consumer import GRPOConsumer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_config | Dict | Yes | Model pretrained path and config |
| plugin_config | Dict | Yes | ColossalAI plugin config (zero_stage, etc.) |
| grpo_config | Dict | Yes | GRPO parameters (num_generations, temperature, etc.) |
| Experience batches | Dict[str, Tensor] | Yes | From producers: input_ids, log_probs, rewards, advantages |
Outputs
| Name | Type | Description |
|---|---|---|
| Updated weights | Dict[str, Tensor] | Broadcast to producers after each update |
| Training metrics | Dict | loss, reward, KL divergence, entropy |
| Checkpoints | Files | Periodic model/optimizer checkpoints |
Usage Examples
# Consumers are created internally by launch_distributed()
# The step() method is called in the consumer's training loop:
# Inside GRPOConsumer.loop() (inherited from BaseConsumer):
for episode in range(num_episodes):
experience_batch = receive_from_producers()
loss = self.step(step_idx=episode, pbar=progress_bar)
broadcast_weights_to_producers()
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment