Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Eric mitchell Direct preference optimization Get Batch Metrics

From Leeroopedia


Knowledge Sources
Domains Evaluation, Training, NLP
Last Updated 2026-02-08 02:00 GMT

Overview

Concrete tool for computing SFT or DPO loss and associated metrics for a single batch provided by the direct-preference-optimization repository.

Description

The get_batch_metrics method on BasicTrainer computes the loss and collects metrics for either SFT or DPO training. It branches based on the loss configuration:

  • DPO/IPO branch: Calls concatenated_forward for both policy and reference models, then calls preference_loss to get losses, chosen_rewards, and rejected_rewards. Computes reward accuracies and margins.
  • SFT branch: Runs the policy on chosen sequences, computes log probabilities, and returns negative log probability as the loss.

Both branches gather metrics across distributed processes and return the mean loss plus a dictionary of logged metrics.

Usage

Called by BasicTrainer.train for each training microbatch (to compute gradients) and for each evaluation batch (under torch.no_grad()). The train parameter controls metric key prefixes.

Code Reference

Source Location

Signature

class BasicTrainer(object):
    def get_batch_metrics(
        self,
        batch: Dict[str, Union[List, torch.LongTensor]],
        loss_config: DictConfig,
        train: bool = True,
    ) -> Tuple[torch.FloatTensor, Dict]:
        """Compute the SFT or DPO loss and other metrics for the given batch of inputs."""

Import

from trainers import BasicTrainer
# Method accessed as trainer.get_batch_metrics(batch, loss_config, train=True)

I/O Contract

Inputs

Name Type Required Description
batch Dict Yes Tokenized batch from get_batch_iterator
loss_config DictConfig Yes Loss configuration with name ("sft", "dpo", or "ipo") and parameters (beta, label_smoothing, reference_free)
train bool No If True, metric keys use "train" prefix; if False, use "eval" (default True)

Outputs

Name Type Description
loss torch.FloatTensor Mean loss for the batch (scalar)
metrics Dict[str, List[float]] DPO: rewards_train/chosen, rewards_train/rejected, rewards_train/accuracies, rewards_train/margins, logps_train/chosen, logps_train/rejected, loss/train. SFT: logps_train/chosen, loss/train.

Usage Examples

Training Step

# Inside BasicTrainer.train() - training
loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True)
(loss / self.config.gradient_accumulation_steps).backward()

Evaluation Step

# Inside BasicTrainer.train() - evaluation
with torch.no_grad():
    _, eval_metrics = self.get_batch_metrics(local_eval_batch, self.config.loss, train=False)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment