Implementation:Hpcaitech ColossalAI Zero Bubble GRPOConsumer
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Training, RLHF, GRPO, Zero_Bubble_Pipeline |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
grpo_consumer.py implements the GRPOConsumer Ray remote actor, a specialized consumer for Group Relative Policy Optimization (GRPO) training with support for KL divergence penalties, dynamic batching, and pipeline parallelism.
Description
GRPOConsumer extends BaseConsumer to implement the GRPO training algorithm. It maintains a policy model and an optional reference model (for KL divergence computation when beta > 0), computes group-wise advantage normalization from reward signals, and performs clipped policy gradient updates using PolicyLoss. The step method handles micro-batched forward passes through both policy and reference models, calculates token-level log probabilities via memory_efficient_logprob, computes per-token KL divergence, and applies the GRPO loss with configurable clipping bounds. It supports both pipeline-parallel (PP > 1) and non-PP execution paths. Training metrics including loss, reward, accuracy, entropy, KL divergence, and sample utilization are logged to Weights & Biases. The consumer also supports filtering of truncated responses and out-of-range reward groups.
Usage
Used as the training consumer in the zero-bubble distributed GRPO training pipeline. It is instantiated as a Ray remote actor by launch_distributed and paired with SimpleProducer inference workers and Distributor weight synchronization actors.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/distributed/zero_bubble/grpo_consumer.py
- Lines: 1-535
Signature
@ray.remote
class GRPOConsumer(BaseConsumer):
def __init__(
self,
shared_sync_data_actor: SharedVariableActor,
shared_signal_actor: SharedVariableActor,
num_producers,
num_episodes,
rank,
world_size,
master_addr,
master_port,
train_dataset_size,
batch_size,
model_config,
plugin_config,
minibatch_size=1,
num_generations=8,
tokenizer_config=None,
generate_config=None,
grpo_config={},
save_interval: int = 100,
save_dir="./model",
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
enable_profiling: bool = False,
)
Key Methods
def setup(self) -> None
def step(self, pbar: Any, **kwargs) -> Optional[float]
def state_dict(self) -> Dict[str, torch.Tensor]
Import
from coati.distributed.zero_bubble.grpo_consumer import GRPOConsumer
I/O Contract
Inputs (step method)
| Name | Type | Required | Description |
|---|---|---|---|
| input_ids | torch.Tensor | Yes | Token IDs of shape [minibatch_size, num_generations, seq_len] |
| attention_mask | torch.Tensor | Yes | Attention mask of same shape as input_ids |
| action_mask | torch.Tensor | Yes | Mask indicating response tokens |
| action_log_probs | torch.Tensor | Yes | Log probabilities from the rollout policy |
| reward | torch.Tensor | Yes | Reward scores of shape [minibatch_size, num_generations, 1] |
| format_acc | torch.Tensor | Yes | Format accuracy scores |
| ans_acc | torch.Tensor | Yes | Answer accuracy scores |
| response_idx | torch.Tensor | Yes | Response start/end indices |
| raw_train_mini_batch_reward | List[torch.Tensor] | Yes | Raw reward tensors for logging |
| raw_train_mini_batch_format_acc | List[torch.Tensor] | Yes | Raw format accuracy for logging |
| raw_train_mini_batch_ans_acc | List[torch.Tensor] | Yes | Raw answer accuracy for logging |
| raw_train_mini_batch_response_len | List[torch.Tensor] | Yes | Raw response lengths for logging |
Outputs (step method)
| Name | Type | Description |
|---|---|---|
| loss | Optional[float] | The accumulated loss scalar if gradient update occurred; None if still accumulating gradients |
Key Configuration (grpo_config)
| Name | Type | Default | Description |
|---|---|---|---|
| lr | float | 1e-6 | Learning rate for HybridAdam optimizer |
| clip_eps_low | float | 0.2 | Lower clipping bound for policy ratio |
| clip_eps_high | float | 0.2 | Upper clipping bound for policy ratio |
| beta | float | 0.01 | KL penalty coefficient (0 disables reference model) |
| loss_variation | str | "sample_level" | Loss aggregation strategy |
| filter_range | List[float] | None | [min, max] answer accuracy range for group filtering |
| filter_truncated_response | bool | False | Whether to filter out overlength responses |
| dynamic_batching | bool | True | Whether to filter groups before training step |
| train_microbatch_size | int | (full batch) | Micro-batch size for forward passes |
Usage Examples
import ray
from coati.distributed.zero_bubble.grpo_consumer import GRPOConsumer
# Created by launch_distributed, not typically instantiated manually:
consumer = GRPOConsumer.options(num_gpus=1, num_cpus=4).remote(
shared_sync_data_actor=data_actor,
shared_signal_actor=signal_actor,
num_producers=2,
num_episodes=3,
rank=0,
world_size=8,
master_addr="10.0.0.1",
master_port=29500,
train_dataset_size=10000,
batch_size=8,
model_config={"path": "Qwen/Qwen2.5-7B"},
plugin_config={"tp_size": 1, "pp_size": 1, "zero_stage": 2},
grpo_config={"lr": 1e-6, "beta": 0.01, "clip_eps_low": 0.2},
num_generations=8,
)