Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Eric mitchell Direct preference optimization Concatenated Inputs Fn

From Leeroopedia


Knowledge Sources
Domains Data_Manipulation, Efficiency_Optimization
Last Updated 2026-02-08 02:00 GMT

Overview

Concrete tool for concatenating chosen and rejected batch tensors into a single tensor provided by the direct-preference-optimization repository.

Description

The concatenated_inputs function takes a batch dictionary with chosen_* and rejected_* keys, pads them to the same length, and concatenates them along the batch dimension. It produces a new dictionary with concatenated_* keys ready for a single model forward pass.

Usage

Called by concatenated_forward before running the model. The output keys use the prefix concatenated_ instead of chosen_ or rejected_.

Code Reference

Source Location

Signature

def concatenated_inputs(
    batch: Dict[str, Union[List, torch.LongTensor]],
) -> Dict[str, torch.LongTensor]:
    """Concatenate the chosen and rejected inputs into a single tensor.

    Args:
        batch: A batch of data. Must contain the keys 'chosen_input_ids' and
               'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).

    Returns:
        A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
    """

Import

from trainers import concatenated_inputs

I/O Contract

Inputs

Name Type Required Description
batch Dict Yes Must contain chosen_input_ids, chosen_attention_mask, chosen_labels, rejected_input_ids, rejected_attention_mask, rejected_labels (all torch.LongTensor of shape (batch_size, seq_len))

Outputs

Name Type Description
concatenated_batch Dict[str, torch.LongTensor] Contains concatenated_input_ids, concatenated_attention_mask, concatenated_labels. Each has shape (2*batch_size, max_seq_len) where max_seq_len = max(chosen_len, rejected_len)

Usage Examples

Preparing Inputs for Forward Pass

from trainers import concatenated_inputs

# batch has chosen_input_ids (4, 200) and rejected_input_ids (4, 180)
concatenated_batch = concatenated_inputs(batch)
# concatenated_batch['concatenated_input_ids'] shape: (8, 200)
# First 4 rows are chosen, last 4 rows are rejected (padded to 200)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment