Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI Train RM Script

From Leeroopedia


Knowledge Sources
Domains Reward Modeling, RLHF, Distributed Training
Last Updated 2026-02-09 00:00 GMT

Overview

train_rm.py is a training script for training reward models used in the RLHF pipeline, supporting both LogSigLoss and LogExpLoss objectives on paired preference data.

Description

This script implements a reward model training pipeline using ColossalAI's distributed training infrastructure. It initializes a RewardModel from a pretrained language model backbone, configures a ranking loss function (either log-sigmoid or log-exp), sets up distributed training with Gemini, ZeRO-2, DDP, or 3D hybrid parallelism plugins, and invokes the RewardModelTrainer for training on preference datasets. The reward model learns to assign higher scores to preferred responses through pairwise comparison training. It uses custom shard policies for 3D parallelism via get_autopolicy.

Usage

Use this script to train a reward model on paired preference data (chosen vs. rejected responses) as a prerequisite for PPO-based RLHF training. The resulting reward model provides the reward signal during PPO training. Launch with torchrun for distributed execution.

Code Reference

Source Location

Signature

def train(args) -> None

Import

# This is a standalone training script, typically run directly:
# torchrun --nproc_per_node=<N> train_rm.py --pretrain <model_path> --dataset <data_path>

I/O Contract

Inputs

Name Type Required Description
--pretrain str Yes Path to the pretrained model backbone
--dataset str (nargs=+) Yes Paths to tokenized training dataset(s) with preference pairs
--plugin str No Plugin: gemini, gemini_auto, zero2, zero2_cpu, 3d, ddp (default: gemini)
--loss_fn str No Loss function: log_sig or log_exp (default: log_sig)
--eval_dataset str (nargs=+) No Paths to evaluation dataset(s)
--checkpoint_path str No Path to resume training from checkpoint
--lora_config str No Path to LoRA configuration file
--max_length int No Maximum sequence length (default: 2048)
--max_epochs int No Maximum training epochs (default: 3)
--batch_size int No Batch size per process (default: 4)
--lr float No Learning rate (default: 5e-6)
--accumulation_steps int No Gradient accumulation steps (default: 8)
--mixed_precision str No Mixed precision: fp16 or bf16 (default: fp16)
--save_interval int No Steps between checkpoints (default: 1000)
--tp int No Tensor parallelism size (default: 1)
--pp int No Pipeline parallelism size (default: 1)
--sp int No Sequence parallelism size (default: 1)
--grad_checkpoint flag No Enable gradient checkpointing
--use_flash_attn flag No Enable flash attention

Outputs

Name Type Description
checkpoint directory Reward model checkpoint saved to --save_dir/modeling
config_file JSON Training configuration saved to --config_file

Usage Examples

# Train reward model with ZeRO-2 on 4 GPUs:
# torchrun --nproc_per_node=4 train_rm.py \
#     --pretrain meta-llama/Llama-2-7b \
#     --dataset ./preference_data \
#     --plugin zero2 \
#     --loss_fn log_sig \
#     --lr 5e-6 \
#     --max_epochs 3 \
#     --save_dir ./rm_checkpoint

# Train with 3D parallelism and gradient checkpointing:
# torchrun --nproc_per_node=8 train_rm.py \
#     --pretrain meta-llama/Llama-2-7b \
#     --dataset ./preference_data \
#     --plugin 3d \
#     --tp 2 --pp 2 \
#     --grad_checkpoint \
#     --use_flash_attn

Key Features

  • Dual Loss Functions - Supports LogSigLoss (log-sigmoid) and LogExpLoss (log-exp) for pairwise preference ranking
  • RewardModel Wrapper - Uses RewardModel class that adds a value head on top of a language model backbone
  • Custom Shard Policy - Applies get_autopolicy for the reward model's inner model when using 3D parallelism
  • Preference Data Collation - Uses DataCollatorForPreferenceDataset for batching chosen/rejected response pairs
  • LoRA Support - Optional low-rank adaptation with weight merging at evaluation time
  • Right Padding - Sets tokenizer padding_side to "right" for reward model training

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment