Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:ContextualAI HALOs Alignment Trainers

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, NLP, Reinforcement_Learning
Last Updated 2026-02-08 03:00 GMT

Overview

Concrete tool for preference-based alignment training provided by the HALOs trainer classes.

Description

The train/trainers.py module implements a hierarchy of trainer classes for different alignment methods. All trainers inherit from BasicTrainer which provides the core training loop, evaluation, saving, and memory management. The key trainer classes are:

  • PairedPreferenceTrainer — Base for methods using paired preferences (DPO, CDPO, IPO, SimPO, SLiC)
  • UnpairedPreferenceTrainer — Base for methods using binary feedback (KTO)
  • DPOTrainer — Direct Preference Optimization with sigmoid loss
  • KTOTrainer — Kahneman-Tversky Optimization with prospect-theoretic loss and KL estimation
  • GRPOTrainer — Group Relative Policy Optimization with clipped ratio and group-normalized advantages
  • PPOTrainer — Proximal Policy Optimization with value head, GAE, and reward model
  • CDPOTrainer, IPOTrainer, SimPOTrainer, SLiCTrainer — Additional paired preference methods

Each trainer overrides loss() and optionally get_batch_metrics() and forward() methods.

Usage

Invoke via accelerate launch launch.py loss={method} model=llama datasets=[{data}] model.load_from={sft_checkpoint}. The loss config determines which trainer class and dataloader class are used.

Code Reference

Source Location

  • Repository: ContextualAI/HALOs
  • File: train/trainers.py
  • Lines: L59-549 (BasicTrainer), L584-651 (UnpairedPreferenceTrainer), L654-743 (PairedPreferenceTrainer), L746-764 (DPOTrainer), L851-1006 (KTOTrainer), L1009-1103 (GRPOTrainer), L1106-1538 (PPOTrainer)

Signature

class BasicTrainer(object):
    def __init__(self,
                 tokenizer: AutoTokenizer,
                 config: DictConfig,
                 train_iterator: dataloader.DataLoader,
                 eval_iterator: dataloader.DataLoader,
                 accelerator: Accelerator,
                 optimizer: torch.optim.Optimizer,
                 scheduler: torch.optim.lr_scheduler.LRScheduler,
                 policy: nn.Module,
                 reference_model: Optional[nn.Module] = None,
                 **kwargs):
        """A trainer for a language model, supporting SFT, HALO, or offline PPO training."""

    def train(self) -> None:
        """Begin training with periodic evaluation."""

class DPOTrainer(PairedPreferenceTrainer):
    def loss(self, batch, policy_chosen_logps, policy_rejected_logps,
             reference_chosen_logps, reference_rejected_logps, *args):
        """Compute DPO loss: -log_sigmoid(beta * (chosen_reward - rejected_reward))"""

class KTOTrainer(UnpairedPreferenceTrainer):
    def loss(self, batch, policy_chosen_logps, policy_rejected_logps,
             policy_KL_logps, reference_chosen_logps, reference_rejected_logps,
             reference_KL_logps, *args):
        """Compute KTO loss with prospect-theoretic desirable/undesirable weighting."""

class GRPOTrainer(BasicTrainer):
    def loss(self, batch, policy_logps, reference_logps, advantages, group_size):
        """Compute GRPO loss with clipped ratios and group-normalized advantages."""

class PPOTrainer(BasicTrainer):
    def train(self) -> None:
        """PPO training loop with reward model scoring and GAE advantage estimation."""

Import

from train.trainers import DPOTrainer, KTOTrainer, GRPOTrainer, PPOTrainer
# Or invoke via CLI:
# accelerate launch launch.py loss=dpo model=llama datasets=[ultrabin] model.load_from=/path/to/sft

I/O Contract

Inputs

Name Type Required Description
config DictConfig Yes Hydra config with loss={dpo,kto,grpo,ppo,...}, model, datasets
config.loss.beta float Yes KL penalty weight (default 0.1)
config.model.load_from str Yes Path to SFT checkpoint
config.model.use_peft bool No Whether to apply LoRA
config.humanline bool No Enable humanline per-token clamping
config.cache_reference_logprobs bool No Precompute reference logprobs to save GPU memory
train_dataset DataLoader Yes Appropriate DataLoader for the loss type
reference_model nn.Module / ReferenceModelWrapper Conditional Required for DPO, KTO, GRPO; not for SLiC, SimPO

Outputs

Name Type Description
Aligned model Directory Saved to {cache_dir}/{exp_name}/FINAL/ with LoRA merged if applicable
Training metrics Dict Loss, reward margins, accuracies, KL estimates, grad norm
WandB logs Remote Training curves if enabled

Usage Examples

DPO Alignment

accelerate launch \
    --config_file accelerate_config/fsdp_4gpu.yaml \
    launch.py \
    loss=dpo \
    model=llama \
    datasets=[ultrabin] \
    exp_name=llama3-8B-dpo \
    ++model.load_from=/models/llama3-8B-sft/FINAL \
    ++model.name_or_path=meta-llama/Meta-Llama-3-8B \
    ++loss.beta=0.1

KTO Alignment with Humanline

accelerate launch \
    --config_file accelerate_config/fsdp_4gpu.yaml \
    launch.py \
    loss=kto \
    model=llama \
    datasets=[ultrafeedback_armorm] \
    exp_name=llama3-8B-kto-humanline \
    ++model.load_from=/models/llama3-8B-sft/FINAL \
    ++model.name_or_path=meta-llama/Meta-Llama-3-8B-Instruct \
    ++humanline=true

GRPO Alignment

accelerate launch \
    --config_file accelerate_config/fsdp_4gpu.yaml \
    launch.py \
    loss=grpo \
    model=llama \
    datasets=[ultrafeedback_armorm] \
    exp_name=llama3-8B-grpo \
    ++model.load_from=/models/llama3-8B-sft/FINAL \
    ++model.name_or_path=meta-llama/Meta-Llama-3-8B-Instruct

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment