Implementation:NVIDIA NeMo Aligner Compute Topk Logits
| Knowledge Sources | |
|---|---|
| Domains | Natural Language Processing, Model Compression, Knowledge Distillation, Synthetic Data Generation |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
compute_topk_logits.py is a teacher logit extraction pipeline script that runs a teacher GPT model over a dataset and saves the top-k logits, token IDs, and log-sum-exp values to disk for subsequent knowledge distillation training.
Description
This script is the first phase of the knowledge distillation pipeline. It loads a pretrained MegatronGPTModel (the teacher) and iterates over an SFT dataset in batches, computing the teacher's output logits for each sequence. For each token position, it extracts the top-k logits and their corresponding token IDs, along with the log-sum-exp of all logits (used for normalization).
Key implementation details:
- Incremental computation -- The script supports resuming from a previous run by reading already-processed indices from the output file. This avoids redundant computation.
- Data parallel processing -- The script distributes batches across data parallel ranks, with each rank processing its assigned slice.
- Batched sequence processing -- Uses compute_topk_logits_in_batched_sequence to efficiently process sequences in micro-batches.
- Output format -- Results are written as JSONL (one JSON object per line), with each object containing the original tokens, labels, loss_mask, and the extracted topk_logits, topk_token_ids, log_sum_exp_logits, and index.
- Configurable range -- Supports start_from_idx and end_at_idx parameters to process a subset of the dataset.
- Padding -- Pads the last batch to the global batch size using the last example as a dummy.
The write_generations helper function handles serialization, carefully excluding padding examples from the output.
Usage
Run this script before knowledge distillation training to precompute teacher logits. The output JSONL file is then consumed by the knowledge distillation dataset builder.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: examples/nlp/synthetic_data_gen/compute_topk_logits.py
- Lines: 1-157
Signature
def write_generations(output_path, indices, batch, topk_logits, topk_token_ids, log_sum_exp_logits, num_padding):
@hydra_runner(config_path="conf", config_name="compute_topk_logits")
def main(cfg) -> None:
Import
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo_aligner.utils.distributed import compute_topk_logits_in_batched_sequence
from nemo_aligner.data.nlp.builders import build_sft_dataset
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg.pretrained_checkpoint.restore_from_path | str | Yes | Path to the teacher model NeMo checkpoint |
| cfg.data.data | DictConfig | Yes | Dataset configuration for the SFT dataset |
| cfg.data.chat | bool | Yes | Whether the dataset is in chat format |
| cfg.data.chat_prompt_tokens | DictConfig | No | Special tokens for chat prompt formatting |
| cfg.top_k | int | Yes | Number of top logits to extract per token position |
| cfg.output_path | str | Yes | File path for the output JSONL file |
| cfg.start_from_idx | int | No | Starting dataset index (default: 0) |
| cfg.end_at_idx | int | No | Ending dataset index (default: len(dataset) - 1) |
Outputs
| Name | Type | Description |
|---|---|---|
| JSONL output file | File | Each line is a JSON object with keys: tokens, labels, loss_mask, topk_logits, topk_token_ids, log_sum_exp_logits, index |
Usage Examples
# Command-line invocation:
# python examples/nlp/synthetic_data_gen/compute_topk_logits.py \
# pretrained_checkpoint.restore_from_path=/path/to/teacher.nemo \
# data.data.file_path=/path/to/dataset.jsonl \
# data.chat=True \
# top_k=20 \
# output_path=/path/to/output/topk_logits.jsonl