Implementation:Hpcaitech ColossalAI Train Prompts On Ray
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement Learning, RLHF, Distributed Training, Ray |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
train_prompts_on_ray.py is a community-contributed script that implements distributed PPO training for RLHF using Ray as the orchestration framework across multiple GPU nodes.
Description
This script defines a complete distributed PPO training pipeline built on Ray, distributing the four core RLHF models (actor, critic, reward model, and initial/reference model) across separate Ray actor groups. Each group manages a cluster of GPU workers coordinated via placement groups and distributed torch. The script implements experience collection through asynchronous Ray remote calls, where sequence generation, action log probability computation, value estimation, and reward calculation are performed in parallel across workers before being aggregated for the learning step.
Usage
Use this script when you need to scale PPO-based RLHF training across multiple nodes or GPUs using Ray for resource management. It supports GPT-2, BLOOM, and OPT model architectures with DDP, Gemini, and Zero2 strategies. Run via command line with appropriate arguments for model type, strategy, and cluster configuration.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/examples/community/ray/train_prompts_on_ray.py
- Lines: 1-569
Signature
# Key classes and functions defined in this script:
class ExperienceCompositionRefs:
def __init__(self, sequences_attention_mask_action_mask_ref, action_log_probs_ref,
base_action_log_probs_ref, value_ref, r_ref) -> None
class ExperienceMaker:
def __init__(self, kl_coef) -> None
def make_experience(self, experiment_computation_refs: ExperienceCompositionRefs)
class DistributedTorchRayActor:
def __init__(self, world_size, rank, local_rank, master_addr, master_port)
class BasePPORole(DistributedTorchRayActor): ...
class TrainablePPORole(BasePPORole): ...
@ray.remote(num_gpus=1)
class RayPPOActor(TrainablePPORole): ...
@ray.remote(num_gpus=1)
class RayPPOCritic(TrainablePPORole): ...
@ray.remote(num_gpus=1)
class RayPPORewardModel(BasePPORole): ...
@ray.remote(num_gpus=1)
class RayPPOInitialModel(BasePPORole): ...
class PPORayActorGroup: ...
class PPOActorRayActorGroup(TrainableModelRayActorGroup): ...
class PPOCriticRayActorGroup(TrainableModelRayActorGroup): ...
class PPOInitialRayActorGroup(PPORayActorGroup): ...
class PPORewardRayActorGroup(PPORayActorGroup): ...
def main(args): ...
Import
# This is a standalone training script, typically run directly:
# python train_prompts_on_ray.py --model gpt2 --strategy ddp --prompt_csv_url <url>
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| --prompt_csv_url | str | Yes | URL to CSV file containing prompts for training |
| --strategy | str | No | Training strategy: ddp, colossalai_gemini, or colossalai_zero2 (default: ddp) |
| --model | str | No | Model architecture: gpt2, bloom, or opt (default: gpt2) |
| --pretrain | str | No | Pretrained model name or path (default: gpt2) |
| --save_path | str | No | Path to save actor checkpoint (default: actor_checkpoint_prompts.pt) |
| --num_episodes | int | No | Number of training episodes (default: 10) |
| --max_timesteps | int | No | Maximum timesteps per episode (default: 10) |
| --update_timesteps | int | No | Timesteps between learning updates (default: 10) |
| --experience_batch_size | int | No | Batch size for experience collection (default: 8) |
| --num_actor_nodes | int | No | Number of nodes for actor model (default: 1) |
| --num_critic_nodes | int | No | Number of nodes for critic model (default: 1) |
| --num_initial_nodes | int | No | Number of nodes for initial model (default: 1) |
| --num_reward_nodes | int | No | Number of nodes for reward model (default: 1) |
| --num_gpus_per_node | int | No | Number of GPUs per node (default: 1) |
Outputs
| Name | Type | Description |
|---|---|---|
| checkpoint | file | Actor model checkpoint saved to --save_path |
Usage Examples
# Run distributed PPO training on Ray with GPT-2:
# python train_prompts_on_ray.py \
# --prompt_csv_url "path/to/prompts.csv" \
# --model gpt2 \
# --pretrain gpt2 \
# --strategy colossalai_zero2 \
# --num_episodes 10 \
# --num_actor_nodes 2 \
# --num_gpus_per_node 4
Architecture
The script organizes four model groups into Ray actor groups, each managing multiple GPU workers:
- PPOActorRayActorGroup - Manages actor model replicas for sequence generation and policy training
- PPOCriticRayActorGroup - Manages critic model replicas for value estimation and value training
- PPOInitialRayActorGroup - Manages reference model replicas for base action log probability computation
- PPORewardRayActorGroup - Manages reward model replicas for reward scoring
Experience collection is pipelined asynchronously: the actor generates sequences, while the critic, initial model, and reward model compute their respective outputs in parallel using Ray object references.