Implementation:Hpcaitech ColossalAI Train RM Script
| Knowledge Sources | |
|---|---|
| Domains | Reward Modeling, RLHF, Distributed Training |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
train_rm.py is a training script for training reward models used in the RLHF pipeline, supporting both LogSigLoss and LogExpLoss objectives on paired preference data.
Description
This script implements a reward model training pipeline using ColossalAI's distributed training infrastructure. It initializes a RewardModel from a pretrained language model backbone, configures a ranking loss function (either log-sigmoid or log-exp), sets up distributed training with Gemini, ZeRO-2, DDP, or 3D hybrid parallelism plugins, and invokes the RewardModelTrainer for training on preference datasets. The reward model learns to assign higher scores to preferred responses through pairwise comparison training. It uses custom shard policies for 3D parallelism via get_autopolicy.
Usage
Use this script to train a reward model on paired preference data (chosen vs. rejected responses) as a prerequisite for PPO-based RLHF training. The resulting reward model provides the reward signal during PPO training. Launch with torchrun for distributed execution.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/examples/training_scripts/train_rm.py
- Lines: 1-347
Signature
def train(args) -> None
Import
# This is a standalone training script, typically run directly:
# torchrun --nproc_per_node=<N> train_rm.py --pretrain <model_path> --dataset <data_path>
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| --pretrain | str | Yes | Path to the pretrained model backbone |
| --dataset | str (nargs=+) | Yes | Paths to tokenized training dataset(s) with preference pairs |
| --plugin | str | No | Plugin: gemini, gemini_auto, zero2, zero2_cpu, 3d, ddp (default: gemini) |
| --loss_fn | str | No | Loss function: log_sig or log_exp (default: log_sig) |
| --eval_dataset | str (nargs=+) | No | Paths to evaluation dataset(s) |
| --checkpoint_path | str | No | Path to resume training from checkpoint |
| --lora_config | str | No | Path to LoRA configuration file |
| --max_length | int | No | Maximum sequence length (default: 2048) |
| --max_epochs | int | No | Maximum training epochs (default: 3) |
| --batch_size | int | No | Batch size per process (default: 4) |
| --lr | float | No | Learning rate (default: 5e-6) |
| --accumulation_steps | int | No | Gradient accumulation steps (default: 8) |
| --mixed_precision | str | No | Mixed precision: fp16 or bf16 (default: fp16) |
| --save_interval | int | No | Steps between checkpoints (default: 1000) |
| --tp | int | No | Tensor parallelism size (default: 1) |
| --pp | int | No | Pipeline parallelism size (default: 1) |
| --sp | int | No | Sequence parallelism size (default: 1) |
| --grad_checkpoint | flag | No | Enable gradient checkpointing |
| --use_flash_attn | flag | No | Enable flash attention |
Outputs
| Name | Type | Description |
|---|---|---|
| checkpoint | directory | Reward model checkpoint saved to --save_dir/modeling |
| config_file | JSON | Training configuration saved to --config_file |
Usage Examples
# Train reward model with ZeRO-2 on 4 GPUs:
# torchrun --nproc_per_node=4 train_rm.py \
# --pretrain meta-llama/Llama-2-7b \
# --dataset ./preference_data \
# --plugin zero2 \
# --loss_fn log_sig \
# --lr 5e-6 \
# --max_epochs 3 \
# --save_dir ./rm_checkpoint
# Train with 3D parallelism and gradient checkpointing:
# torchrun --nproc_per_node=8 train_rm.py \
# --pretrain meta-llama/Llama-2-7b \
# --dataset ./preference_data \
# --plugin 3d \
# --tp 2 --pp 2 \
# --grad_checkpoint \
# --use_flash_attn
Key Features
- Dual Loss Functions - Supports LogSigLoss (log-sigmoid) and LogExpLoss (log-exp) for pairwise preference ranking
- RewardModel Wrapper - Uses RewardModel class that adds a value head on top of a language model backbone
- Custom Shard Policy - Applies get_autopolicy for the reward model's inner model when using 3D parallelism
- Preference Data Collation - Uses DataCollatorForPreferenceDataset for batching chosen/rejected response pairs
- LoRA Support - Optional low-rank adaptation with weight merging at evaluation time
- Right Padding - Sets tokenizer padding_side to "right" for reward model training