Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:OpenRLHF OpenRLHF KDLoss

From Leeroopedia


Knowledge Sources
Domains Model_Compression, Loss_Functions
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for computing knowledge distillation loss between teacher and student models provided by OpenRLHF.

Description

The KDLoss class computes forward KL divergence between teacher and student distributions. It applies softmax to teacher logits, log-softmax to student logits, element-wise multiplies and sums across the vocabulary dimension, then masks using the label ignore index (-100). Handles infinities in student logits gracefully with masked_fill.

Usage

Instantiated by KDTrainer. Called each training step with student logits, teacher logits, and labels.

Code Reference

Source Location

  • Repository: OpenRLHF
  • File: openrlhf/models/loss.py
  • Lines: L376-394

Signature

class KDLoss(nn.Module):
    def __init__(self):
        # IGNORE_INDEX = -100

    def forward(
        self,
        logits: torch.Tensor,          # Student logits (batch, seq, vocab)
        teacher_logits: torch.Tensor,   # Teacher logits (batch, seq, vocab)
        label: torch.Tensor,            # Ground truth labels for masking
    ) -> torch.Tensor:
        """Returns scalar distillation loss."""

Import

from openrlhf.models import KDLoss

I/O Contract

Inputs

Name Type Required Description
logits Tensor Yes Student model logits (batch, seq_len, vocab_size)
teacher_logits Tensor Yes Teacher model logits (same shape)
label Tensor Yes Ground truth labels for masking (batch, seq_len)

Outputs

Name Type Description
loss Tensor Scalar distillation loss

Usage Examples

from openrlhf.models import KDLoss

kd_loss_fn = KDLoss()
distil_loss = kd_loss_fn(student_logits, teacher_logits, labels)

# Combined loss
total_loss = gpt_loss * (1 - kd_coef) + distil_loss * kd_coef

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment