Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner Compute Topk Logits

From Leeroopedia


Knowledge Sources
Domains Natural Language Processing, Model Compression, Knowledge Distillation, Synthetic Data Generation
Last Updated 2026-02-08 00:00 GMT

Overview

compute_topk_logits.py is a teacher logit extraction pipeline script that runs a teacher GPT model over a dataset and saves the top-k logits, token IDs, and log-sum-exp values to disk for subsequent knowledge distillation training.

Description

This script is the first phase of the knowledge distillation pipeline. It loads a pretrained MegatronGPTModel (the teacher) and iterates over an SFT dataset in batches, computing the teacher's output logits for each sequence. For each token position, it extracts the top-k logits and their corresponding token IDs, along with the log-sum-exp of all logits (used for normalization).

Key implementation details:

  • Incremental computation -- The script supports resuming from a previous run by reading already-processed indices from the output file. This avoids redundant computation.
  • Data parallel processing -- The script distributes batches across data parallel ranks, with each rank processing its assigned slice.
  • Batched sequence processing -- Uses compute_topk_logits_in_batched_sequence to efficiently process sequences in micro-batches.
  • Output format -- Results are written as JSONL (one JSON object per line), with each object containing the original tokens, labels, loss_mask, and the extracted topk_logits, topk_token_ids, log_sum_exp_logits, and index.
  • Configurable range -- Supports start_from_idx and end_at_idx parameters to process a subset of the dataset.
  • Padding -- Pads the last batch to the global batch size using the last example as a dummy.

The write_generations helper function handles serialization, carefully excluding padding examples from the output.

Usage

Run this script before knowledge distillation training to precompute teacher logits. The output JSONL file is then consumed by the knowledge distillation dataset builder.

Code Reference

Source Location

  • Repository: NVIDIA_NeMo_Aligner
  • File: examples/nlp/synthetic_data_gen/compute_topk_logits.py
  • Lines: 1-157

Signature

def write_generations(output_path, indices, batch, topk_logits, topk_token_ids, log_sum_exp_logits, num_padding):

@hydra_runner(config_path="conf", config_name="compute_topk_logits")
def main(cfg) -> None:

Import

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo_aligner.utils.distributed import compute_topk_logits_in_batched_sequence
from nemo_aligner.data.nlp.builders import build_sft_dataset

I/O Contract

Inputs

Name Type Required Description
cfg.pretrained_checkpoint.restore_from_path str Yes Path to the teacher model NeMo checkpoint
cfg.data.data DictConfig Yes Dataset configuration for the SFT dataset
cfg.data.chat bool Yes Whether the dataset is in chat format
cfg.data.chat_prompt_tokens DictConfig No Special tokens for chat prompt formatting
cfg.top_k int Yes Number of top logits to extract per token position
cfg.output_path str Yes File path for the output JSONL file
cfg.start_from_idx int No Starting dataset index (default: 0)
cfg.end_at_idx int No Ending dataset index (default: len(dataset) - 1)

Outputs

Name Type Description
JSONL output file File Each line is a JSON object with keys: tokens, labels, loss_mask, topk_logits, topk_token_ids, log_sum_exp_logits, index

Usage Examples

# Command-line invocation:
# python examples/nlp/synthetic_data_gen/compute_topk_logits.py \
#     pretrained_checkpoint.restore_from_path=/path/to/teacher.nemo \
#     data.data.file_path=/path/to/dataset.jsonl \
#     data.chat=True \
#     top_k=20 \
#     output_path=/path/to/output/topk_logits.jsonl

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment