Principle:Microsoft DeepSpeedExamples Supervised Fine Tuning
Overview
A transfer learning technique that adapts a pre-trained language model to follow instructions by training on curated prompt-response demonstrations.
Description
Supervised Fine-Tuning (SFT) is the first step in the Reinforcement Learning from Human Feedback (RLHF) alignment pipeline. It takes a pre-trained causal language model and continues training on high-quality instruction-following data. This teaches the model the expected response format before reward model training and reinforcement learning fine-tuning take place.
Through SFT, the model learns to produce helpful, structured responses rather than simply predicting the next token in a sequence. The training data consists of curated prompt-response demonstrations, where each example pairs an instruction (the prompt) with a desired output (the demonstration response). By exposing the model to many such pairs, SFT shifts the model's distribution from generic language modeling toward instruction-following behavior.
In the DeepSpeed-Chat pipeline, SFT corresponds to Step 1 of the three-step RLHF process:
- Step 1 -- Supervised Fine-Tuning (SFT): Train on demonstration data to learn instruction-following format.
- Step 2 -- Reward Model Training: Train a reward model on human preference comparisons.
- Step 3 -- RLHF Fine-Tuning (PPO): Optimize the SFT model using the reward model via Proximal Policy Optimization.
Usage
Use Supervised Fine-Tuning when initializing the RLHF pipeline. SFT produces the base policy model that will be further refined by PPO in Step 3. The SFT-trained model also serves as the reference model against which KL-divergence is measured during reinforcement learning, preventing the policy from drifting too far from coherent language.
SFT is also used standalone for instruction-tuning without full RLHF. In many practical deployments, SFT alone provides substantial improvements in instruction-following ability and can be sufficient when human preference data or reward model infrastructure is unavailable.
When to apply SFT:
- As the first stage of an RLHF alignment pipeline.
- When adapting a general-purpose language model to follow domain-specific instructions.
- When high-quality demonstration data is available but preference comparison data is not.
- As a warm-start before applying Direct Preference Optimization (DPO) or other alignment methods.
Theoretical Basis
SFT minimizes the cross-entropy loss on demonstration data. Given a prompt x and a demonstration response y = (y_1, y_2, ..., y_T), the loss is defined as:
L_SFT = - sum_{t=1}^{T} log P(y_t | y_{<t}, x)
where x is the input prompt, y is the demonstration response, and P(y_t | y_{<t}, x) is the model's predicted probability of token y_t given all preceding tokens and the prompt.
Conceptually, this is standard causal language model fine-tuning but applied to instruction-formatted data. The key distinction from pre-training is not the loss function itself but the nature of the training distribution: rather than learning from unstructured web text, the model learns from curated instruction-response pairs that encode the desired behavior.
In the DeepSpeed-Chat implementation, the loss computation supports optional fp32 upcasting for numerical stability when training in lower precision (fp16 or bf16). The shifted cross-entropy loss is computed as:
# From model_utils.py -- causal_lm_model_to_fp32_loss
shift_logits = lm_logits[..., :-1, :].float().contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length)
)
This ensures that even when the model forward pass runs in half precision, the loss gradients remain numerically stable.