Implementation:Hpcaitech ColossalAI Zero Bubble Producer
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Training, RLHF, Inference, Zero_Bubble_Pipeline |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
producer.py defines BaseProducer and SimpleProducer, the inference workers in the zero-bubble distributed RL pipeline that generate rollouts, compute rewards, and send experience data to consumers.
Description
BaseProducer manages model inference with configurable backends (vLLM or transformers), dataset loading via RawConversationDataset with distributed sampling, reward computation using VerifiableReward (supporting math, boxed math, and code reward functions), and asynchronous model weight synchronization. The loop method iterates through training episodes, performing rollouts, computing rewards, and sending batched outputs to consumers via the shared data actor. It also handles periodic model evaluation with result logging to Weights & Biases and JSONL files. SimpleProducer extends BaseProducer as a Ray remote actor, adding vLLM-specific evaluation sampling parameters, rollout logging, and concrete implementations of rollout (generate + decode) and load_state_dict methods. The producer supports temperature annealing during the first training episode and vLLM sleep/wake mode for memory management.
Usage
Used as the inference worker in the zero-bubble distributed GRPO/DAPO training pipeline. It is instantiated as a Ray remote actor by launch_distributed and paired with GRPOConsumer training workers and Distributor weight synchronization actors.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/distributed/zero_bubble/producer.py
- Lines: 1-540
Signature
class BaseProducer:
def __init__(
self,
shared_sync_data_actor: SharedVariableActor,
shared_signal_actor: SharedVariableActor,
producer_idx: int,
num_producers: int,
num_consumer_procs: int,
num_episodes: int,
batch_size: int,
train_dataset_config: Dict[str, Any],
model_config: Dict[str, Any],
generate_config: Dict[str, Any],
tokenizer_config: Optional[Dict[str, Any]] = None,
microbatch_size: int = 1,
backend: str = "transformers",
consumer_plugin_config: Dict[str, Any] = None,
eval_dataset_config=None,
eval_interval=-1,
grpo_config: Dict[str, Any] = None,
eval_save_dir: str = "./eval",
project_name: str = None,
run_name: str = None,
wandb_group_name: str = None,
log_rollout_interval: int = 20,
rollout_log_file: str = "./rollout_log.jsonl",
enable_profiling: bool = False,
)
@ray.remote
class SimpleProducer(BaseProducer):
def __init__(
self,
..., # same as BaseProducer plus:
num_generations: int = 8,
eval_generation_config={},
)
Key Methods
# BaseProducer
def init_collective_group(self, world_size, rank, backend, group_name, gloo_timeout)
def rollout(self, input_ids, attention_mask, **kwargs) -> Dict[str, torch.Tensor] # abstract
def load_state_dict(self, state_dict) -> None # abstract
def loop(self) -> None
# SimpleProducer
@torch.no_grad()
def rollout(self, input_ids, attention_mask, **kwargs) -> Dict[str, torch.Tensor]
def load_state_dict(self, state_dict) -> None
Import
from coati.distributed.zero_bubble.producer import BaseProducer, SimpleProducer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| shared_sync_data_actor | SharedVariableActor | Yes | Ray actor for buffered data exchange to consumers |
| shared_signal_actor | SharedVariableActor | Yes | Ray actor for signal-based coordination |
| producer_idx | int | Yes | Index of this producer in the producer group |
| num_producers | int | Yes | Total number of producer workers |
| num_consumer_procs | int | Yes | Number of consumer processes |
| batch_size | int | Yes | Inference batch size per producer |
| train_dataset_config | Dict[str, Any] | Yes | Dataset configuration with "path" key |
| model_config | Dict[str, Any] | Yes | Model configuration with "path" key for model loading |
| generate_config | Dict[str, Any] | Yes | Generation parameters (temperature, top_p, max_tokens, etc.) |
| grpo_config | Dict[str, Any] | Yes | GRPO config including reward_fn_type ("think_answer_tags", "boxed", or "code") |
| backend | str | No | Inference backend, "transformers" or "vllm" (default: "transformers") |
Outputs (rollout method)
| Name | Type | Description |
|---|---|---|
| input_ids | torch.Tensor | Generated token sequences [batch_size, num_generations, seq_len] |
| attention_mask | torch.Tensor | Attention masks for generated sequences |
| action_log_probs | torch.Tensor | Log probabilities of generated actions |
| action_mask | torch.Tensor | Mask indicating action (response) tokens |
| response_idx | torch.Tensor | Start and end indices of responses |
| reward | torch.Tensor | Computed reward scores [batch_size, num_generations, 1] |
| format_acc | torch.Tensor | Format accuracy scores |
| ans_acc | torch.Tensor | Answer accuracy scores |
| temperature | torch.Tensor | Current generation temperature values |
Usage Examples
import ray
from coati.distributed.zero_bubble.producer import SimpleProducer
# Created by launch_distributed, not typically instantiated manually:
producer = SimpleProducer.options(num_gpus=4, num_cpus=4).remote(
shared_sync_data_actor=data_actor,
shared_signal_actor=signal_actor,
producer_idx=0,
num_producers=2,
num_consumer_procs=8,
num_episodes=3,
batch_size=16,
train_dataset_config={"path": "/data/train.jsonl"},
model_config={"path": "Qwen/Qwen2.5-7B"},
generate_config={"temperature": 1.0, "top_p": 0.95, "max_tokens": 2048},
grpo_config={"reward_fn_type": "think_answer_tags"},
backend="vllm",
num_generations=8,
)