Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI Train ORPO Script

From Leeroopedia
Revision as of 15:09, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Hpcaitech_ColossalAI_Train_ORPO_Script.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Preference Optimization, ORPO, Distributed Training, RLHF
Last Updated 2026-02-09 00:00 GMT

Overview

train_orpo.py is a training script that implements Odds Ratio Preference Optimization (ORPO) for aligning language models using paired preference data without a reference model.

Description

This script sets up a complete ORPO training pipeline using ColossalAI's distributed training infrastructure. Unlike DPO or KTO, ORPO eliminates the need for a separate reference model by incorporating an odds ratio-based penalty directly into the supervised fine-tuning loss. The script initializes a single model, configures distributed training plugins (Gemini, ZeRO-2, or 3D hybrid parallelism), loads tokenized preference datasets with chosen/rejected pairs, and invokes the ORPOTrainer for alignment training.

Usage

Use this script when you have paired preference data (chosen vs. rejected responses) and want to align a model without maintaining a separate frozen reference model. This makes ORPO more memory-efficient than DPO. Launch with torchrun for distributed execution.

Code Reference

Source Location

Signature

def train(args) -> None

Import

# This is a standalone training script, typically run directly:
# torchrun --nproc_per_node=<N> train_orpo.py --pretrain <model_path> --dataset <data_path>

I/O Contract

Inputs

Name Type Required Description
--pretrain str Yes Path to the pretrained model
--dataset str (nargs=+) Yes Paths to tokenized training dataset(s)
--plugin str No Plugin: gemini, gemini_auto, zero2, zero2_cpu, 3d (default: gemini)
--lam float No Lambda coefficient in ORPO loss (default: 0.1)
--eval_dataset str (nargs=+) No Paths to evaluation dataset(s)
--checkpoint_path str No Path to resume training from checkpoint
--lora_config str No Path to LoRA configuration file
--max_length int No Maximum sequence length (default: 2048)
--max_epochs int No Maximum training epochs (default: 3)
--batch_size int No Batch size per process (default: 4)
--lr float No Learning rate (default: 5e-6)
--accumulation_steps int No Gradient accumulation steps (default: 8)
--mixed_precision str No Mixed precision: fp16 or bf16 (default: fp16)
--save_interval int No Steps between checkpoints (default: 1000)
--disable_loss_mask flag No Disable loss masking on padding tokens
--grad_checkpoint flag No Enable gradient checkpointing
--use_flash_attn flag No Enable flash attention

Outputs

Name Type Description
checkpoint directory Final model checkpoint saved to --save_dir/modeling
config_file JSON Training configuration saved to --config_file

Usage Examples

# Train ORPO with ZeRO-2 on 4 GPUs:
# torchrun --nproc_per_node=4 train_orpo.py \
#     --pretrain meta-llama/Llama-2-7b \
#     --dataset ./preference_data \
#     --plugin zero2 \
#     --lam 0.1 \
#     --lr 5e-6 \
#     --max_epochs 3 \
#     --save_dir ./orpo_checkpoint

Key Features

  • Reference-Free - No separate reference model required, reducing GPU memory footprint compared to DPO/KTO
  • Lambda Parameter - Controls the weight of the odds ratio penalty term in the combined loss
  • Single Booster - Only one Booster instance needed (no ref_booster), simplifying the setup
  • LoRA Integration - Optional LoRA adaptation with weight merging at evaluation time
  • Preference Data Collation - Uses DataCollatorForPreferenceDataset for handling chosen/rejected pairs
  • Dropout Disabled - Dropout is disabled in the model during ORPO training for stable preference learning

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment