Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI Zero Bubble GRPOConsumer

From Leeroopedia


Knowledge Sources
Domains Distributed_Training, RLHF, GRPO, Zero_Bubble_Pipeline
Last Updated 2026-02-09 00:00 GMT

Overview

grpo_consumer.py implements the GRPOConsumer Ray remote actor, a specialized consumer for Group Relative Policy Optimization (GRPO) training with support for KL divergence penalties, dynamic batching, and pipeline parallelism.

Description

GRPOConsumer extends BaseConsumer to implement the GRPO training algorithm. It maintains a policy model and an optional reference model (for KL divergence computation when beta > 0), computes group-wise advantage normalization from reward signals, and performs clipped policy gradient updates using PolicyLoss. The step method handles micro-batched forward passes through both policy and reference models, calculates token-level log probabilities via memory_efficient_logprob, computes per-token KL divergence, and applies the GRPO loss with configurable clipping bounds. It supports both pipeline-parallel (PP > 1) and non-PP execution paths. Training metrics including loss, reward, accuracy, entropy, KL divergence, and sample utilization are logged to Weights & Biases. The consumer also supports filtering of truncated responses and out-of-range reward groups.

Usage

Used as the training consumer in the zero-bubble distributed GRPO training pipeline. It is instantiated as a Ray remote actor by launch_distributed and paired with SimpleProducer inference workers and Distributor weight synchronization actors.

Code Reference

Source Location

Signature

@ray.remote
class GRPOConsumer(BaseConsumer):
    def __init__(
        self,
        shared_sync_data_actor: SharedVariableActor,
        shared_signal_actor: SharedVariableActor,
        num_producers,
        num_episodes,
        rank,
        world_size,
        master_addr,
        master_port,
        train_dataset_size,
        batch_size,
        model_config,
        plugin_config,
        minibatch_size=1,
        num_generations=8,
        tokenizer_config=None,
        generate_config=None,
        grpo_config={},
        save_interval: int = 100,
        save_dir="./model",
        project_name: str = None,
        run_name: str = None,
        wandb_group_name: str = None,
        enable_profiling: bool = False,
    )

Key Methods

def setup(self) -> None
def step(self, pbar: Any, **kwargs) -> Optional[float]
def state_dict(self) -> Dict[str, torch.Tensor]

Import

from coati.distributed.zero_bubble.grpo_consumer import GRPOConsumer

I/O Contract

Inputs (step method)

Name Type Required Description
input_ids torch.Tensor Yes Token IDs of shape [minibatch_size, num_generations, seq_len]
attention_mask torch.Tensor Yes Attention mask of same shape as input_ids
action_mask torch.Tensor Yes Mask indicating response tokens
action_log_probs torch.Tensor Yes Log probabilities from the rollout policy
reward torch.Tensor Yes Reward scores of shape [minibatch_size, num_generations, 1]
format_acc torch.Tensor Yes Format accuracy scores
ans_acc torch.Tensor Yes Answer accuracy scores
response_idx torch.Tensor Yes Response start/end indices
raw_train_mini_batch_reward List[torch.Tensor] Yes Raw reward tensors for logging
raw_train_mini_batch_format_acc List[torch.Tensor] Yes Raw format accuracy for logging
raw_train_mini_batch_ans_acc List[torch.Tensor] Yes Raw answer accuracy for logging
raw_train_mini_batch_response_len List[torch.Tensor] Yes Raw response lengths for logging

Outputs (step method)

Name Type Description
loss Optional[float] The accumulated loss scalar if gradient update occurred; None if still accumulating gradients

Key Configuration (grpo_config)

Name Type Default Description
lr float 1e-6 Learning rate for HybridAdam optimizer
clip_eps_low float 0.2 Lower clipping bound for policy ratio
clip_eps_high float 0.2 Upper clipping bound for policy ratio
beta float 0.01 KL penalty coefficient (0 disables reference model)
loss_variation str "sample_level" Loss aggregation strategy
filter_range List[float] None [min, max] answer accuracy range for group filtering
filter_truncated_response bool False Whether to filter out overlength responses
dynamic_batching bool True Whether to filter groups before training step
train_microbatch_size int (full batch) Micro-batch size for forward passes

Usage Examples

import ray
from coati.distributed.zero_bubble.grpo_consumer import GRPOConsumer

# Created by launch_distributed, not typically instantiated manually:
consumer = GRPOConsumer.options(num_gpus=1, num_cpus=4).remote(
    shared_sync_data_actor=data_actor,
    shared_signal_actor=signal_actor,
    num_producers=2,
    num_episodes=3,
    rank=0,
    world_size=8,
    master_addr="10.0.0.1",
    master_port=29500,
    train_dataset_size=10000,
    batch_size=8,
    model_config={"path": "Qwen/Qwen2.5-7B"},
    plugin_config={"tp_size": 1, "pp_size": 1, "zero_stage": 2},
    grpo_config={"lr": 1e-6, "beta": 0.01, "clip_eps_low": 0.2},
    num_generations=8,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment