Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Microsoft DeepSpeedExamples Supervised Fine Tuning

From Leeroopedia


Template:Metadata

Overview

A transfer learning technique that adapts a pre-trained language model to follow instructions by training on curated prompt-response demonstrations.

Description

Supervised Fine-Tuning (SFT) is the first step in the Reinforcement Learning from Human Feedback (RLHF) alignment pipeline. It takes a pre-trained causal language model and continues training on high-quality instruction-following data. This teaches the model the expected response format before reward model training and reinforcement learning fine-tuning take place.

Through SFT, the model learns to produce helpful, structured responses rather than simply predicting the next token in a sequence. The training data consists of curated prompt-response demonstrations, where each example pairs an instruction (the prompt) with a desired output (the demonstration response). By exposing the model to many such pairs, SFT shifts the model's distribution from generic language modeling toward instruction-following behavior.

In the DeepSpeed-Chat pipeline, SFT corresponds to Step 1 of the three-step RLHF process:

  1. Step 1 -- Supervised Fine-Tuning (SFT): Train on demonstration data to learn instruction-following format.
  2. Step 2 -- Reward Model Training: Train a reward model on human preference comparisons.
  3. Step 3 -- RLHF Fine-Tuning (PPO): Optimize the SFT model using the reward model via Proximal Policy Optimization.

Usage

Use Supervised Fine-Tuning when initializing the RLHF pipeline. SFT produces the base policy model that will be further refined by PPO in Step 3. The SFT-trained model also serves as the reference model against which KL-divergence is measured during reinforcement learning, preventing the policy from drifting too far from coherent language.

SFT is also used standalone for instruction-tuning without full RLHF. In many practical deployments, SFT alone provides substantial improvements in instruction-following ability and can be sufficient when human preference data or reward model infrastructure is unavailable.

When to apply SFT:

  • As the first stage of an RLHF alignment pipeline.
  • When adapting a general-purpose language model to follow domain-specific instructions.
  • When high-quality demonstration data is available but preference comparison data is not.
  • As a warm-start before applying Direct Preference Optimization (DPO) or other alignment methods.

Theoretical Basis

SFT minimizes the cross-entropy loss on demonstration data. Given a prompt x and a demonstration response y = (y_1, y_2, ..., y_T), the loss is defined as:

L_SFT = - sum_{t=1}^{T} log P(y_t | y_{<t}, x)

where x is the input prompt, y is the demonstration response, and P(y_t | y_{<t}, x) is the model's predicted probability of token y_t given all preceding tokens and the prompt.

Conceptually, this is standard causal language model fine-tuning but applied to instruction-formatted data. The key distinction from pre-training is not the loss function itself but the nature of the training distribution: rather than learning from unstructured web text, the model learns from curated instruction-response pairs that encode the desired behavior.

In the DeepSpeed-Chat implementation, the loss computation supports optional fp32 upcasting for numerical stability when training in lower precision (fp16 or bf16). The shifted cross-entropy loss is computed as:

# From model_utils.py -- causal_lm_model_to_fp32_loss
shift_logits = lm_logits[..., :-1, :].float().contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(
    shift_logits.view(batch_size * seq_length, vocab_size),
    shift_labels.view(batch_size * seq_length)
)

This ensures that even when the model forward pass runs in half precision, the loss gradients remain numerically stable.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment