Implementation:Princeton nlp SimPO SimPOTrainer
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, NLP, Preference_Optimization |
| Last Updated | 2026-02-08 04:30 GMT |
Overview
Concrete tool for training language models with the SimPO preference optimization algorithm, extending HuggingFace's Trainer class.
Description
SimPOTrainer is the core training class that implements the SimPO algorithm. It extends transformers.Trainer and overrides the loss computation to use SimPO's length-normalized preference loss. The trainer handles model loading (from string name via AutoModelForCausalLM), PEFT/LoRA wrapping, dataset tokenization, preference data collation, training loop execution, metric logging, and evaluation. Key internal methods include:
- simpo_loss — Computes the SimPO loss (sigmoid or hinge) from chosen/rejected log probabilities
- get_batch_logps — Computes length-normalized average log probabilities per sequence
- concatenated_forward — Runs a single forward pass on concatenated chosen+rejected inputs for FSDP efficiency
- tokenize_row — Tokenizes a single preference example with proper BOS/EOS handling and truncation
- get_batch_loss_metrics — Orchestrates forward pass, loss computation, and metric collection
Usage
Instantiate after all preparation steps (config parsing, data loading, tokenizer/model config). The trainer accepts a model name string and handles loading internally. Call trainer.train() to begin training and trainer.save_model() to save the result.
Code Reference
Source Location
- Repository: SimPO
- File: scripts/simpo_trainer.py (Lines 46-893)
Signature
class SimPOTrainer(Trainer):
def __init__(
self,
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
args: Optional[SimPOConfig] = None,
data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
callbacks: Optional[List[TrainerCallback]] = None,
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable] = None,
peft_config: Optional[Dict] = None,
compute_metrics: Optional[Callable] = None,
):
"""
Args:
model: Model name string (lazy-loaded) or pre-instantiated model.
args: SimPOConfig with training hyperparameters.
data_collator: Custom collator (default: DPODataCollatorWithPadding).
train_dataset: Dataset with 'prompt', 'chosen', 'rejected' columns.
eval_dataset: Evaluation dataset.
tokenizer: Tokenizer for data processing.
peft_config: LoRA configuration dict for PEFT wrapping.
"""
def simpo_loss(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute SimPO loss. Returns (losses, chosen_rewards, rejected_rewards)."""
@staticmethod
def get_batch_logps(
logits: torch.FloatTensor,
labels: torch.LongTensor,
average_log_prob: bool = True,
label_pad_token_id: int = -100,
is_encoder_decoder: bool = False,
) -> torch.FloatTensor:
"""Compute per-sequence average log probabilities."""
def concatenated_forward(
self, model: nn.Module, batch: Dict
) -> Tuple[torch.FloatTensor, ...]:
"""Single forward pass on concatenated chosen+rejected inputs."""
def tokenize_row(
self, feature: Dict, model: Optional[PreTrainedModel] = None
) -> Dict:
"""Tokenize a single preference example with BOS/EOS handling."""
def get_batch_loss_metrics(
self, model, batch: Dict,
train_eval: Literal["train", "eval"] = "train",
) -> Tuple[torch.Tensor, Dict]:
"""Compute loss and all training metrics for a batch."""
def compute_loss(
self, model, inputs: Dict, return_outputs: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict]]:
"""Override Trainer.compute_loss with SimPO loss."""
def train(
self, resume_from_checkpoint: Optional[str] = None
) -> TrainOutput:
"""Run the full training loop (inherited from Trainer)."""
Import
from simpo_trainer import SimPOTrainer
from simpo_config import SimPOConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | str or PreTrainedModel | Yes | Model name (lazy-loaded) or pre-instantiated model |
| args | SimPOConfig | Yes | Training config with beta, gamma_beta_ratio, sft_weight, loss_type |
| train_dataset | Dataset | Yes | Dataset with prompt, chosen, rejected string columns |
| eval_dataset | Dataset | No | Evaluation dataset (same format) |
| tokenizer | PreTrainedTokenizerBase | Yes | Tokenizer for data collation and tokenization |
| peft_config | Dict | No | LoRA config for PEFT wrapping (from get_peft_config) |
Outputs
| Name | Type | Description |
|---|---|---|
| train() returns | TrainOutput | Contains global_step, training_loss, metrics dict |
| Training metrics | Dict | rewards/chosen, rewards/rejected, rewards/accuracies, rewards/margins, logps/chosen, logps/rejected, logits/chosen, logits/rejected |
| Model weights | Updated in-place | Model parameters optimized with SimPO loss |
Usage Examples
Complete SimPO Training
from simpo_trainer import SimPOTrainer
from simpo_config import SimPOConfig
from alignment import (
H4ArgumentParser, ModelArguments, DataArguments,
get_tokenizer, get_peft_config, get_quantization_config,
get_kbit_device_map, get_datasets,
)
import torch
# 1. Parse configuration
parser = H4ArgumentParser((ModelArguments, DataArguments, SimPOConfig))
model_args, data_args, training_args = parser.parse()
# 2. Load datasets
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits,
columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"])
# 3. Prepare tokenizer and model kwargs
data_args.truncation_side = "left"
tokenizer = get_tokenizer(model_args, data_args)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
torch_dtype=getattr(torch, model_args.torch_dtype) if model_args.torch_dtype not in ["auto", None] else model_args.torch_dtype,
device_map=get_kbit_device_map() if quantization_config else None,
quantization_config=quantization_config,
attn_implementation=model_args.attn_implementation,
)
training_args.model_init_kwargs = model_kwargs
# 4. Apply chat template (see Apply_Chat_Template implementation)
# ... (omitted for brevity)
# 5. Instantiate trainer
trainer = SimPOTrainer(
model=model_args.model_name_or_path, # String triggers lazy loading
args=training_args,
train_dataset=raw_datasets["train"],
eval_dataset=raw_datasets["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
)
# 6. Train
train_result = trainer.train()
print(f"Training loss: {train_result.metrics['train_loss']}")
# 7. Save model
trainer.save_model(training_args.output_dir)
Key SimPO Hyperparameters
# SimPO-specific parameters (in YAML config or SimPOConfig):
# beta: 2.0 - Loss scaling factor
# gamma_beta_ratio: 0.5 - Target reward margin / beta
# sft_weight: 0.0 - SFT regularization weight (0 = pure SimPO)
# loss_type: "sigmoid" - Loss variant ("sigmoid" or "hinge")
# max_length: 2048 - Max total sequence length
# max_prompt_length: 1800 - Max prompt length before truncation