Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Eric mitchell Direct preference optimization BasicTrainer Train

From Leeroopedia


Knowledge Sources
Domains Training, Optimization, Deep_Learning
Last Updated 2026-02-08 02:00 GMT

Overview

Concrete tool for executing the SFT or DPO training loop with evaluation interleaving provided by the direct-preference-optimization repository.

Description

The BasicTrainer.train method is the main training loop. It initializes the optimizer and scheduler, then iterates over the training data, computing loss with gradient accumulation, clipping gradients, and performing optimizer steps. Evaluation is interleaved at configurable intervals, with optional text sample generation. Checkpoints are saved at each evaluation point.

Usage

Called after constructing a BasicTrainer (or subclass) instance with a loaded model, config, and data iterators. This is the final step in both SFT and DPO training workflows.

Code Reference

Source Location

Signature

class BasicTrainer(object):
    def train(self):
        """Begin either SFT or DPO training, with periodic evaluation."""
        # Initializes optimizer (configurable, default RMSprop)
        # Initializes LR scheduler (linear warmup)
        # Iterates over self.train_iterator
        # Calls self.get_batch_metrics for loss computation
        # Gradient accumulation over config.gradient_accumulation_steps
        # Gradient clipping via self.clip_gradient()
        # Periodic eval via self.get_batch_metrics(train=False)
        # Optional sample generation via self.get_batch_samples()
        # Checkpointing via self.save()
        # Logging to wandb

Import

from trainers import BasicTrainer

I/O Contract

Inputs

Name Type Required Description
self.train_iterator Iterator[Dict] Yes Training batch iterator from get_batch_iterator
self.policy nn.Module Yes Trainable policy model
self.reference_model Optional[nn.Module] DPO only Frozen reference model (None for SFT)
self.config DictConfig Yes Training configuration with optimizer, lr, warmup_steps, gradient_accumulation_steps, max_grad_norm, eval_every, etc.

Outputs

Name Type Description
self.policy (modified) nn.Module Model with updated weights after training
wandb logs Dict Training metrics: loss/train, logps_train/chosen, examples_per_second, grad_norm; eval metrics: loss/eval, rewards accuracies/margins
checkpoints Files policy.pt, optimizer.pt, scheduler.pt saved at each eval point

Usage Examples

Running SFT Training

from trainers import BasicTrainer

# After model loading and config setup
trainer = BasicTrainer(
    policy=policy,
    config=config,
    seed=config.seed,
    run_dir=config.local_run_dir,
    reference_model=None,  # None for SFT
    rank=0,
    world_size=1,
)

trainer.train()  # Runs full training loop
trainer.save()   # Save final checkpoint

Running DPO Training

trainer = BasicTrainer(
    policy=policy,
    config=config,
    seed=config.seed,
    run_dir=config.local_run_dir,
    reference_model=reference_model,  # Frozen reference for DPO
    rank=0,
    world_size=1,
)

trainer.train()
trainer.save()

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment