Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Heuristic:Princeton nlp SimPO Concatenated Forward Pass

From Leeroopedia
Revision as of 10:42, 16 February 2026 by Admin (talk | contribs) (Auto-imported from heuristics/Princeton_nlp_SimPO_Concatenated_Forward_Pass.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)




Knowledge Sources
Domains Optimization, Deep_Learning, Distributed_Training
Last Updated 2026-02-08 05:00 GMT

Overview

Concatenate chosen and rejected inputs into a single forward pass instead of two separate passes, reducing overhead for FSDP and other distributed training strategies.

Description

The SimPOTrainer's concatenated_forward method combines the chosen and rejected inputs into a single batch before running the model forward pass. This is an optimization for distributed training (especially FSDP), where each forward pass incurs synchronization overhead. By processing both chosen and rejected sequences in one call, the trainer halves the number of forward passes and their associated communication costs. The chosen and rejected log probabilities are then split from the combined output.

Usage

Use this heuristic when implementing preference optimization trainers for distributed training, especially with FSDP. This is automatically applied by the SimPOTrainer; no configuration is needed.

The Insight (Rule of Thumb)

  • Action: Concatenate chosen and rejected input_ids, attention_mask, and labels along the batch dimension before the forward pass.
  • Action: Split the output logits back into chosen (first half) and rejected (second half) after the forward pass.
  • Action: Set `use_cache=False` in the forward call (required for gradient computation and memory efficiency).
  • Value: 1 forward pass instead of 2 per training step.
  • Trade-off: Doubles the effective batch size in GPU memory for a single forward pass. Ensure VRAM is sufficient.

Reasoning

In FSDP (Fully Sharded Data Parallelism), each forward pass triggers all-gather operations to reconstruct model parameters from shards across GPUs. Two separate forward passes means two all-gather rounds. Concatenating chosen and rejected into one pass halves this communication. The same logic applies to DeepSpeed ZeRO-3, where parameter gathering is expensive. The trade-off is increased memory usage per forward pass, but this is typically offset by the use of gradient checkpointing.

Code evidence from `scripts/simpo_trainer.py:597-603`:

def concatenated_forward(
    self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.

    We do this to avoid doing two forward passes, because it's faster for FSDP.
    """

Code evidence from `scripts/simpo_trainer.py:622-627`:

all_logits = model(
    concatenated_batch["concatenated_input_ids"],
    attention_mask=concatenated_batch["concatenated_attention_mask"],
    use_cache=False,
    **model_kwargs,
).logits

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment