Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Eric mitchell Direct preference optimization Concatenated Forward

From Leeroopedia


Knowledge Sources
Domains Efficiency_Optimization, Training
Last Updated 2026-02-08 02:00 GMT

Overview

Concrete tool for computing chosen and rejected log probabilities in a single forward pass provided by the direct-preference-optimization repository.

Description

The concatenated_forward method on BasicTrainer takes a batch containing both chosen and rejected sequences, concatenates them using concatenated_inputs, runs a single forward pass through the model, extracts log probabilities via _get_batch_logps, then splits the results into chosen and rejected portions.

Usage

Called within get_batch_metrics during DPO/IPO training for both the policy model (with gradients enabled) and the reference model (under torch.no_grad()).

Code Reference

Source Location

Signature

class BasicTrainer(object):
    def concatenated_forward(
        self,
        model: nn.Module,
        batch: Dict[str, Union[List, torch.LongTensor]],
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        """Run the given model on the given batch of inputs,
        concatenating the chosen and rejected inputs together.

        We do this to avoid doing two forward passes, because it's faster for FSDP.
        """

Import

from trainers import BasicTrainer
# Method accessed as trainer.concatenated_forward(model, batch)

I/O Contract

Inputs

Name Type Required Description
model nn.Module Yes The model to run (policy or reference)
batch Dict Yes Batch with chosen_input_ids, chosen_attention_mask, chosen_labels, rejected_input_ids, rejected_attention_mask, rejected_labels

Outputs

Name Type Description
chosen_logps torch.FloatTensor Log probabilities for chosen responses. Shape: (batch_size,)
rejected_logps torch.FloatTensor Log probabilities for rejected responses. Shape: (batch_size,)

Usage Examples

DPO Log Probability Computation

# Inside get_batch_metrics (DPO branch)
policy_chosen_logps, policy_rejected_logps = self.concatenated_forward(self.policy, batch)

with torch.no_grad():
    reference_chosen_logps, reference_rejected_logps = self.concatenated_forward(
        self.reference_model, batch
    )

# Feed into preference_loss
losses, chosen_rewards, rejected_rewards = preference_loss(
    policy_chosen_logps, policy_rejected_logps,
    reference_chosen_logps, reference_rejected_logps,
    beta=loss_config.beta,
)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment