Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL DPO Loss Fn

From Leeroopedia


Knowledge Sources
Domains Alignment, Optimization
Last Updated 2026-02-07 20:00 GMT

Overview

Concrete DPO loss function implementations provided by the Alibaba ROLL library.

Description

The get_logps function extracts per-sequence log probabilities from per-token values. The loss_fn function computes the DPO/IPO/cDPO loss from policy and reference log probabilities.

Usage

Called during the train_step of the DPO ActorWorker.

Code Reference

Source Location

  • Repository: Alibaba ROLL
  • File: roll/pipeline/dpo/actor_worker.py
  • Lines: L15-59

Signature

def get_logps(
    per_token_logps: torch.LongTensor,
    attention_mask,
    prompt_id_lens,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    """
    Extract chosen and rejected sequence-level log probabilities.

    Args:
        per_token_logps: Per-token log probs (2*B, seq_len)
        attention_mask: Attention mask
        prompt_id_lens: Prompt lengths for masking

    Returns:
        (chosen_logps, rejected_logps) each shape (B,)
    """

def loss_fn(
    policy_chosen_logps: torch.Tensor,
    policy_rejected_logps: torch.Tensor,
    reference_chosen_logps: torch.Tensor,
    reference_rejected_logps: torch.Tensor,
    ipo: bool = False,
    beta: float = 0.1,
    label_smoothing: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute DPO/IPO/cDPO loss.

    Returns:
        (loss, chosen_rewards, rejected_rewards)
    """

Import

from roll.pipeline.dpo.actor_worker import get_logps, loss_fn

I/O Contract

Inputs

Name Type Required Description
policy_chosen_logps torch.Tensor Yes Policy log probs for chosen (B,)
policy_rejected_logps torch.Tensor Yes Policy log probs for rejected (B,)
reference_chosen_logps torch.Tensor Yes Reference log probs for chosen (B,)
reference_rejected_logps torch.Tensor Yes Reference log probs for rejected (B,)
beta float Yes DPO temperature

Outputs

Name Type Description
loss torch.Tensor Scalar DPO loss
chosen_rewards torch.Tensor Implicit rewards for chosen (B,)
rejected_rewards torch.Tensor Implicit rewards for rejected (B,)

Usage Examples

from roll.pipeline.dpo.actor_worker import get_logps, loss_fn

chosen_logps, rejected_logps = get_logps(per_token_logps, mask, prompt_lens)
loss, chosen_rew, rejected_rew = loss_fn(
    chosen_logps, rejected_logps,
    ref_chosen_logps, ref_rejected_logps,
    beta=0.1, ipo=False, label_smoothing=0.0
)

Related Pages

Implements Principle

Requires Environment

Environment Dependencies

This implementation requires the following environment constraints:

Heuristics Applied

This implementation uses the following heuristics:

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment