Implementation:Hpcaitech ColossalAI KTOTrainer
| Knowledge Sources | |
|---|---|
| Domains | RLHF, Preference_Learning, KTO |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
kto.py implements the KTOTrainer class for Kahneman-Tversky Optimization, a preference learning algorithm that trains from binary (desirable/undesirable) feedback without requiring paired comparisons.
Description
KTOTrainer extends SLTrainer to implement the KTO algorithm. The training loop processes batches containing input tokens, attention masks, loss masks, binary labels (desirable=1, undesirable=0), and separate KL reference data (kl_input_ids, kl_attention_mask, kl_loss_mask). For each batch, it computes log probabilities from both the actor model and a frozen reference model, separates chosen and rejected examples by label, and calculates the KTOLoss with configurable beta, desirable_weight, and undesirable_weight parameters. The trainer supports gradient accumulation, periodic checkpoint saving via save_checkpoint, and logging to TensorBoard and Weights & Biases. Metrics tracked include loss, chosen rewards, rejected rewards, and reward margin. The _eval method runs the same computation without gradients and writes evaluation results to text files. The trainer uses all_reduce_mean for distributed metric synchronization and AccumulativeMeanMeter for metric averaging across accumulation steps.
Usage
Use this trainer when training a language model from binary feedback data where each example is labeled as desirable or undesirable, without requiring paired preference comparisons. It is suitable for scenarios where preference data consists of individual ratings rather than A-vs-B comparisons.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/trainer/kto.py
- Lines: 1-355
Signature
class KTOTrainer(SLTrainer):
def __init__(
self,
actor: Any,
ref_model: Any,
booster: Booster,
actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
beta: float = 0.1,
desirable_weight: float = 1.0,
undesirable_weight: float = 1.0,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None
Key Methods
def _before_fit(
self,
train_preference_dataloader: DataLoader = None,
eval_preference_dataloader: DataLoader = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
)
def _train(self, epoch: int)
def _eval(self, epoch: int)
Import
from coati.trainer.kto import KTOTrainer
I/O Contract
Inputs (__init__)
| Name | Type | Required | Description |
|---|---|---|---|
| actor | Any | Yes | The actor (policy) model to train |
| ref_model | Any | Yes | Frozen reference model for KL divergence computation |
| booster | Booster | Yes | ColossalAI Booster for distributed training |
| actor_optim | Optimizer | Yes | Optimizer for the actor model |
| plugin | Plugin | Yes | ColossalAI plugin for parallelism strategy |
| actor_lr_scheduler | _LRScheduler | Yes | Learning rate scheduler |
| tokenizer | PreTrainedTokenizerBase | Yes | Tokenizer for encoding |
| beta | float | No | KTO loss temperature parameter (default: 0.1) |
| desirable_weight | float | No | Weight for desirable (chosen) examples in loss (default: 1.0) |
| undesirable_weight | float | No | Weight for undesirable (rejected) examples in loss (default: 1.0) |
| apply_loss_mask | bool | No | Whether to apply loss masking (default: True) |
| accumulation_steps | int | No | Gradient accumulation steps (default: 1) |
| save_interval | int | No | Checkpoint saving interval in steps (default: 0, disabled) |
Training Batch Format
| Name | Type | Description |
|---|---|---|
| input_ids | torch.Tensor | Token IDs for the examples |
| attention_mask | torch.Tensor | Attention mask |
| loss_mask | torch.Tensor | Mask for loss computation |
| label | torch.Tensor | Binary labels (1=desirable, 0=undesirable) |
| kl_input_ids | torch.Tensor | Token IDs for KL reference data |
| kl_attention_mask | torch.Tensor | Attention mask for KL data |
| kl_loss_mask | torch.Tensor | Loss mask for KL data |
Outputs
| Name | Type | Description |
|---|---|---|
| (none) | None | Training modifies the model in-place; metrics logged to TensorBoard/W&B |
Usage Examples
from coati.trainer.kto import KTOTrainer
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
plugin = HybridParallelPlugin(tp_size=1, pp_size=1, zero_stage=2)
booster = Booster(plugin=plugin)
trainer = KTOTrainer(
actor=actor_model,
ref_model=ref_model,
booster=booster,
actor_optim=optimizer,
plugin=plugin,
actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer,
max_epochs=3,
beta=0.1,
desirable_weight=1.0,
undesirable_weight=1.0,
accumulation_steps=4,
save_interval=500,
save_dir="./checkpoints/kto",
coordinator=coordinator,
)
trainer.fit(
train_preference_dataloader=train_dataloader,
eval_preference_dataloader=eval_dataloader,
log_dir="./logs",
use_wandb=True,
)