Implementation:Alibaba ROLL DPOTrainer
| Knowledge Sources | |
|---|---|
| Domains | Training, DPO, Preference_Optimization |
| Last Updated | 2026-02-07 20:00 GMT |
Overview
DPO (Direct Preference Optimization) and ORPO (Odds Ratio Preference Optimization) trainer implementation for Megatron distributed training with reference model support.
Description
DPOTrainer extends McaTrainer to implement preference-based optimization algorithms for aligning language models with human preferences in a distributed Megatron-Core training environment. It supports two loss functions: sigmoid (standard DPO) and orpo (ORPO).
Core design: The trainer processes paired data where each batch contains both chosen and rejected responses (batch dimension is doubled). It tracks eight metrics: loss, chosen/rejected rewards, reward accuracies, reward margins, chosen/rejected log probabilities, and SFT loss.
Key methods:
- __init__ (lines 39-59): Initializes with an optional reference model. Validates that per-token loss and sequence packing are not used (unsupported for DPO).
- odds_ratio_loss (lines 97-116): Implements ORPO loss (modified from LLaMA-Factory). Computes per-token average log probabilities, then calculates the log odds ratio between chosen and rejected responses. The final loss combines an SFT loss term with a beta-weighted odds ratio loss.
- dpo_loss (lines 118-137): Implements standard DPO sigmoid loss (modified from TRL). Computes the log-ratio difference between policy and reference models for chosen vs rejected responses, then applies a sigmoid loss with optional label smoothing.
- _post_compute_loss (lines 139-186): Dispatches to either dpo_loss or odds_ratio_loss based on train_config.pref_loss. Handles context parallelism by all-reducing log probabilities and response lengths. Computes all eight training metrics.
- compute_reference_log_probs (lines 193-210): Runs forward-only inference through the reference model to obtain chosen and rejected log probabilities. Only the last pipeline stage returns results (other stages return None).
- training_step (lines 212-251): Main training loop: (1) gathers reference log probabilities if using a reference model, (2) runs the forward-backward pass with the policy model, (3) performs optimizer step with gradient norm tracking, and (4) returns loss, metrics, and gradient statistics.
Usage
Use DPOTrainer for preference-based training of Megatron-Core models. Provide a reference model for standard DPO (pref_loss="sigmoid") or omit it for ORPO (pref_loss="orpo"). Input data must contain paired chosen/rejected examples stacked along the batch dimension.
Code Reference
Source Location
- Repository: Alibaba_ROLL
- File: mcore_adapter/src/mcore_adapter/trainer/dpo_trainer.py
- Lines: 1-267
Signature
class DPOTrainer(McaTrainer):
metrics_keys: list[str] = [
"loss", "rewards/chosen", "rewards/rejected", "rewards/accuracies",
"rewards/margins", "logps/chosen", "logps/rejected", "sft_loss",
]
def __init__(
self,
model: VirtualModels = None,
train_config: DPOConfig = None,
ref_model: Optional[VirtualModels] = None,
args: TrainingArguments = None,
**kwargs,
) -> None: ...
def odds_ratio_loss(
self,
chosen_logps: torch.Tensor,
rejected_logps: torch.Tensor,
response_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ...
def dpo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ...
def _post_compute_loss(
self,
labels: torch.Tensor,
loss_mask: torch.Tensor,
ref_chosen_logps: torch.Tensor,
ref_rejected_logps: torch.Tensor,
logits: torch.Tensor,
) -> Tuple[torch.Tensor, dict]: ...
def compute_reference_log_probs(
self,
models: VirtualModels,
data_list: List[Dict[str, Any]],
seq_length: int,
micro_batch_size: int,
) -> Optional[List[Dict[str, torch.Tensor]]]: ...
def training_step(
self,
models: List[DistributedDataParallel],
data_iterator,
seq_length: int,
) -> Tuple[torch.Tensor, list, int, float, int]: ...
Import
from mcore_adapter.trainer.dpo_trainer import DPOTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | VirtualModels | Yes | The policy model to train |
| train_config | DPOConfig | Yes | DPO-specific configuration (beta, pref_loss type, label_smoothing, use_ref_model) |
| ref_model | VirtualModels or None | No | Reference model for KL-constrained DPO; required when use_ref_model is True |
| args | TrainingArguments | Yes | General training arguments (learning rate, batch size, etc.) |
| models | List[DistributedDataParallel] | Yes (training_step) | List of DDP-wrapped model shards |
| data_iterator | Iterator | Yes (training_step) | Iterator yielding paired chosen/rejected training batches |
| seq_length | int | Yes (training_step) | Sequence length for the current step |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor | Mean training loss for the step |
| metrics_tensors | list[dict] | List of metric dictionaries per microbatch |
| skipped_iter | int | 1 if optimizer step was skipped (e.g., due to gradient overflow), 0 otherwise |
| grad_norm | float | Gradient norm after clipping |
| num_zeros_in_grad | int | Number of zero-valued gradient elements |
Usage Examples
from mcore_adapter.trainer.dpo_trainer import DPOTrainer
from mcore_adapter.trainer.dpo_config import DPOConfig
# Configure DPO training
dpo_config = DPOConfig(
pref_loss="sigmoid", # or "orpo"
beta=0.1,
label_smoothing=0.0,
use_ref_model=True,
)
# Create trainer with policy and reference models
trainer = DPOTrainer(
model=policy_model,
train_config=dpo_config,
ref_model=reference_model,
args=training_args,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
)
# Run a training step
loss, metrics, skipped, grad_norm, num_zeros = trainer.training_step(
models=ddp_models,
data_iterator=data_iter,
seq_length=2048,
)