Implementation:OpenRLHF OpenRLHF KDTrainer
Appearance
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training, Model_Compression |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for knowledge distillation training of language models provided by OpenRLHF.
Description
The KDTrainer class implements the knowledge distillation training loop. It runs both student and teacher models on the same data, computes the GPTLMLoss (standard cross-entropy) and KDLoss (forward KL divergence), combines them with a configurable kd_coef weight, and optimizes only the student model. The teacher model is kept in eval mode throughout.
Usage
Instantiate with student model, frozen teacher model, optimizer, SFT dataloaders, and call fit() to train.
Code Reference
Source Location
- Repository: OpenRLHF
- File: openrlhf/trainer/kd_trainer.py
- Lines: L12-261 (class), L30-96 (__init__), L98-186 (fit)
Signature
class KDTrainer(ABC):
def __init__(
self,
model, # Actor: student model to train
teacher_model, # Actor: frozen teacher model
strategy, # DeepspeedStrategy
optim: Optimizer, # optimizer
train_dataloader, # training DataLoader (SFTDataset)
eval_dataloader, # evaluation DataLoader
scheduler, # learning rate scheduler
max_norm: float = 1, # gradient clipping norm
pretrain_mode: bool = False, # loss on all tokens
batch_size: int = 1, # batch size
max_epochs: int = 2, # training epochs
tokenizer=None, # tokenizer for checkpointing
save_hf_ckpt: bool = False,
disable_ds_ckpt: bool = False,
) -> None:
def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None):
"""Run the full KD training loop."""
Import
from openrlhf.trainer import KDTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | Actor | Yes | Student model to train |
| teacher_model | Actor | Yes | Frozen teacher model |
| args.kd_coef | float | Yes | Weight for distillation loss (from CLI) |
Outputs
| Name | Type | Description |
|---|---|---|
| (side effect) | None | Student model trained in-place |
| logs | Dict | gpt_loss, distil_loss metrics |
Usage Examples
from openrlhf.trainer import KDTrainer
trainer = KDTrainer(
model=student_model,
teacher_model=teacher_model,
strategy=strategy,
optim=optimizer,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
scheduler=scheduler,
max_norm=args.max_norm,
max_epochs=args.max_epochs,
tokenizer=tokenizer,
)
trainer.fit(args, num_update_steps_per_epoch=num_update_steps_per_epoch)
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment