Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Princeton nlp SimPO SimPOTrainer

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, NLP, Preference_Optimization
Last Updated 2026-02-08 04:30 GMT

Overview

Concrete tool for training language models with the SimPO preference optimization algorithm, extending HuggingFace's Trainer class.

Description

SimPOTrainer is the core training class that implements the SimPO algorithm. It extends transformers.Trainer and overrides the loss computation to use SimPO's length-normalized preference loss. The trainer handles model loading (from string name via AutoModelForCausalLM), PEFT/LoRA wrapping, dataset tokenization, preference data collation, training loop execution, metric logging, and evaluation. Key internal methods include:

  • simpo_loss — Computes the SimPO loss (sigmoid or hinge) from chosen/rejected log probabilities
  • get_batch_logps — Computes length-normalized average log probabilities per sequence
  • concatenated_forward — Runs a single forward pass on concatenated chosen+rejected inputs for FSDP efficiency
  • tokenize_row — Tokenizes a single preference example with proper BOS/EOS handling and truncation
  • get_batch_loss_metrics — Orchestrates forward pass, loss computation, and metric collection

Usage

Instantiate after all preparation steps (config parsing, data loading, tokenizer/model config). The trainer accepts a model name string and handles loading internally. Call trainer.train() to begin training and trainer.save_model() to save the result.

Code Reference

Source Location

  • Repository: SimPO
  • File: scripts/simpo_trainer.py (Lines 46-893)

Signature

class SimPOTrainer(Trainer):
    def __init__(
        self,
        model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
        args: Optional[SimPOConfig] = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable] = None,
        peft_config: Optional[Dict] = None,
        compute_metrics: Optional[Callable] = None,
    ):
        """
        Args:
            model: Model name string (lazy-loaded) or pre-instantiated model.
            args: SimPOConfig with training hyperparameters.
            data_collator: Custom collator (default: DPODataCollatorWithPadding).
            train_dataset: Dataset with 'prompt', 'chosen', 'rejected' columns.
            eval_dataset: Evaluation dataset.
            tokenizer: Tokenizer for data processing.
            peft_config: LoRA configuration dict for PEFT wrapping.
        """

    def simpo_loss(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """Compute SimPO loss. Returns (losses, chosen_rewards, rejected_rewards)."""

    @staticmethod
    def get_batch_logps(
        logits: torch.FloatTensor,
        labels: torch.LongTensor,
        average_log_prob: bool = True,
        label_pad_token_id: int = -100,
        is_encoder_decoder: bool = False,
    ) -> torch.FloatTensor:
        """Compute per-sequence average log probabilities."""

    def concatenated_forward(
        self, model: nn.Module, batch: Dict
    ) -> Tuple[torch.FloatTensor, ...]:
        """Single forward pass on concatenated chosen+rejected inputs."""

    def tokenize_row(
        self, feature: Dict, model: Optional[PreTrainedModel] = None
    ) -> Dict:
        """Tokenize a single preference example with BOS/EOS handling."""

    def get_batch_loss_metrics(
        self, model, batch: Dict,
        train_eval: Literal["train", "eval"] = "train",
    ) -> Tuple[torch.Tensor, Dict]:
        """Compute loss and all training metrics for a batch."""

    def compute_loss(
        self, model, inputs: Dict, return_outputs: bool = False,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
        """Override Trainer.compute_loss with SimPO loss."""

    def train(
        self, resume_from_checkpoint: Optional[str] = None
    ) -> TrainOutput:
        """Run the full training loop (inherited from Trainer)."""

Import

from simpo_trainer import SimPOTrainer
from simpo_config import SimPOConfig

I/O Contract

Inputs

Name Type Required Description
model str or PreTrainedModel Yes Model name (lazy-loaded) or pre-instantiated model
args SimPOConfig Yes Training config with beta, gamma_beta_ratio, sft_weight, loss_type
train_dataset Dataset Yes Dataset with prompt, chosen, rejected string columns
eval_dataset Dataset No Evaluation dataset (same format)
tokenizer PreTrainedTokenizerBase Yes Tokenizer for data collation and tokenization
peft_config Dict No LoRA config for PEFT wrapping (from get_peft_config)

Outputs

Name Type Description
train() returns TrainOutput Contains global_step, training_loss, metrics dict
Training metrics Dict rewards/chosen, rewards/rejected, rewards/accuracies, rewards/margins, logps/chosen, logps/rejected, logits/chosen, logits/rejected
Model weights Updated in-place Model parameters optimized with SimPO loss

Usage Examples

Complete SimPO Training

from simpo_trainer import SimPOTrainer
from simpo_config import SimPOConfig
from alignment import (
    H4ArgumentParser, ModelArguments, DataArguments,
    get_tokenizer, get_peft_config, get_quantization_config,
    get_kbit_device_map, get_datasets,
)
import torch

# 1. Parse configuration
parser = H4ArgumentParser((ModelArguments, DataArguments, SimPOConfig))
model_args, data_args, training_args = parser.parse()

# 2. Load datasets
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits,
    columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"])

# 3. Prepare tokenizer and model kwargs
data_args.truncation_side = "left"
tokenizer = get_tokenizer(model_args, data_args)

quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
    revision=model_args.model_revision,
    trust_remote_code=model_args.trust_remote_code,
    torch_dtype=getattr(torch, model_args.torch_dtype) if model_args.torch_dtype not in ["auto", None] else model_args.torch_dtype,
    device_map=get_kbit_device_map() if quantization_config else None,
    quantization_config=quantization_config,
    attn_implementation=model_args.attn_implementation,
)
training_args.model_init_kwargs = model_kwargs

# 4. Apply chat template (see Apply_Chat_Template implementation)
# ... (omitted for brevity)

# 5. Instantiate trainer
trainer = SimPOTrainer(
    model=model_args.model_name_or_path,  # String triggers lazy loading
    args=training_args,
    train_dataset=raw_datasets["train"],
    eval_dataset=raw_datasets["test"],
    tokenizer=tokenizer,
    peft_config=get_peft_config(model_args),
)

# 6. Train
train_result = trainer.train()
print(f"Training loss: {train_result.metrics['train_loss']}")

# 7. Save model
trainer.save_model(training_args.output_dir)

Key SimPO Hyperparameters

# SimPO-specific parameters (in YAML config or SimPOConfig):
# beta: 2.0            - Loss scaling factor
# gamma_beta_ratio: 0.5 - Target reward margin / beta
# sft_weight: 0.0      - SFT regularization weight (0 = pure SimPO)
# loss_type: "sigmoid"  - Loss variant ("sigmoid" or "hinge")
# max_length: 2048      - Max total sequence length
# max_prompt_length: 1800 - Max prompt length before truncation

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment