Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:NVIDIA NeMo Aligner Knowledge Distillation Training

From Leeroopedia
Revision as of 17:22, 16 February 2026 by Admin (talk | contribs) (Auto-imported from principles/NVIDIA_NeMo_Aligner_Knowledge_Distillation_Training.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Natural Language Processing, Model Compression, Knowledge Distillation
Last Updated 2026-02-08 00:00 GMT

Overview

Knowledge distillation transfers knowledge from a larger teacher model to a smaller student model by matching the teacher's output distribution (top-k logits), reducing model size while preserving capability.

Description

Knowledge distillation (KD) is a model compression technique in which a smaller student model is trained to replicate the behavior of a larger teacher model. Rather than training the student solely on hard labels (ground-truth tokens), the student learns from the teacher's soft probability distribution over the vocabulary, which encodes richer information about inter-token relationships and uncertainty.

In NeMo Aligner's implementation, the teacher's output is represented as top-k logits -- only the k highest-scoring logits and their corresponding token IDs are stored. This dramatically reduces the storage requirements compared to saving the full vocabulary distribution while retaining the most informative portion of the teacher's output.

The training objective is a weighted combination of two losses:

  • KD loss -- a cross-entropy or KL divergence loss computed between the student's predicted logits and the teacher's top-k logits. Both forward KL and backward KL divergence variants are supported.
  • SFT loss -- a standard supervised fine-tuning (next-token prediction) loss computed against the ground-truth labels.

The relative weighting of these two loss terms is controlled by kd_loss_weight and sft_loss_weight configuration parameters. Additionally, both the student's logits and the teacher's target logits can be independently scaled via logits_scale and target_logits_scale parameters, which act as inverse temperature controls.

The implementation also supports cross-tokenizer distillation, enabling knowledge transfer between models that use different tokenizers.

Usage

Knowledge distillation training is used when:

  • You have a large, high-quality teacher model and want to produce a smaller, deployable student model that retains much of the teacher's performance.
  • You want to compress a model for inference efficiency (lower latency, reduced memory) without retraining from scratch.
  • You want to transfer specialized capabilities from a fine-tuned teacher to a student with a different architecture or tokenizer.

The typical workflow involves two phases:

  1. Teacher logit extraction -- Run the teacher model over the training corpus to compute and save top-k logits using the compute_topk_logits script.
  2. Student training -- Train the student model using the train_gpt_knowledge_distillation script, which loads the precomputed teacher logits and optimizes the combined KD + SFT objective.

Theoretical Basis

The core idea of knowledge distillation is grounded in minimizing the divergence between the teacher's output distribution PT and the student's output distribution PS.

Forward KL Divergence (default): Failed to parse (syntax error): {\displaystyle \mathcal{L}_{\text{fwd\_kl}} = \sum_{i \in \text{top-k}} P_T(i) \log \frac{P_T(i)}{P_S(i)} }

Backward KL Divergence: Failed to parse (syntax error): {\displaystyle \mathcal{L}_{\text{bwd\_kl}} = \sum_{i \in \text{top-k}} P_S(i) \log \frac{P_S(i)}{P_T(i)} }

Combined Loss: =wkdKD+wsftSFT

where wkd is the knowledge distillation loss weight and wsft is the supervised fine-tuning loss weight.

Temperature scaling is applied to both student logits (via logits_scale) and teacher logits (via target_logits_scale) before computing the KD loss. Higher scaling softens the distributions, encouraging the student to better match the teacher's relative ranking of tokens.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment