Principle:NVIDIA NeMo Aligner Knowledge Distillation Training
| Knowledge Sources | |
|---|---|
| Domains | Natural Language Processing, Model Compression, Knowledge Distillation |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Knowledge distillation transfers knowledge from a larger teacher model to a smaller student model by matching the teacher's output distribution (top-k logits), reducing model size while preserving capability.
Description
Knowledge distillation (KD) is a model compression technique in which a smaller student model is trained to replicate the behavior of a larger teacher model. Rather than training the student solely on hard labels (ground-truth tokens), the student learns from the teacher's soft probability distribution over the vocabulary, which encodes richer information about inter-token relationships and uncertainty.
In NeMo Aligner's implementation, the teacher's output is represented as top-k logits -- only the k highest-scoring logits and their corresponding token IDs are stored. This dramatically reduces the storage requirements compared to saving the full vocabulary distribution while retaining the most informative portion of the teacher's output.
The training objective is a weighted combination of two losses:
- KD loss -- a cross-entropy or KL divergence loss computed between the student's predicted logits and the teacher's top-k logits. Both forward KL and backward KL divergence variants are supported.
- SFT loss -- a standard supervised fine-tuning (next-token prediction) loss computed against the ground-truth labels.
The relative weighting of these two loss terms is controlled by kd_loss_weight and sft_loss_weight configuration parameters. Additionally, both the student's logits and the teacher's target logits can be independently scaled via logits_scale and target_logits_scale parameters, which act as inverse temperature controls.
The implementation also supports cross-tokenizer distillation, enabling knowledge transfer between models that use different tokenizers.
Usage
Knowledge distillation training is used when:
- You have a large, high-quality teacher model and want to produce a smaller, deployable student model that retains much of the teacher's performance.
- You want to compress a model for inference efficiency (lower latency, reduced memory) without retraining from scratch.
- You want to transfer specialized capabilities from a fine-tuned teacher to a student with a different architecture or tokenizer.
The typical workflow involves two phases:
- Teacher logit extraction -- Run the teacher model over the training corpus to compute and save top-k logits using the compute_topk_logits script.
- Student training -- Train the student model using the train_gpt_knowledge_distillation script, which loads the precomputed teacher logits and optimizes the combined KD + SFT objective.
Theoretical Basis
The core idea of knowledge distillation is grounded in minimizing the divergence between the teacher's output distribution and the student's output distribution .
Forward KL Divergence (default): Failed to parse (syntax error): {\displaystyle \mathcal{L}_{\text{fwd\_kl}} = \sum_{i \in \text{top-k}} P_T(i) \log \frac{P_T(i)}{P_S(i)} }
Backward KL Divergence: Failed to parse (syntax error): {\displaystyle \mathcal{L}_{\text{bwd\_kl}} = \sum_{i \in \text{top-k}} P_S(i) \log \frac{P_S(i)}{P_T(i)} }
Combined Loss:
where is the knowledge distillation loss weight and is the supervised fine-tuning loss weight.
Temperature scaling is applied to both student logits (via logits_scale) and teacher logits (via target_logits_scale) before computing the KD loss. Higher scaling softens the distributions, encouraging the student to better match the teacher's relative ranking of tokens.