Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL DPOTrainer

From Leeroopedia


Knowledge Sources
Domains Training, DPO, Preference_Optimization
Last Updated 2026-02-07 20:00 GMT

Overview

DPO (Direct Preference Optimization) and ORPO (Odds Ratio Preference Optimization) trainer implementation for Megatron distributed training with reference model support.

Description

DPOTrainer extends McaTrainer to implement preference-based optimization algorithms for aligning language models with human preferences in a distributed Megatron-Core training environment. It supports two loss functions: sigmoid (standard DPO) and orpo (ORPO).

Core design: The trainer processes paired data where each batch contains both chosen and rejected responses (batch dimension is doubled). It tracks eight metrics: loss, chosen/rejected rewards, reward accuracies, reward margins, chosen/rejected log probabilities, and SFT loss.

Key methods:

  • __init__ (lines 39-59): Initializes with an optional reference model. Validates that per-token loss and sequence packing are not used (unsupported for DPO).
  • odds_ratio_loss (lines 97-116): Implements ORPO loss (modified from LLaMA-Factory). Computes per-token average log probabilities, then calculates the log odds ratio between chosen and rejected responses. The final loss combines an SFT loss term with a beta-weighted odds ratio loss.
  • dpo_loss (lines 118-137): Implements standard DPO sigmoid loss (modified from TRL). Computes the log-ratio difference between policy and reference models for chosen vs rejected responses, then applies a sigmoid loss with optional label smoothing.
  • _post_compute_loss (lines 139-186): Dispatches to either dpo_loss or odds_ratio_loss based on train_config.pref_loss. Handles context parallelism by all-reducing log probabilities and response lengths. Computes all eight training metrics.
  • compute_reference_log_probs (lines 193-210): Runs forward-only inference through the reference model to obtain chosen and rejected log probabilities. Only the last pipeline stage returns results (other stages return None).
  • training_step (lines 212-251): Main training loop: (1) gathers reference log probabilities if using a reference model, (2) runs the forward-backward pass with the policy model, (3) performs optimizer step with gradient norm tracking, and (4) returns loss, metrics, and gradient statistics.

Usage

Use DPOTrainer for preference-based training of Megatron-Core models. Provide a reference model for standard DPO (pref_loss="sigmoid") or omit it for ORPO (pref_loss="orpo"). Input data must contain paired chosen/rejected examples stacked along the batch dimension.

Code Reference

Source Location

Signature

class DPOTrainer(McaTrainer):
    metrics_keys: list[str] = [
        "loss", "rewards/chosen", "rewards/rejected", "rewards/accuracies",
        "rewards/margins", "logps/chosen", "logps/rejected", "sft_loss",
    ]

    def __init__(
        self,
        model: VirtualModels = None,
        train_config: DPOConfig = None,
        ref_model: Optional[VirtualModels] = None,
        args: TrainingArguments = None,
        **kwargs,
    ) -> None: ...

    def odds_ratio_loss(
        self,
        chosen_logps: torch.Tensor,
        rejected_logps: torch.Tensor,
        response_lens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ...

    def dpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ...

    def _post_compute_loss(
        self,
        labels: torch.Tensor,
        loss_mask: torch.Tensor,
        ref_chosen_logps: torch.Tensor,
        ref_rejected_logps: torch.Tensor,
        logits: torch.Tensor,
    ) -> Tuple[torch.Tensor, dict]: ...

    def compute_reference_log_probs(
        self,
        models: VirtualModels,
        data_list: List[Dict[str, Any]],
        seq_length: int,
        micro_batch_size: int,
    ) -> Optional[List[Dict[str, torch.Tensor]]]: ...

    def training_step(
        self,
        models: List[DistributedDataParallel],
        data_iterator,
        seq_length: int,
    ) -> Tuple[torch.Tensor, list, int, float, int]: ...

Import

from mcore_adapter.trainer.dpo_trainer import DPOTrainer

I/O Contract

Inputs

Name Type Required Description
model VirtualModels Yes The policy model to train
train_config DPOConfig Yes DPO-specific configuration (beta, pref_loss type, label_smoothing, use_ref_model)
ref_model VirtualModels or None No Reference model for KL-constrained DPO; required when use_ref_model is True
args TrainingArguments Yes General training arguments (learning rate, batch size, etc.)
models List[DistributedDataParallel] Yes (training_step) List of DDP-wrapped model shards
data_iterator Iterator Yes (training_step) Iterator yielding paired chosen/rejected training batches
seq_length int Yes (training_step) Sequence length for the current step

Outputs

Name Type Description
loss torch.Tensor Mean training loss for the step
metrics_tensors list[dict] List of metric dictionaries per microbatch
skipped_iter int 1 if optimizer step was skipped (e.g., due to gradient overflow), 0 otherwise
grad_norm float Gradient norm after clipping
num_zeros_in_grad int Number of zero-valued gradient elements

Usage Examples

from mcore_adapter.trainer.dpo_trainer import DPOTrainer
from mcore_adapter.trainer.dpo_config import DPOConfig

# Configure DPO training
dpo_config = DPOConfig(
    pref_loss="sigmoid",  # or "orpo"
    beta=0.1,
    label_smoothing=0.0,
    use_ref_model=True,
)

# Create trainer with policy and reference models
trainer = DPOTrainer(
    model=policy_model,
    train_config=dpo_config,
    ref_model=reference_model,
    args=training_args,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
)

# Run a training step
loss, metrics, skipped, grad_norm, num_zeros = trainer.training_step(
    models=ddp_models,
    data_iterator=data_iter,
    seq_length=2048,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment