Implementation:OpenRLHF OpenRLHF KDLoss
| Knowledge Sources | |
|---|---|
| Domains | Model_Compression, Loss_Functions |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for computing knowledge distillation loss between teacher and student models provided by OpenRLHF.
Description
The KDLoss class computes forward KL divergence between teacher and student distributions. It applies softmax to teacher logits, log-softmax to student logits, element-wise multiplies and sums across the vocabulary dimension, then masks using the label ignore index (-100). Handles infinities in student logits gracefully with masked_fill.
Usage
Instantiated by KDTrainer. Called each training step with student logits, teacher logits, and labels.
Code Reference
Source Location
- Repository: OpenRLHF
- File: openrlhf/models/loss.py
- Lines: L376-394
Signature
class KDLoss(nn.Module):
def __init__(self):
# IGNORE_INDEX = -100
def forward(
self,
logits: torch.Tensor, # Student logits (batch, seq, vocab)
teacher_logits: torch.Tensor, # Teacher logits (batch, seq, vocab)
label: torch.Tensor, # Ground truth labels for masking
) -> torch.Tensor:
"""Returns scalar distillation loss."""
Import
from openrlhf.models import KDLoss
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| logits | Tensor | Yes | Student model logits (batch, seq_len, vocab_size) |
| teacher_logits | Tensor | Yes | Teacher model logits (same shape) |
| label | Tensor | Yes | Ground truth labels for masking (batch, seq_len) |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Tensor | Scalar distillation loss |
Usage Examples
from openrlhf.models import KDLoss
kd_loss_fn = KDLoss()
distil_loss = kd_loss_fn(student_logits, teacher_logits, labels)
# Combined loss
total_loss = gpt_loss * (1 - kd_coef) + distil_loss * kd_coef