Implementation:Eric mitchell Direct preference optimization Concatenated Forward
| Knowledge Sources | |
|---|---|
| Domains | Efficiency_Optimization, Training |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Concrete tool for computing chosen and rejected log probabilities in a single forward pass provided by the direct-preference-optimization repository.
Description
The concatenated_forward method on BasicTrainer takes a batch containing both chosen and rejected sequences, concatenates them using concatenated_inputs, runs a single forward pass through the model, extracts log probabilities via _get_batch_logps, then splits the results into chosen and rejected portions.
Usage
Called within get_batch_metrics during DPO/IPO training for both the policy model (with gradients enabled) and the reference model (under torch.no_grad()).
Code Reference
Source Location
- Repository: direct-preference-optimization
- File: trainers.py
- Lines: 210-220
Signature
class BasicTrainer(object):
def concatenated_forward(
self,
model: nn.Module,
batch: Dict[str, Union[List, torch.LongTensor]],
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs,
concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
Import
from trainers import BasicTrainer
# Method accessed as trainer.concatenated_forward(model, batch)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | The model to run (policy or reference) |
| batch | Dict | Yes | Batch with chosen_input_ids, chosen_attention_mask, chosen_labels, rejected_input_ids, rejected_attention_mask, rejected_labels |
Outputs
| Name | Type | Description |
|---|---|---|
| chosen_logps | torch.FloatTensor | Log probabilities for chosen responses. Shape: (batch_size,) |
| rejected_logps | torch.FloatTensor | Log probabilities for rejected responses. Shape: (batch_size,) |
Usage Examples
DPO Log Probability Computation
# Inside get_batch_metrics (DPO branch)
policy_chosen_logps, policy_rejected_logps = self.concatenated_forward(self.policy, batch)
with torch.no_grad():
reference_chosen_logps, reference_rejected_logps = self.concatenated_forward(
self.reference_model, batch
)
# Feed into preference_loss
losses, chosen_rewards, rejected_rewards = preference_loss(
policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps, reference_rejected_logps,
beta=loss_config.beta,
)