Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner MegatronGPT Knowledge Distillation

From Leeroopedia


Knowledge Sources
Domains Natural Language Processing, Model Compression, Knowledge Distillation
Last Updated 2026-02-08 00:00 GMT

Overview

GPTKnowledgeDistillationModel is a Megatron-based GPT model class that implements knowledge distillation training by combining a KD loss (top-k logit cross-entropy) with an optional SFT loss.

Description

The GPTKnowledgeDistillationModel class extends NLPAdapterModelMixin, MegatronGPTModel, and SupervisedInterface to provide a complete knowledge distillation training pipeline for GPT models within the NeMo Aligner framework.

During initialization, the model reads the following configuration parameters from cfg.knowledge_distillation:

  • target_logits_scale (default 1.0) -- scaling factor applied to teacher logits before loss computation
  • logits_scale (default 1.0) -- scaling factor applied to student logits before loss computation
  • kd_loss (default "fwd_kl") -- the type of KL divergence used; supports "fwd_kl" or "bwd_kl"
  • kd_loss_weight (default 1) -- weight for the knowledge distillation loss term
  • sft_loss_weight (default 0) -- weight for the supervised fine-tuning loss term
  • cross_tokenizer (default False) -- enables cross-tokenizer distillation

The core loss computation is delegated to _TopKLogitsCrossEntropy, a custom autograd function that efficiently computes the cross-entropy between the student's full logits and the teacher's sparse top-k logit representation. The model supports Megatron pipeline parallelism and data parallelism, correctly routing tensors to the appropriate pipeline stages.

Usage

Import and instantiate this class when you need to train a student GPT model using precomputed teacher top-k logits. It is designed to be used with the SupervisedTrainer from NeMo Aligner.

Code Reference

Source Location

  • Repository: NVIDIA_NeMo_Aligner
  • File: nemo_aligner/models/nlp/gpt/megatron_gpt_knowledge_distillation.py
  • Lines: 44-212

Signature

class GPTKnowledgeDistillationModel(NLPAdapterModelMixin, MegatronGPTModel, SupervisedInterface):
    def __init__(self, cfg: DictConfig, trainer: Trainer):

Import

from nemo_aligner.models.nlp.gpt.megatron_gpt_knowledge_distillation import GPTKnowledgeDistillationModel

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes Model configuration containing knowledge_distillation sub-config with kd_loss, kd_loss_weight, sft_loss_weight, logits_scale, target_logits_scale, and cross_tokenizer
trainer Trainer Yes PyTorch Lightning Trainer instance

Key Methods

Method Returns Description
get_forward_output_and_loss_func(validation_step) Callable Returns forward-loss closure that computes combined KD + SFT loss from student logits and teacher top-k logits
get_loss_and_metrics(batch, forward_only) Tuple[float, dict] Executes forward (and optionally backward) pass and returns loss value with metrics dict containing "loss", "sft_loss", and "kd_loss"
prepare_for_training_step() None Prepares model for a training step (gradient setup)
finish_training_step() None Performs gradient reductions across data parallel groups
prepare_for_validation_step() None Prepares model for validation
finish_validation_step() None Cleans up after validation step

Outputs

Name Type Description
loss_value float Combined weighted loss (kd_loss_weight * kd_loss + sft_loss_weight * sft_loss)
metrics dict Dictionary with keys "loss", "sft_loss", "kd_loss" containing averaged loss values

Usage Examples

from nemo_aligner.models.nlp.gpt.megatron_gpt_knowledge_distillation import GPTKnowledgeDistillationModel
from nemo_aligner.utils.utils import load_from_nemo

# Load a pretrained model for knowledge distillation training
ptl_model = load_from_nemo(
    GPTKnowledgeDistillationModel,
    cfg.model,
    trainer,
    strict=True,
    load_base_model_only=False,
    restore_path=cfg.pretrained_checkpoint.restore_from_path,
)

# Use with SupervisedTrainer
loss_value, metrics = ptl_model.get_loss_and_metrics(batch, forward_only=False)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment