Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Workflow:CarperAI Trlx SFT Instruction Tuning

From Leeroopedia


Knowledge Sources
Domains LLMs, Fine_Tuning, Instruction_Following
Last Updated 2026-02-07 16:00 GMT

Overview

End-to-end process for supervised fine-tuning (SFT) of language models on text or instruction-following datasets using the trlX framework.

Description

This workflow fine-tunes a pretrained language model on a curated dataset of demonstrations using standard supervised learning (next-token prediction loss). It supports both plain text samples (for domain adaptation or style transfer) and prompt-completion pairs (for instruction following). SFT is the simplest training mode in trlX and is often used as the first stage before RLHF (PPO or ILQL). The workflow uses the unified trlx.train() API with the AccelerateSFTTrainer, which handles data loading, tokenization, distributed training, evaluation, and checkpointing.

Usage

Execute this workflow when you have a dataset of high-quality text samples or instruction-response pairs and want to fine-tune a language model to imitate them. Common use cases include training on curated positive examples (e.g., positive sentiment reviews), instruction-following datasets (e.g., Alpaca, Dolly), or as the supervised warm-up stage before reinforcement learning.

Execution Steps

Step 1: Configure training

Set up the training configuration by loading a default SFT config and optionally overriding hyperparameters. The configuration specifies the base model, tokenizer, optimizer, scheduler, and SFT-specific generation settings. SFT typically unfreezes all model layers (num_layers_unfrozen = -1) for full fine-tuning.

Key considerations:

  • Use lower learning rates (1e-5 to 1e-4) for larger models
  • Set seq_length to accommodate your longest training samples
  • The config.evolve() method allows clean nested parameter overrides
  • DeepSpeed ZeRO Stage 2 or 3 can be used for memory-efficient training of large models

Step 2: Prepare training data

Load and format the training dataset. For plain text samples, provide a list of strings. For instruction-following, provide a list of [prompt, completion] pairs. Optionally filter or preprocess the data (e.g., selecting only positive examples, applying prompt templates).

Key considerations:

  • For Alpaca-style data, format as [instruction_prompt, response] pairs
  • For plain text, the model learns next-token prediction on the full sequence
  • For prompt-completion pairs, loss is computed only on the completion tokens
  • Data preprocessing (template application, filtering) happens before passing to trlx.train()

Step 3: Define evaluation metrics

Optionally define a metric function for monitoring training quality. The function receives generated samples from evaluation prompts and returns a dictionary of named metric values. This does not affect training but provides visibility into model behavior.

Key considerations:

  • Use task-specific metrics (e.g., sentiment score, ROUGE, task accuracy)
  • Evaluation runs at eval_interval steps on the eval_prompts set
  • The metric function should be efficient as it runs periodically during training

Step 4: Launch SFT training

Call trlx.train() with the samples (and optionally eval_prompts, metric_fn, and config). This dispatches to the AccelerateSFTTrainer, which tokenizes the data, creates a training pipeline, and runs the supervised training loop with next-token prediction loss. The trainer handles gradient accumulation, mixed precision, distributed training, logging, and checkpointing.

Key considerations:

  • SFT training is the most straightforward mode with no RL-specific complexity
  • The trainer automatically handles tokenization and data collation
  • Checkpoints are saved at checkpoint_interval steps
  • Training progress is logged to the configured tracker (Weights & Biases by default)

Step 5: Save the fine-tuned model

After training, save the model using trainer.save_pretrained() to produce a HuggingFace-compatible model directory. The saved model can be loaded for inference, uploaded to HuggingFace Hub, or used as the starting point for subsequent RLHF training (PPO or ILQL).

Key considerations:

  • The output directory contains model weights, config, and tokenizer files
  • The saved model is compatible with HuggingFace Transformers for direct loading
  • This SFT checkpoint is commonly used as the base for Stage 2 (reward model) or Stage 3 (PPO) of RLHF

Execution Diagram

GitHub URL

Workflow Repository