Implementation:Eric mitchell Direct preference optimization Get Batch Metrics
| Knowledge Sources | |
|---|---|
| Domains | Evaluation, Training, NLP |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Concrete tool for computing SFT or DPO loss and associated metrics for a single batch provided by the direct-preference-optimization repository.
Description
The get_batch_metrics method on BasicTrainer computes the loss and collects metrics for either SFT or DPO training. It branches based on the loss configuration:
- DPO/IPO branch: Calls concatenated_forward for both policy and reference models, then calls preference_loss to get losses, chosen_rewards, and rejected_rewards. Computes reward accuracies and margins.
- SFT branch: Runs the policy on chosen sequences, computes log probabilities, and returns negative log probability as the loss.
Both branches gather metrics across distributed processes and return the mean loss plus a dictionary of logged metrics.
Usage
Called by BasicTrainer.train for each training microbatch (to compute gradients) and for each evaluation batch (under torch.no_grad()). The train parameter controls metric key prefixes.
Code Reference
Source Location
- Repository: direct-preference-optimization
- File: trainers.py
- Lines: 223-270
Signature
class BasicTrainer(object):
def get_batch_metrics(
self,
batch: Dict[str, Union[List, torch.LongTensor]],
loss_config: DictConfig,
train: bool = True,
) -> Tuple[torch.FloatTensor, Dict]:
"""Compute the SFT or DPO loss and other metrics for the given batch of inputs."""
Import
from trainers import BasicTrainer
# Method accessed as trainer.get_batch_metrics(batch, loss_config, train=True)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| batch | Dict | Yes | Tokenized batch from get_batch_iterator |
| loss_config | DictConfig | Yes | Loss configuration with name ("sft", "dpo", or "ipo") and parameters (beta, label_smoothing, reference_free) |
| train | bool | No | If True, metric keys use "train" prefix; if False, use "eval" (default True) |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.FloatTensor | Mean loss for the batch (scalar) |
| metrics | Dict[str, List[float]] | DPO: rewards_train/chosen, rewards_train/rejected, rewards_train/accuracies, rewards_train/margins, logps_train/chosen, logps_train/rejected, loss/train. SFT: logps_train/chosen, loss/train. |
Usage Examples
Training Step
# Inside BasicTrainer.train() - training
loss, metrics = self.get_batch_metrics(local_microbatch, self.config.loss, train=True)
(loss / self.config.gradient_accumulation_steps).backward()
Evaluation Step
# Inside BasicTrainer.train() - evaluation
with torch.no_grad():
_, eval_metrics = self.get_batch_metrics(local_eval_batch, self.config.loss, train=False)