Implementation:ContextualAI HALOs Alignment Trainers
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, NLP, Reinforcement_Learning |
| Last Updated | 2026-02-08 03:00 GMT |
Overview
Concrete tool for preference-based alignment training provided by the HALOs trainer classes.
Description
The train/trainers.py module implements a hierarchy of trainer classes for different alignment methods. All trainers inherit from BasicTrainer which provides the core training loop, evaluation, saving, and memory management. The key trainer classes are:
- PairedPreferenceTrainer — Base for methods using paired preferences (DPO, CDPO, IPO, SimPO, SLiC)
- UnpairedPreferenceTrainer — Base for methods using binary feedback (KTO)
- DPOTrainer — Direct Preference Optimization with sigmoid loss
- KTOTrainer — Kahneman-Tversky Optimization with prospect-theoretic loss and KL estimation
- GRPOTrainer — Group Relative Policy Optimization with clipped ratio and group-normalized advantages
- PPOTrainer — Proximal Policy Optimization with value head, GAE, and reward model
- CDPOTrainer, IPOTrainer, SimPOTrainer, SLiCTrainer — Additional paired preference methods
Each trainer overrides loss() and optionally get_batch_metrics() and forward() methods.
Usage
Invoke via accelerate launch launch.py loss={method} model=llama datasets=[{data}] model.load_from={sft_checkpoint}. The loss config determines which trainer class and dataloader class are used.
Code Reference
Source Location
- Repository: ContextualAI/HALOs
- File: train/trainers.py
- Lines: L59-549 (BasicTrainer), L584-651 (UnpairedPreferenceTrainer), L654-743 (PairedPreferenceTrainer), L746-764 (DPOTrainer), L851-1006 (KTOTrainer), L1009-1103 (GRPOTrainer), L1106-1538 (PPOTrainer)
Signature
class BasicTrainer(object):
def __init__(self,
tokenizer: AutoTokenizer,
config: DictConfig,
train_iterator: dataloader.DataLoader,
eval_iterator: dataloader.DataLoader,
accelerator: Accelerator,
optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler,
policy: nn.Module,
reference_model: Optional[nn.Module] = None,
**kwargs):
"""A trainer for a language model, supporting SFT, HALO, or offline PPO training."""
def train(self) -> None:
"""Begin training with periodic evaluation."""
class DPOTrainer(PairedPreferenceTrainer):
def loss(self, batch, policy_chosen_logps, policy_rejected_logps,
reference_chosen_logps, reference_rejected_logps, *args):
"""Compute DPO loss: -log_sigmoid(beta * (chosen_reward - rejected_reward))"""
class KTOTrainer(UnpairedPreferenceTrainer):
def loss(self, batch, policy_chosen_logps, policy_rejected_logps,
policy_KL_logps, reference_chosen_logps, reference_rejected_logps,
reference_KL_logps, *args):
"""Compute KTO loss with prospect-theoretic desirable/undesirable weighting."""
class GRPOTrainer(BasicTrainer):
def loss(self, batch, policy_logps, reference_logps, advantages, group_size):
"""Compute GRPO loss with clipped ratios and group-normalized advantages."""
class PPOTrainer(BasicTrainer):
def train(self) -> None:
"""PPO training loop with reward model scoring and GAE advantage estimation."""
Import
from train.trainers import DPOTrainer, KTOTrainer, GRPOTrainer, PPOTrainer
# Or invoke via CLI:
# accelerate launch launch.py loss=dpo model=llama datasets=[ultrabin] model.load_from=/path/to/sft
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | DictConfig | Yes | Hydra config with loss={dpo,kto,grpo,ppo,...}, model, datasets |
| config.loss.beta | float | Yes | KL penalty weight (default 0.1) |
| config.model.load_from | str | Yes | Path to SFT checkpoint |
| config.model.use_peft | bool | No | Whether to apply LoRA |
| config.humanline | bool | No | Enable humanline per-token clamping |
| config.cache_reference_logprobs | bool | No | Precompute reference logprobs to save GPU memory |
| train_dataset | DataLoader | Yes | Appropriate DataLoader for the loss type |
| reference_model | nn.Module / ReferenceModelWrapper | Conditional | Required for DPO, KTO, GRPO; not for SLiC, SimPO |
Outputs
| Name | Type | Description |
|---|---|---|
| Aligned model | Directory | Saved to {cache_dir}/{exp_name}/FINAL/ with LoRA merged if applicable |
| Training metrics | Dict | Loss, reward margins, accuracies, KL estimates, grad norm |
| WandB logs | Remote | Training curves if enabled |
Usage Examples
DPO Alignment
accelerate launch \
--config_file accelerate_config/fsdp_4gpu.yaml \
launch.py \
loss=dpo \
model=llama \
datasets=[ultrabin] \
exp_name=llama3-8B-dpo \
++model.load_from=/models/llama3-8B-sft/FINAL \
++model.name_or_path=meta-llama/Meta-Llama-3-8B \
++loss.beta=0.1
KTO Alignment with Humanline
accelerate launch \
--config_file accelerate_config/fsdp_4gpu.yaml \
launch.py \
loss=kto \
model=llama \
datasets=[ultrafeedback_armorm] \
exp_name=llama3-8B-kto-humanline \
++model.load_from=/models/llama3-8B-sft/FINAL \
++model.name_or_path=meta-llama/Meta-Llama-3-8B-Instruct \
++humanline=true
GRPO Alignment
accelerate launch \
--config_file accelerate_config/fsdp_4gpu.yaml \
launch.py \
loss=grpo \
model=llama \
datasets=[ultrafeedback_armorm] \
exp_name=llama3-8B-grpo \
++model.load_from=/models/llama3-8B-sft/FINAL \
++model.name_or_path=meta-llama/Meta-Llama-3-8B-Instruct