Implementation:ContextualAI HALOs SFTTrainer Train
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, NLP, Training |
| Last Updated | 2026-02-08 03:00 GMT |
Overview
Concrete tool for supervised fine-tuning of language models provided by the HALOs SFTTrainer class.
Description
The SFTTrainer class extends BasicTrainer to implement supervised fine-tuning. It computes the NLL loss over target tokens by running a forward pass through the policy model, extracting log probabilities via get_batch_logps, and normalizing by the number of non-padding tokens. The SFTTrainer does not use a reference model (use_reference_model = False).
Training is orchestrated by the main() function in launch.py, which handles Hydra config resolution, model loading (with optional LoRA via PEFT), distributed setup via Accelerate/FSDP, optimizer/scheduler creation, and DataLoader instantiation before passing everything to the trainer.
Usage
Use this when you need to fine-tune a pre-trained base model on instruction-response pairs. Invoke via accelerate launch launch.py loss=sft model=llama datasets=[ultrachat].
Code Reference
Source Location
- Repository: ContextualAI/HALOs
- File: train/trainers.py (SFTTrainer), launch.py (orchestration)
- Lines: train/trainers.py:L552-581 (SFTTrainer), train/trainers.py:L312-398 (BasicTrainer.train), launch.py:L42-331 (main)
Signature
class SFTTrainer(BasicTrainer):
use_reference_model = False
def get_batch_metrics(
self,
batch: Dict[str, Union[List, torch.LongTensor]],
mode: str = 'train'
) -> Tuple[torch.Tensor, Dict]:
"""Compute NLL loss over target tokens.
Args:
batch: Must contain 'target_combined_input_ids',
'target_combined_attention_mask', 'target_labels'
mode: 'train', 'eval', or 'sample'
Returns:
loss: Normalized NLL loss
metrics: Dict with 'logps/{mode}' and 'loss/{mode}'
"""
# Orchestration entry point:
def main(config: DictConfig) -> None:
"""Main entry point. Resolves config, loads model, creates trainer, runs training."""
Import
from train.trainers import SFTTrainer
# Or invoke via CLI:
# accelerate launch launch.py loss=sft model=llama datasets=[ultrachat]
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | DictConfig | Yes | Hydra config with loss=sft, model, datasets, hyperparameters |
| model | AutoModelForCausalLM | Yes | Pre-trained base model (loaded via HuggingFace) |
| train_dataset | SFTDataLoader | Yes | Iterator producing batches with target_combined_input_ids, target_labels |
| eval_dataset | SFTDataLoader | No | Evaluation data iterator |
| config.lr | float | Yes | Learning rate (default 5e-6) |
| config.n_epochs | int | Yes | Number of training epochs (default 1) |
| config.model.batch_size | int | Yes | Global batch size (default 32) |
| config.model.use_peft | bool | No | Whether to apply LoRA (default false) |
Outputs
| Name | Type | Description |
|---|---|---|
| Model checkpoint | Directory | Saved to {cache_dir}/{exp_name}/FINAL/ with model weights, tokenizer, optimizer, scheduler, metrics.json |
| Training metrics | Dict | Per-step loss, log probabilities, learning rate, grad norm |
| WandB logs | Remote | Training curves if wandb.enabled=true |
Usage Examples
Basic SFT Training
# Launch SFT training on UltraChat with Llama-3-8B
accelerate launch \
--config_file accelerate_config/fsdp_4gpu.yaml \
--main_process_port 29501 \
launch.py \
loss=sft \
model=llama \
datasets=[ultrachat] \
exp_name=llama3-8B-sft \
++cache_dir=/models \
++model.name_or_path=meta-llama/Meta-Llama-3-8B
SFT with LoRA
# SFT with LoRA adapters for memory efficiency
accelerate launch \
--config_file accelerate_config/fsdp_4gpu.yaml \
launch.py \
loss=sft \
model=llama \
datasets=[ultrachat] \
exp_name=llama3-8B-sft-lora \
++model.name_or_path=meta-llama/Meta-Llama-3-8B \
++model.use_peft=true \
++model.peft.lora_r=64 \
++model.peft.lora_alpha=256