Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner Train GPT Knowledge Distillation

From Leeroopedia


Knowledge Sources
Domains Natural Language Processing, Model Compression, Knowledge Distillation
Last Updated 2026-02-08 00:00 GMT

Overview

train_gpt_knowledge_distillation.py is the entry-point training script that orchestrates knowledge distillation training for GPT models, including data loading with a custom collate function, model initialization, and trainer execution.

Description

This script sets up the complete knowledge distillation training pipeline. It performs the following steps:

  1. Configuration loading -- Uses Hydra with the config path "conf" and config name "gpt_knowledge_distillation". Overrides the model config from the pretrained checkpoint.
  2. Trainer creation -- Resolves and creates a PyTorch Lightning trainer configured for knowledge distillation.
  3. Model loading -- Loads a pretrained GPTKnowledgeDistillationModel from a NeMo checkpoint and optionally initializes PEFT adapters.
  4. Dataset construction -- Builds train and validation datasets using build_train_valid_test_knowledge_distillation_datasets, which loads chunked data containing tokens, labels, loss masks, and precomputed teacher top-k logits.
  5. Custom collate function -- The _collate_fn function pads variable-length sequences in a batch and constructs attention masks and position IDs. It handles tokens, labels, loss_mask, topk_logits, and topk_token_ids.
  6. Training execution -- Creates a SupervisedTrainer and calls fit() to run the training loop.

The collate function is critical: it pads tokens, labels, and loss_mask to uniform length using the EOS token ID and zero-padding respectively, and pads topk_logits (3D: [B, seq_len, k]) and topk_token_ids (3D: [B, seq_len, k]) accordingly. It also computes left-to-right attention masks and position IDs via get_ltor_masks_and_position_ids.

Usage

Run this script to perform knowledge distillation training after teacher logits have been precomputed. It is typically invoked via the command line with Hydra configuration overrides.

Code Reference

Source Location

  • Repository: NVIDIA_NeMo_Aligner
  • File: examples/nlp/gpt/train_gpt_knowledge_distillation.py
  • Lines: 1-205

Signature

def _collate_fn(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):

@hydra_runner(config_path="conf", config_name="gpt_knowledge_distillation")
def main(cfg) -> None:

Import

from nemo_aligner.models.nlp.gpt.megatron_gpt_knowledge_distillation import GPTKnowledgeDistillationModel
from nemo_aligner.algorithms.supervised import SupervisedTrainer
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_knowledge_distillation_datasets

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes Hydra configuration containing model, trainer, data, and pretrained_checkpoint sections
cfg.pretrained_checkpoint.restore_from_path str Yes Path to the pretrained NeMo checkpoint for the student model
cfg.model.data.data_prefix str Yes Path prefix for the knowledge distillation dataset containing teacher logits
cfg.model.data.seq_length int Yes Maximum sequence length
cfg.model.data.n_chunks int Yes Number of data chunks in the dataset
cfg.model.data.n_examples_per_chunk int Yes Number of examples per data chunk

Outputs

Name Type Description
Trained model checkpoint File Saved NeMo checkpoint of the distilled student model
Training logs Logs Loss metrics (loss, sft_loss, kd_loss) logged via the experiment logger

Usage Examples

# Command-line invocation:
# python examples/nlp/gpt/train_gpt_knowledge_distillation.py \
#     pretrained_checkpoint.restore_from_path=/path/to/student.nemo \
#     model.data.data_prefix=/path/to/kd_data \
#     model.knowledge_distillation.kd_loss=fwd_kl \
#     model.knowledge_distillation.kd_loss_weight=1.0 \
#     model.knowledge_distillation.sft_loss_weight=0.0

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment