Implementation:Eric mitchell Direct preference optimization BasicTrainer Train
| Knowledge Sources | |
|---|---|
| Domains | Training, Optimization, Deep_Learning |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Concrete tool for executing the SFT or DPO training loop with evaluation interleaving provided by the direct-preference-optimization repository.
Description
The BasicTrainer.train method is the main training loop. It initializes the optimizer and scheduler, then iterates over the training data, computing loss with gradient accumulation, clipping gradients, and performing optimizer steps. Evaluation is interleaved at configurable intervals, with optional text sample generation. Checkpoints are saved at each evaluation point.
Usage
Called after constructing a BasicTrainer (or subclass) instance with a loaded model, config, and data iterators. This is the final step in both SFT and DPO training workflows.
Code Reference
Source Location
- Repository: direct-preference-optimization
- File: trainers.py
- Lines: 272-394
Signature
class BasicTrainer(object):
def train(self):
"""Begin either SFT or DPO training, with periodic evaluation."""
# Initializes optimizer (configurable, default RMSprop)
# Initializes LR scheduler (linear warmup)
# Iterates over self.train_iterator
# Calls self.get_batch_metrics for loss computation
# Gradient accumulation over config.gradient_accumulation_steps
# Gradient clipping via self.clip_gradient()
# Periodic eval via self.get_batch_metrics(train=False)
# Optional sample generation via self.get_batch_samples()
# Checkpointing via self.save()
# Logging to wandb
Import
from trainers import BasicTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| self.train_iterator | Iterator[Dict] | Yes | Training batch iterator from get_batch_iterator |
| self.policy | nn.Module | Yes | Trainable policy model |
| self.reference_model | Optional[nn.Module] | DPO only | Frozen reference model (None for SFT) |
| self.config | DictConfig | Yes | Training configuration with optimizer, lr, warmup_steps, gradient_accumulation_steps, max_grad_norm, eval_every, etc. |
Outputs
| Name | Type | Description |
|---|---|---|
| self.policy (modified) | nn.Module | Model with updated weights after training |
| wandb logs | Dict | Training metrics: loss/train, logps_train/chosen, examples_per_second, grad_norm; eval metrics: loss/eval, rewards accuracies/margins |
| checkpoints | Files | policy.pt, optimizer.pt, scheduler.pt saved at each eval point |
Usage Examples
Running SFT Training
from trainers import BasicTrainer
# After model loading and config setup
trainer = BasicTrainer(
policy=policy,
config=config,
seed=config.seed,
run_dir=config.local_run_dir,
reference_model=None, # None for SFT
rank=0,
world_size=1,
)
trainer.train() # Runs full training loop
trainer.save() # Save final checkpoint
Running DPO Training
trainer = BasicTrainer(
policy=policy,
config=config,
seed=config.seed,
run_dir=config.local_run_dir,
reference_model=reference_model, # Frozen reference for DPO
rank=0,
world_size=1,
)
trainer.train()
trainer.save()
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
- Heuristic:Eric_mitchell_Direct_preference_optimization_TF32_Matmul_Precision
- Heuristic:Eric_mitchell_Direct_preference_optimization_FSDP_Mixed_Precision_BFloat16
- Heuristic:Eric_mitchell_Direct_preference_optimization_Activation_Checkpointing_Memory
- Heuristic:Eric_mitchell_Direct_preference_optimization_RMSprop_Over_Adam
- Heuristic:Eric_mitchell_Direct_preference_optimization_FSDP_Batch_Size_Per_GPU