Principle:OpenRLHF OpenRLHF KD Loss Computation
| Knowledge Sources | |
|---|---|
| Domains | Model_Compression, Loss_Functions |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
A loss function that measures the divergence between student and teacher model output distributions for knowledge transfer.
Description
KD Loss computes the forward KL divergence between teacher and student probability distributions at the token level. It uses teacher softmax probabilities and student log-softmax probabilities, masking out padding and prompt tokens via an ignore index. This encourages the student to match the teacher's full output distribution, not just the argmax prediction.
Usage
Used internally by KDTrainer. Combined with GPTLMLoss via a weighting coefficient (kd_coef).
Theoretical Basis
Forward KL Divergence (token-level):
where is the set of non-masked positions, is the teacher distribution, and is the student distribution.