Implementation:Eric mitchell Direct preference optimization Concatenated Inputs Fn
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Data_Manipulation, Efficiency_Optimization |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Concrete tool for concatenating chosen and rejected batch tensors into a single tensor provided by the direct-preference-optimization repository.
Description
The concatenated_inputs function takes a batch dictionary with chosen_* and rejected_* keys, pads them to the same length, and concatenates them along the batch dimension. It produces a new dictionary with concatenated_* keys ready for a single model forward pass.
Usage
Called by concatenated_forward before running the model. The output keys use the prefix concatenated_ instead of chosen_ or rejected_.
Code Reference
Source Location
- Repository: direct-preference-optimization
- File: trainers.py
- Lines: 118-142
Signature
def concatenated_inputs(
batch: Dict[str, Union[List, torch.LongTensor]],
) -> Dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.
Args:
batch: A batch of data. Must contain the keys 'chosen_input_ids' and
'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
Returns:
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
"""
Import
from trainers import concatenated_inputs
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| batch | Dict | Yes | Must contain chosen_input_ids, chosen_attention_mask, chosen_labels, rejected_input_ids, rejected_attention_mask, rejected_labels (all torch.LongTensor of shape (batch_size, seq_len)) |
Outputs
| Name | Type | Description |
|---|---|---|
| concatenated_batch | Dict[str, torch.LongTensor] | Contains concatenated_input_ids, concatenated_attention_mask, concatenated_labels. Each has shape (2*batch_size, max_seq_len) where max_seq_len = max(chosen_len, rejected_len) |
Usage Examples
Preparing Inputs for Forward Pass
from trainers import concatenated_inputs
# batch has chosen_input_ids (4, 200) and rejected_input_ids (4, 180)
concatenated_batch = concatenated_inputs(batch)
# concatenated_batch['concatenated_input_ids'] shape: (8, 200)
# First 4 rows are chosen, last 4 rows are rejected (padded to 200)
Related Pages
Implements Principle
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment