Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:OpenRLHF OpenRLHF KDTrainer

From Leeroopedia


Knowledge Sources
Domains NLP, Training, Model_Compression
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for knowledge distillation training of language models provided by OpenRLHF.

Description

The KDTrainer class implements the knowledge distillation training loop. It runs both student and teacher models on the same data, computes the GPTLMLoss (standard cross-entropy) and KDLoss (forward KL divergence), combines them with a configurable kd_coef weight, and optimizes only the student model. The teacher model is kept in eval mode throughout.

Usage

Instantiate with student model, frozen teacher model, optimizer, SFT dataloaders, and call fit() to train.

Code Reference

Source Location

  • Repository: OpenRLHF
  • File: openrlhf/trainer/kd_trainer.py
  • Lines: L12-261 (class), L30-96 (__init__), L98-186 (fit)

Signature

class KDTrainer(ABC):
    def __init__(
        self,
        model,                          # Actor: student model to train
        teacher_model,                  # Actor: frozen teacher model
        strategy,                       # DeepspeedStrategy
        optim: Optimizer,               # optimizer
        train_dataloader,               # training DataLoader (SFTDataset)
        eval_dataloader,                # evaluation DataLoader
        scheduler,                      # learning rate scheduler
        max_norm: float = 1,            # gradient clipping norm
        pretrain_mode: bool = False,    # loss on all tokens
        batch_size: int = 1,            # batch size
        max_epochs: int = 2,            # training epochs
        tokenizer=None,                 # tokenizer for checkpointing
        save_hf_ckpt: bool = False,
        disable_ds_ckpt: bool = False,
    ) -> None:

    def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None):
        """Run the full KD training loop."""

Import

from openrlhf.trainer import KDTrainer

I/O Contract

Inputs

Name Type Required Description
model Actor Yes Student model to train
teacher_model Actor Yes Frozen teacher model
args.kd_coef float Yes Weight for distillation loss (from CLI)

Outputs

Name Type Description
(side effect) None Student model trained in-place
logs Dict gpt_loss, distil_loss metrics

Usage Examples

from openrlhf.trainer import KDTrainer

trainer = KDTrainer(
    model=student_model,
    teacher_model=teacher_model,
    strategy=strategy,
    optim=optimizer,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
    scheduler=scheduler,
    max_norm=args.max_norm,
    max_epochs=args.max_epochs,
    tokenizer=tokenizer,
)

trainer.fit(args, num_update_steps_per_epoch=num_update_steps_per_epoch)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment