Principle:Allenai Open instruct Supervised Finetuning
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Deep Learning, Natural Language Processing |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Supervised fine-tuning (SFT) is the process of training a pre-trained language model on instruction-response pairs using next-token prediction with a cross-entropy loss computed only on assistant-generated tokens.
Description
SFT is typically the first post-training step after language model pre-training. While the pre-trained model has learned general language understanding from a large corpus, SFT teaches it to follow instructions and produce helpful, coherent responses in a conversational format.
The SFT training loop follows a standard supervised learning paradigm:
- Data preparation: Instruction-response datasets are loaded, mixed, tokenized, and cached. Labels are constructed with prompt masking so only assistant responses contribute to the loss.
- Model loading: The pre-trained model is loaded (optionally with LoRA/QLoRA adapters) and prepared for training with the appropriate precision and attention implementation.
- Optimization: The model is trained with AdamW optimizer, a learning rate schedule (linear warmup + decay), and optional gradient accumulation for effective batch sizes larger than what fits in GPU memory.
- Distributed training: HuggingFace Accelerate handles multi-GPU and multi-node distribution, optionally with DeepSpeed ZeRO for memory-efficient training.
- Checkpointing: Model checkpoints are saved periodically and at the end of training, with optional upload to HuggingFace Hub.
Assistant-only masking is the key difference from standard language model training. In SFT, the loss is not computed over the entire sequence; instead, only the tokens that the model should learn to generate (the assistant's response) contribute to the gradient. This prevents the model from being trained to "predict" the user's input, which would be counterproductive.
Usage
Use SFT as the first stage of post-training when adapting a base language model for instruction following. It is appropriate when you have high-quality instruction-response pairs and want the model to learn a specific response style or capability.
Theoretical Basis
Loss function: The SFT loss is the standard autoregressive cross-entropy loss, restricted to assistant tokens:
L_SFT = - (1 / T_a) * sum_{t in A} log P_theta(x_t | x_{<t})
Where:
theta= model parametersx_t= token at position tA= set of positions where the token is part of an assistant responseT_a= |A|, the total number of assistant tokensP_theta(x_t | x_{<t})= model's predicted probability for token x_t given all preceding tokens
Gradient update with AdamW:
g_t = gradient of L_SFT w.r.t. theta
m_t = beta_1 * m_{t-1} + (1 - beta_1) * g_t # first moment
v_t = beta_2 * v_{t-1} + (1 - beta_2) * g_t^2 # second moment
m_hat = m_t / (1 - beta_1^t) # bias correction
v_hat = v_t / (1 - beta_2^t) # bias correction
theta = theta - lr * (m_hat / (sqrt(v_hat) + eps) + wd * theta) # weight decay
Effective batch size with gradient accumulation:
effective_batch_size = per_device_batch_size * num_devices * gradient_accumulation_steps
Learning rate schedule (linear warmup + linear decay):
if step < warmup_steps:
lr = learning_rate * (step / warmup_steps)
else:
lr = learning_rate * (1 - (step - warmup_steps) / (total_steps - warmup_steps))