Workflow:Eric mitchell Direct preference optimization SFT Training
| Knowledge Sources | |
|---|---|
| Domains | LLMs, Fine_Tuning, Preference_Learning |
| Last Updated | 2026-02-08 01:00 GMT |
Overview
End-to-end process for supervised fine-tuning (SFT) of causal language models on preference datasets to produce an in-distribution policy model for subsequent DPO training.
Description
This workflow covers the first stage of the DPO pipeline: Supervised Fine-Tuning. SFT trains a base HuggingFace causal language model on the chosen (preferred) responses from a preference dataset, ensuring the model produces in-distribution outputs before preference learning begins. The workflow handles environment setup, Hydra-based configuration, model loading with optional multi-GPU distribution via FSDP, dataset preparation and tokenization, the SFT training loop with periodic evaluation and checkpointing, and final model saving. The output is a checkpoint (policy.pt) that serves as both the policy initialization and reference model for the DPO stage.
Usage
Execute this workflow when you have a preference dataset (such as Anthropic HH-RLHF, Stanford Human Preferences, or StackExchange) and need to create a supervised fine-tuned base model as the prerequisite for DPO preference learning. This is always the first step in the DPO pipeline and must be completed before running DPO training.
Execution Steps
Step 1: Environment_Setup
Set up a Python virtual environment and install the required dependencies. The project requires PyTorch, HuggingFace Transformers, Hydra for configuration management, and additional libraries for distributed training (FSDP) and logging (Weights & Biases).
Key considerations:
- Python 3.8+ is required
- Dependencies are specified in requirements.txt with pinned versions
- For multi-GPU FSDP training, you may need to increase the file descriptor limit (ulimit -n 64000)
Step 2: Configuration_Selection
Select the model, dataset, and training hyperparameters via Hydra configuration composition. The main config (config.yaml) composes with a model sub-config (e.g., pythia28.yaml) and the SFT loss sub-config (sft.yaml). Key parameters include batch size, learning rate, gradient accumulation steps, and the trainer class.
Key considerations:
- Set loss=sft to select supervised fine-tuning mode
- Choose a pre-configured model (gpt2-large, gpt2-xl, gptj, llama7b, pythia28, pythia69) or use blank_model with a custom path
- For FSDP training, specify the trainer as FSDPTrainer and provide model.block_name
- Enable mixed precision (model.fsdp_policy_mp=bfloat16) for faster training on compatible hardware
Step 3: Model_Loading
Load the base causal language model from HuggingFace with the specified dtype and device mapping. Dropout is disabled across all model layers to ensure deterministic training behavior. For BasicTrainer, the model is loaded with balanced device mapping across available GPUs; for FSDP, models are loaded on CPU first and then sharded.
Key considerations:
- Models are loaded via transformers.AutoModelForCausalLM.from_pretrained
- Dropout is explicitly disabled for all modules in the model
- The policy dtype (e.g., float32 or bfloat16) is set via the model config
- No reference model is loaded for SFT (only needed for DPO)
Step 4: Dataset_Preparation
Load and tokenize the preference dataset(s) for SFT mode. In SFT mode, only the chosen (sft_target) response is used for each prompt. The data pipeline loads from HuggingFace datasets, converts to a canonical format with prompt-response pairs, tokenizes with truncation handling, and produces batched iterators with proper padding and collation.
Key considerations:
- Multiple datasets can be combined (e.g., datasets=[hh,shp])
- Truncation handles prompts that exceed max_prompt_length (keep_end for Anthropic HH, keep_start for others)
- Labels mask prompt tokens with -100 so loss is only computed on the response
- Data is cached locally after first download
Step 5: Training_Loop
Execute the SFT training loop with gradient accumulation, periodic evaluation, and logging. The trainer iterates over batches, computes the negative log-likelihood loss on chosen responses, accumulates gradients over microbatches, clips gradients, and updates parameters with the optimizer (RMSprop by default) and a linear warmup schedule.
Key considerations:
- Evaluation runs every eval_every examples (default 20,000) with optional sample generation
- Gradient accumulation enables effective batch sizes larger than GPU memory allows
- The training loop supports both single-process (BasicTrainer) and multi-process (FSDPTrainer, TensorParallelTrainer) execution
- Metrics are logged to Weights & Biases when enabled
Step 6: Checkpoint_Saving
Save the final model checkpoint containing the policy state dict, optimizer state, and scheduler state. Checkpoints are saved both at evaluation intervals and at the end of training (to a LATEST directory). For FSDP, state dicts are gathered from all processes and saved only on rank 0.
Key considerations:
- The final checkpoint path follows the pattern: {run_dir}/LATEST/policy.pt
- Intermediate checkpoints are saved as {run_dir}/step-{N}/policy.pt
- The checkpoint contains step_idx, state dict, and evaluation metrics
- This LATEST/policy.pt path is needed as input for the DPO training stage