Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL VariousDivergence

From Leeroopedia


Knowledge Sources
Domains Knowledge_Distillation, Optimization
Last Updated 2026-02-07 20:00 GMT

Overview

Concrete KL divergence objective implementations for knowledge distillation provided by the Alibaba ROLL library.

Description

The VariousDivergence class implements six KL divergence objectives. It selects the appropriate method based on the kd_objective configuration parameter.

Usage

Instantiated by the StudentWorker during initialization and called during each loss computation.

Code Reference

Source Location

  • Repository: Alibaba ROLL
  • File: roll/pipeline/distill/various_divergence.py
  • Lines: L8-153

Signature

class VariousDivergence:
    def __init__(self, pipeline_config: DistillConfig, padding_id: int = -100) -> None:
        """Initialize with KD objective from config."""

    def __call__(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
        """Compute selected divergence."""

    def compute_forward_kl_divergence(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
        """Forward KL: KL(teacher || student)."""

    def compute_reverse_kl_divergence(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
        """Reverse KL: KL(student || teacher)."""

    def compute_adaptive_kl_divergence(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
        """Adaptive KL: weighted combination."""

    def compute_skewed_forward_kl_divergence(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
        """Skewed forward KL."""

    def compute_skewed_reverse_kl_divergence(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
        """Skewed reverse KL."""

    def compute_js_divergence(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
        """Jensen-Shannon divergence."""

Import

from roll.pipeline.distill.various_divergence import VariousDivergence

I/O Contract

Inputs

Name Type Required Description
logits torch.Tensor Yes Student model logits
teacher_probs torch.Tensor Yes Teacher top-k probabilities
teacher_log_probs torch.Tensor Yes Teacher top-k log probabilities
teacher_inf_mask torch.Tensor Yes Mask for infinite teacher values

Outputs

Name Type Description
divergence_loss torch.Tensor Scalar divergence loss

Usage Examples

from roll.pipeline.distill.various_divergence import VariousDivergence

divergence = VariousDivergence(pipeline_config=distill_config)
kd_loss = divergence(student_logits, teacher_probs, teacher_log_probs, teacher_inf_mask)
total_loss = (1 - config.distill_loss_weight) * sft_loss + config.distill_loss_weight * kd_loss

Related Pages

Implements Principle

Requires Environment

Environment Dependencies

This implementation requires the following environment constraints:

Heuristics Applied

This implementation uses the following heuristics:

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment