Implementation:Alibaba ROLL VariousDivergence
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Knowledge_Distillation, Optimization |
| Last Updated | 2026-02-07 20:00 GMT |
Overview
Concrete KL divergence objective implementations for knowledge distillation provided by the Alibaba ROLL library.
Description
The VariousDivergence class implements six KL divergence objectives. It selects the appropriate method based on the kd_objective configuration parameter.
Usage
Instantiated by the StudentWorker during initialization and called during each loss computation.
Code Reference
Source Location
- Repository: Alibaba ROLL
- File: roll/pipeline/distill/various_divergence.py
- Lines: L8-153
Signature
class VariousDivergence:
def __init__(self, pipeline_config: DistillConfig, padding_id: int = -100) -> None:
"""Initialize with KD objective from config."""
def __call__(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
"""Compute selected divergence."""
def compute_forward_kl_divergence(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
"""Forward KL: KL(teacher || student)."""
def compute_reverse_kl_divergence(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
"""Reverse KL: KL(student || teacher)."""
def compute_adaptive_kl_divergence(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
"""Adaptive KL: weighted combination."""
def compute_skewed_forward_kl_divergence(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
"""Skewed forward KL."""
def compute_skewed_reverse_kl_divergence(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
"""Skewed reverse KL."""
def compute_js_divergence(self, logits, teacher_probs, teacher_log_probs, teacher_inf_mask) -> torch.Tensor:
"""Jensen-Shannon divergence."""
Import
from roll.pipeline.distill.various_divergence import VariousDivergence
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| logits | torch.Tensor | Yes | Student model logits |
| teacher_probs | torch.Tensor | Yes | Teacher top-k probabilities |
| teacher_log_probs | torch.Tensor | Yes | Teacher top-k log probabilities |
| teacher_inf_mask | torch.Tensor | Yes | Mask for infinite teacher values |
Outputs
| Name | Type | Description |
|---|---|---|
| divergence_loss | torch.Tensor | Scalar divergence loss |
Usage Examples
from roll.pipeline.distill.various_divergence import VariousDivergence
divergence = VariousDivergence(pipeline_config=distill_config)
kd_loss = divergence(student_logits, teacher_probs, teacher_log_probs, teacher_inf_mask)
total_loss = (1 - config.distill_loss_weight) * sft_loss + config.distill_loss_weight * kd_loss
Related Pages
Implements Principle
Requires Environment
Environment Dependencies
This implementation requires the following environment constraints:
Heuristics Applied
This implementation uses the following heuristics:
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment