Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Eric mitchell Direct preference optimization Concatenated Forward Pass

From Leeroopedia


Knowledge Sources
Domains Efficiency_Optimization, Training, Deep_Learning
Last Updated 2026-02-08 02:00 GMT

Overview

An efficiency optimization that concatenates chosen and rejected response sequences into a single batch for one forward pass instead of two separate passes.

Description

In DPO training, the loss requires log probabilities for both chosen and rejected responses under the same model. A naive implementation would run two separate forward passes. The concatenated forward pass optimization instead:

  1. Pads chosen and rejected sequences to the same length
  2. Concatenates them along the batch dimension (chosen first, rejected second)
  3. Runs a single forward pass through the model
  4. Splits the resulting log probabilities back into chosen and rejected portions

This is particularly important for FSDP (Fully Sharded Data Parallel) training where each forward pass involves expensive all-gather communication. Halving the number of forward passes significantly improves throughput.

Usage

Use this principle in DPO training when computing log probabilities for both chosen and rejected responses. Applied to both the policy model (with gradients) and the reference model (without gradients).

Theoretical Basis

Since the model processes each sequence independently (no cross-sequence attention in the batch dimension), concatenating chosen and rejected sequences into one batch produces identical results to two separate forward passes:

fθ([yw;yl])=[fθ(yw);fθ(yl)]

This property holds because the causal attention mask ensures no information leaks between sequences in the batch.

Pseudo-code:

# Abstract concatenated forward (NOT actual implementation)
concatenated = concat(chosen_sequences, rejected_sequences)
all_logits = model(concatenated)
all_logps = extract_logps(all_logits)
chosen_logps = all_logps[:batch_size]
rejected_logps = all_logps[batch_size:]

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment