Heuristic:Princeton nlp SimPO Concatenated Forward Pass
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep_Learning, Distributed_Training |
| Last Updated | 2026-02-08 05:00 GMT |
Overview
Concatenate chosen and rejected inputs into a single forward pass instead of two separate passes, reducing overhead for FSDP and other distributed training strategies.
Description
The SimPOTrainer's concatenated_forward method combines the chosen and rejected inputs into a single batch before running the model forward pass. This is an optimization for distributed training (especially FSDP), where each forward pass incurs synchronization overhead. By processing both chosen and rejected sequences in one call, the trainer halves the number of forward passes and their associated communication costs. The chosen and rejected log probabilities are then split from the combined output.
Usage
Use this heuristic when implementing preference optimization trainers for distributed training, especially with FSDP. This is automatically applied by the SimPOTrainer; no configuration is needed.
The Insight (Rule of Thumb)
- Action: Concatenate chosen and rejected input_ids, attention_mask, and labels along the batch dimension before the forward pass.
- Action: Split the output logits back into chosen (first half) and rejected (second half) after the forward pass.
- Action: Set `use_cache=False` in the forward call (required for gradient computation and memory efficiency).
- Value: 1 forward pass instead of 2 per training step.
- Trade-off: Doubles the effective batch size in GPU memory for a single forward pass. Ensure VRAM is sufficient.
Reasoning
In FSDP (Fully Sharded Data Parallelism), each forward pass triggers all-gather operations to reconstruct model parameters from shards across GPUs. Two separate forward passes means two all-gather rounds. Concatenating chosen and rejected into one pass halves this communication. The same logic applies to DeepSpeed ZeRO-3, where parameter gathering is expensive. The trade-off is increased memory usage per forward pass, but this is typically offset by the use of gradient checkpointing.
Code evidence from `scripts/simpo_trainer.py:597-603`:
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it's faster for FSDP.
"""
Code evidence from `scripts/simpo_trainer.py:622-627`:
all_logits = model(
concatenated_batch["concatenated_input_ids"],
attention_mask=concatenated_batch["concatenated_attention_mask"],
use_cache=False,
**model_kwargs,
).logits