Principle:Eric mitchell Direct preference optimization Concatenated Forward Pass
| Knowledge Sources | |
|---|---|
| Domains | Efficiency_Optimization, Training, Deep_Learning |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
An efficiency optimization that concatenates chosen and rejected response sequences into a single batch for one forward pass instead of two separate passes.
Description
In DPO training, the loss requires log probabilities for both chosen and rejected responses under the same model. A naive implementation would run two separate forward passes. The concatenated forward pass optimization instead:
- Pads chosen and rejected sequences to the same length
- Concatenates them along the batch dimension (chosen first, rejected second)
- Runs a single forward pass through the model
- Splits the resulting log probabilities back into chosen and rejected portions
This is particularly important for FSDP (Fully Sharded Data Parallel) training where each forward pass involves expensive all-gather communication. Halving the number of forward passes significantly improves throughput.
Usage
Use this principle in DPO training when computing log probabilities for both chosen and rejected responses. Applied to both the policy model (with gradients) and the reference model (without gradients).
Theoretical Basis
Since the model processes each sequence independently (no cross-sequence attention in the batch dimension), concatenating chosen and rejected sequences into one batch produces identical results to two separate forward passes:
This property holds because the causal attention mask ensures no information leaks between sequences in the batch.
Pseudo-code:
# Abstract concatenated forward (NOT actual implementation)
concatenated = concat(chosen_sequences, rejected_sequences)
all_logits = model(concatenated)
all_logps = extract_logps(all_logits)
chosen_logps = all_logps[:batch_size]
rejected_logps = all_logps[batch_size:]