Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Allenai Open instruct PolicyTrainerRayProcess

From Leeroopedia


Type Class (Ray Actor)
Source open_instruct/grpo_fast.py:L141-898
Dependencies ray, deepspeed, torch, transformers, numpy
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete Ray actor class for distributed GRPO policy training on a single GPU, managing the DeepSpeed engine, reference model, optimizer, and weight synchronization, provided by the Open Instruct library.

Description

PolicyTrainerRayProcess is a Ray actor that extends RayProcess and manages one GPU's share of the distributed GRPO training. Each instance:

  1. Initializes the training environment: Sets CUDA device, seeds, and DeepSpeed distributed backend.
  2. Loads the policy model: Creates the policy via AutoModelForCausalLM.from_pretrained() with gradient checkpointing and dropout disabled.
  3. Creates the DeepSpeed engine: Wraps the policy model with DeepSpeed for ZeRO optimization, gradient accumulation, and mixed-precision training.
  4. Loads the reference model: If load_ref_policy=True, loads a separate copy of the model for KL penalty computation. The reference model uses a DeepSpeed inference config (no optimizer).
  5. Sets up weight synchronization: Rank 0 creates a process group linking to all vLLM engine workers for weight broadcast.
  6. Creates the streaming data loader: Builds a StreamingDataLoader that pulls pre-prepared training data from the DataPreparationActor.
  7. Executes training steps: The step() method fetches a batch, computes reference log-probs, iterates over mini-batches with gradient accumulation, computes GRPO loss, and returns training metrics.
  8. Handles checkpointing: Saves DeepSpeed checkpoint states and model weights at configured intervals.
  9. Supports checkpoint resumption: Loads DeepSpeed optimizer, scheduler, and RNG states from a checkpoint directory.

Usage

PolicyTrainerRayProcess instances are created by the main GRPO training script. The user does not instantiate them directly; instead, they are created as Ray remote actors and coordinated via the ModelGroup abstraction.

Code Reference

Source Location

Signature

class PolicyTrainerRayProcess(RayProcess):
    def __init__(
        self,
        world_size: int,
        rank: int,
        local_rank: int,
        master_addr: str | None,
        master_port: int | None,
        args: grpo_utils.ExperimentConfig,
        streaming_config: data_loader_lib.StreamingDataLoaderConfig,
        vllm_config: data_loader_lib.VLLMConfig,
        data_prep_actor_name: str,
        tokenizer: PreTrainedTokenizer,
    ):

Import

from open_instruct.grpo_fast import PolicyTrainerRayProcess

I/O Contract

Constructor Inputs

Name Type Description
world_size int Total number of learner GPUs across all nodes.
rank int Global rank of this learner (0 to world_size-1).
local_rank int Local GPU index on this node.
master_addr None IP address for torch.distributed initialization.
master_port None Port for torch.distributed initialization.
args ExperimentConfig Full experiment configuration.
streaming_config StreamingDataLoaderConfig Generation and data loading configuration.
vllm_config VLLMConfig vLLM engine configuration.
data_prep_actor_name str Name of the DataPreparationActor to pull training data from.
tokenizer PreTrainedTokenizer Tokenizer instance.

Key Methods

Method Return Type Description
from_pretrained(args, model_config, beaker_config, wandb_url, tokenizer) int Load model, create DeepSpeed engine, set up optimizer and scheduler. Returns the number of completed optimization steps (for checkpoint resumption).
setup_model_update_group(vllm_engines) None Create NCCL process group for weight synchronization with vLLM engines.
step() tuple[dict, dict] Execute one training step. Returns (scalar_metrics, array_metrics).
broadcast_to_vllm() list[ObjectRef] Broadcast current policy weights to all vLLM engines.
update_ref_policy() None Update reference policy via Polyak averaging with current policy.
save_model(output_dir, chat_template_name, tokenizer) None Save model weights and tokenizer to disk.
save_checkpoint_state(checkpoint_state_dir, training_step) None Save full training state (model, optimizer, scheduler, RNG states).

Usage Examples

import ray
from open_instruct.grpo_fast import PolicyTrainerRayProcess

# Create learner actors (done by the main training script)
learner = ray.remote(num_gpus=1)(PolicyTrainerRayProcess).remote(
    world_size=2,
    rank=0,
    local_rank=0,
    master_addr="10.0.0.1",
    master_port=29500,
    args=experiment_config,
    streaming_config=streaming_config,
    vllm_config=vllm_config,
    data_prep_actor_name="data_prep_actor",
    tokenizer=tokenizer,
)

# Initialize model and optimizer
steps_done = ray.get(learner.from_pretrained.remote(
    args=experiment_config,
    model_config=model_config,
    beaker_config=beaker_config,
    wandb_url="https://wandb.ai/...",
    tokenizer=tokenizer,
))

# Set up weight sync
ray.get(learner.setup_model_update_group.remote(vllm_engines))

# Execute training step
metrics, array_metrics = ray.get(learner.step.remote())

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment