Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:ContextualAI HALOs SFTTrainer Train

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, NLP, Training
Last Updated 2026-02-08 03:00 GMT

Overview

Concrete tool for supervised fine-tuning of language models provided by the HALOs SFTTrainer class.

Description

The SFTTrainer class extends BasicTrainer to implement supervised fine-tuning. It computes the NLL loss over target tokens by running a forward pass through the policy model, extracting log probabilities via get_batch_logps, and normalizing by the number of non-padding tokens. The SFTTrainer does not use a reference model (use_reference_model = False).

Training is orchestrated by the main() function in launch.py, which handles Hydra config resolution, model loading (with optional LoRA via PEFT), distributed setup via Accelerate/FSDP, optimizer/scheduler creation, and DataLoader instantiation before passing everything to the trainer.

Usage

Use this when you need to fine-tune a pre-trained base model on instruction-response pairs. Invoke via accelerate launch launch.py loss=sft model=llama datasets=[ultrachat].

Code Reference

Source Location

  • Repository: ContextualAI/HALOs
  • File: train/trainers.py (SFTTrainer), launch.py (orchestration)
  • Lines: train/trainers.py:L552-581 (SFTTrainer), train/trainers.py:L312-398 (BasicTrainer.train), launch.py:L42-331 (main)

Signature

class SFTTrainer(BasicTrainer):
    use_reference_model = False

    def get_batch_metrics(
        self,
        batch: Dict[str, Union[List, torch.LongTensor]],
        mode: str = 'train'
    ) -> Tuple[torch.Tensor, Dict]:
        """Compute NLL loss over target tokens.

        Args:
            batch: Must contain 'target_combined_input_ids',
                   'target_combined_attention_mask', 'target_labels'
            mode: 'train', 'eval', or 'sample'

        Returns:
            loss: Normalized NLL loss
            metrics: Dict with 'logps/{mode}' and 'loss/{mode}'
        """

# Orchestration entry point:
def main(config: DictConfig) -> None:
    """Main entry point. Resolves config, loads model, creates trainer, runs training."""

Import

from train.trainers import SFTTrainer
# Or invoke via CLI:
# accelerate launch launch.py loss=sft model=llama datasets=[ultrachat]

I/O Contract

Inputs

Name Type Required Description
config DictConfig Yes Hydra config with loss=sft, model, datasets, hyperparameters
model AutoModelForCausalLM Yes Pre-trained base model (loaded via HuggingFace)
train_dataset SFTDataLoader Yes Iterator producing batches with target_combined_input_ids, target_labels
eval_dataset SFTDataLoader No Evaluation data iterator
config.lr float Yes Learning rate (default 5e-6)
config.n_epochs int Yes Number of training epochs (default 1)
config.model.batch_size int Yes Global batch size (default 32)
config.model.use_peft bool No Whether to apply LoRA (default false)

Outputs

Name Type Description
Model checkpoint Directory Saved to {cache_dir}/{exp_name}/FINAL/ with model weights, tokenizer, optimizer, scheduler, metrics.json
Training metrics Dict Per-step loss, log probabilities, learning rate, grad norm
WandB logs Remote Training curves if wandb.enabled=true

Usage Examples

Basic SFT Training

# Launch SFT training on UltraChat with Llama-3-8B
accelerate launch \
    --config_file accelerate_config/fsdp_4gpu.yaml \
    --main_process_port 29501 \
    launch.py \
    loss=sft \
    model=llama \
    datasets=[ultrachat] \
    exp_name=llama3-8B-sft \
    ++cache_dir=/models \
    ++model.name_or_path=meta-llama/Meta-Llama-3-8B

SFT with LoRA

# SFT with LoRA adapters for memory efficiency
accelerate launch \
    --config_file accelerate_config/fsdp_4gpu.yaml \
    launch.py \
    loss=sft \
    model=llama \
    datasets=[ultrachat] \
    exp_name=llama3-8B-sft-lora \
    ++model.name_or_path=meta-llama/Meta-Llama-3-8B \
    ++model.use_peft=true \
    ++model.peft.lora_r=64 \
    ++model.peft.lora_alpha=256

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment