Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Workflow:Alibaba ROLL Knowledge Distillation Pipeline

From Leeroopedia


Knowledge Sources
Domains LLMs, Knowledge_Distillation, Distributed_Training
Last Updated 2026-02-07 19:00 GMT

Overview

End-to-end process for transferring knowledge from a large teacher LLM to a smaller student LLM using logit-based distillation with configurable KL divergence objectives.

Description

This workflow implements the Knowledge Distillation pipeline in the ROLL framework. It trains a smaller student model to mimic a larger teacher model's output distribution, producing a compact model that retains much of the teacher's capability. The pipeline supports multiple distillation objectives (forward KL, reverse KL, adaptive KL, skewed variants) and efficient logit transfer between teacher and student workers using NCCL or IPC backends. Only the top-k teacher logits are transferred to minimize communication overhead. Both text-only LLM and vision-language model (VLM) distillation are supported.

Usage

Execute this workflow when you have a large, capable teacher model (e.g., Qwen2.5-7B or Qwen2.5-14B) and want to transfer its knowledge to a smaller student model (e.g., Qwen2.5-1.5B) for more efficient inference deployment, while preserving as much of the teacher's quality as possible.

Execution Steps

Step 1: Environment Setup and Configuration

Prepare the compute environment and define the Hydra YAML configuration specifying both teacher and student model paths, distillation parameters (KD objective, temperature, loss weight, top-k logits), and distributed training backends for both models. Configure the logits transfer backend (NCCL or IPC) based on the deployment topology.

Key considerations:

  • Choose the KD objective based on the desired behavior: forward_kl for mode-covering, reverse_kl for mode-seeking
  • The distill_loss_weight parameter controls the balance between distillation loss and SFT loss (0.0 = pure SFT, 1.0 = pure distillation)
  • logits_topk controls how many of the teacher's top logits are transferred (typical: 64), trading off fidelity for communication efficiency

Step 2: Dataset Preparation

Prepare an instruction-response dataset in a format compatible with both teacher and student tokenizers. The pipeline tokenizes data using the model's chat template and creates labels for the SFT component of the combined loss. Both teacher and student share the same input sequences.

What happens:

  • Dataset is loaded and tokenized with the student model's chat template
  • Labels are created with prompt tokens masked (IGNORE_INDEX) and response tokens preserved
  • If distill_on_prompt is enabled, distillation loss is computed on prompt tokens as well
  • Data is batched and distributed across data-parallel ranks

Step 3: Distributed Worker Initialization

Launch the Ray cluster and initialize two worker groups: the teacher inference cluster (frozen, forward-only) and the student training cluster (being optimized). Both clusters load their respective models with the configured strategy backends. The logits transfer channel is established between teacher and student workers.

Key considerations:

  • Teacher workers use an inference-only strategy (DeepSpeed infer or Megatron infer) since no gradients are needed
  • Student workers use a training strategy (DeepSpeed train, Megatron train, or FSDP2)
  • NCCL-based logits transfer is fastest when teacher and student are on the same node; IPC works across nodes

Step 4: Teacher Forward Pass

For each training batch, pass the input sequences through the frozen teacher model to obtain output logits. Extract the top-k logits and their corresponding token indices to reduce the data volume for transfer.

What happens:

  • Teacher model runs forward pass on the batch
  • Top-k logits and indices are extracted from the full vocabulary distribution
  • Compressed logits are transferred to student workers via the configured backend (NCCL or IPC)

Step 5: Student Training with Distillation Loss

Compute the student model's logits on the same input sequences. Combine the distillation loss (KL divergence between teacher and student distributions using the transferred top-k logits) with the standard SFT cross-entropy loss using the configured distill_loss_weight. Apply gradient updates to the student model.

Key considerations:

  • Temperature scaling softens both distributions before computing KL divergence
  • The combined loss is: (1 - distill_loss_weight) * SFT_loss + distill_loss_weight * KD_loss
  • Gradient accumulation and distributed reduction follow the standard training strategy patterns
  • Only the student model parameters are updated; the teacher remains frozen

Step 6: Validation and Checkpointing

Periodically evaluate the student model on a held-out validation set by computing both SFT loss and distillation loss. Save student model checkpoints at configured intervals and log training metrics to the tracking backend.

Key considerations:

  • Monitor both individual losses (SFT and KD) and the combined loss for training diagnostics
  • Student checkpoints can be directly deployed as standalone models
  • Compare student validation metrics against teacher baselines to measure knowledge transfer effectiveness

Execution Diagram

GitHub URL

Workflow Repository