Implementation:NVIDIA NeMo Aligner MegatronGPT Knowledge Distillation
| Knowledge Sources | |
|---|---|
| Domains | Natural Language Processing, Model Compression, Knowledge Distillation |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
GPTKnowledgeDistillationModel is a Megatron-based GPT model class that implements knowledge distillation training by combining a KD loss (top-k logit cross-entropy) with an optional SFT loss.
Description
The GPTKnowledgeDistillationModel class extends NLPAdapterModelMixin, MegatronGPTModel, and SupervisedInterface to provide a complete knowledge distillation training pipeline for GPT models within the NeMo Aligner framework.
During initialization, the model reads the following configuration parameters from cfg.knowledge_distillation:
- target_logits_scale (default 1.0) -- scaling factor applied to teacher logits before loss computation
- logits_scale (default 1.0) -- scaling factor applied to student logits before loss computation
- kd_loss (default "fwd_kl") -- the type of KL divergence used; supports "fwd_kl" or "bwd_kl"
- kd_loss_weight (default 1) -- weight for the knowledge distillation loss term
- sft_loss_weight (default 0) -- weight for the supervised fine-tuning loss term
- cross_tokenizer (default False) -- enables cross-tokenizer distillation
The core loss computation is delegated to _TopKLogitsCrossEntropy, a custom autograd function that efficiently computes the cross-entropy between the student's full logits and the teacher's sparse top-k logit representation. The model supports Megatron pipeline parallelism and data parallelism, correctly routing tensors to the appropriate pipeline stages.
Usage
Import and instantiate this class when you need to train a student GPT model using precomputed teacher top-k logits. It is designed to be used with the SupervisedTrainer from NeMo Aligner.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: nemo_aligner/models/nlp/gpt/megatron_gpt_knowledge_distillation.py
- Lines: 44-212
Signature
class GPTKnowledgeDistillationModel(NLPAdapterModelMixin, MegatronGPTModel, SupervisedInterface):
def __init__(self, cfg: DictConfig, trainer: Trainer):
Import
from nemo_aligner.models.nlp.gpt.megatron_gpt_knowledge_distillation import GPTKnowledgeDistillationModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | Model configuration containing knowledge_distillation sub-config with kd_loss, kd_loss_weight, sft_loss_weight, logits_scale, target_logits_scale, and cross_tokenizer |
| trainer | Trainer | Yes | PyTorch Lightning Trainer instance |
Key Methods
| Method | Returns | Description |
|---|---|---|
| get_forward_output_and_loss_func(validation_step) | Callable | Returns forward-loss closure that computes combined KD + SFT loss from student logits and teacher top-k logits |
| get_loss_and_metrics(batch, forward_only) | Tuple[float, dict] | Executes forward (and optionally backward) pass and returns loss value with metrics dict containing "loss", "sft_loss", and "kd_loss" |
| prepare_for_training_step() | None | Prepares model for a training step (gradient setup) |
| finish_training_step() | None | Performs gradient reductions across data parallel groups |
| prepare_for_validation_step() | None | Prepares model for validation |
| finish_validation_step() | None | Cleans up after validation step |
Outputs
| Name | Type | Description |
|---|---|---|
| loss_value | float | Combined weighted loss (kd_loss_weight * kd_loss + sft_loss_weight * sft_loss) |
| metrics | dict | Dictionary with keys "loss", "sft_loss", "kd_loss" containing averaged loss values |
Usage Examples
from nemo_aligner.models.nlp.gpt.megatron_gpt_knowledge_distillation import GPTKnowledgeDistillationModel
from nemo_aligner.utils.utils import load_from_nemo
# Load a pretrained model for knowledge distillation training
ptl_model = load_from_nemo(
GPTKnowledgeDistillationModel,
cfg.model,
trainer,
strict=True,
load_base_model_only=False,
restore_path=cfg.pretrained_checkpoint.restore_from_path,
)
# Use with SupervisedTrainer
loss_value, metrics = ptl_model.get_loss_and_metrics(batch, forward_only=False)