Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI Zero Bubble Producer

From Leeroopedia


Knowledge Sources
Domains Distributed_Training, RLHF, Inference, Zero_Bubble_Pipeline
Last Updated 2026-02-09 00:00 GMT

Overview

producer.py defines BaseProducer and SimpleProducer, the inference workers in the zero-bubble distributed RL pipeline that generate rollouts, compute rewards, and send experience data to consumers.

Description

BaseProducer manages model inference with configurable backends (vLLM or transformers), dataset loading via RawConversationDataset with distributed sampling, reward computation using VerifiableReward (supporting math, boxed math, and code reward functions), and asynchronous model weight synchronization. The loop method iterates through training episodes, performing rollouts, computing rewards, and sending batched outputs to consumers via the shared data actor. It also handles periodic model evaluation with result logging to Weights & Biases and JSONL files. SimpleProducer extends BaseProducer as a Ray remote actor, adding vLLM-specific evaluation sampling parameters, rollout logging, and concrete implementations of rollout (generate + decode) and load_state_dict methods. The producer supports temperature annealing during the first training episode and vLLM sleep/wake mode for memory management.

Usage

Used as the inference worker in the zero-bubble distributed GRPO/DAPO training pipeline. It is instantiated as a Ray remote actor by launch_distributed and paired with GRPOConsumer training workers and Distributor weight synchronization actors.

Code Reference

Source Location

Signature

class BaseProducer:
    def __init__(
        self,
        shared_sync_data_actor: SharedVariableActor,
        shared_signal_actor: SharedVariableActor,
        producer_idx: int,
        num_producers: int,
        num_consumer_procs: int,
        num_episodes: int,
        batch_size: int,
        train_dataset_config: Dict[str, Any],
        model_config: Dict[str, Any],
        generate_config: Dict[str, Any],
        tokenizer_config: Optional[Dict[str, Any]] = None,
        microbatch_size: int = 1,
        backend: str = "transformers",
        consumer_plugin_config: Dict[str, Any] = None,
        eval_dataset_config=None,
        eval_interval=-1,
        grpo_config: Dict[str, Any] = None,
        eval_save_dir: str = "./eval",
        project_name: str = None,
        run_name: str = None,
        wandb_group_name: str = None,
        log_rollout_interval: int = 20,
        rollout_log_file: str = "./rollout_log.jsonl",
        enable_profiling: bool = False,
    )

@ray.remote
class SimpleProducer(BaseProducer):
    def __init__(
        self,
        ...,  # same as BaseProducer plus:
        num_generations: int = 8,
        eval_generation_config={},
    )

Key Methods

# BaseProducer
def init_collective_group(self, world_size, rank, backend, group_name, gloo_timeout)
def rollout(self, input_ids, attention_mask, **kwargs) -> Dict[str, torch.Tensor]  # abstract
def load_state_dict(self, state_dict) -> None  # abstract
def loop(self) -> None

# SimpleProducer
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs) -> Dict[str, torch.Tensor]
def load_state_dict(self, state_dict) -> None

Import

from coati.distributed.zero_bubble.producer import BaseProducer, SimpleProducer

I/O Contract

Inputs

Name Type Required Description
shared_sync_data_actor SharedVariableActor Yes Ray actor for buffered data exchange to consumers
shared_signal_actor SharedVariableActor Yes Ray actor for signal-based coordination
producer_idx int Yes Index of this producer in the producer group
num_producers int Yes Total number of producer workers
num_consumer_procs int Yes Number of consumer processes
batch_size int Yes Inference batch size per producer
train_dataset_config Dict[str, Any] Yes Dataset configuration with "path" key
model_config Dict[str, Any] Yes Model configuration with "path" key for model loading
generate_config Dict[str, Any] Yes Generation parameters (temperature, top_p, max_tokens, etc.)
grpo_config Dict[str, Any] Yes GRPO config including reward_fn_type ("think_answer_tags", "boxed", or "code")
backend str No Inference backend, "transformers" or "vllm" (default: "transformers")

Outputs (rollout method)

Name Type Description
input_ids torch.Tensor Generated token sequences [batch_size, num_generations, seq_len]
attention_mask torch.Tensor Attention masks for generated sequences
action_log_probs torch.Tensor Log probabilities of generated actions
action_mask torch.Tensor Mask indicating action (response) tokens
response_idx torch.Tensor Start and end indices of responses
reward torch.Tensor Computed reward scores [batch_size, num_generations, 1]
format_acc torch.Tensor Format accuracy scores
ans_acc torch.Tensor Answer accuracy scores
temperature torch.Tensor Current generation temperature values

Usage Examples

import ray
from coati.distributed.zero_bubble.producer import SimpleProducer

# Created by launch_distributed, not typically instantiated manually:
producer = SimpleProducer.options(num_gpus=4, num_cpus=4).remote(
    shared_sync_data_actor=data_actor,
    shared_signal_actor=signal_actor,
    producer_idx=0,
    num_producers=2,
    num_consumer_procs=8,
    num_episodes=3,
    batch_size=16,
    train_dataset_config={"path": "/data/train.jsonl"},
    model_config={"path": "Qwen/Qwen2.5-7B"},
    generate_config={"temperature": 1.0, "top_p": 0.95, "max_tokens": 2048},
    grpo_config={"reward_fn_type": "think_answer_tags"},
    backend="vllm",
    num_generations=8,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment