Implementation:Hiyouga LLaMA Factory SFT Workflow
| Knowledge Sources | |
|---|---|
| Domains | Supervised Fine-Tuning, Training Workflow, NLP |
| Last Updated | 2026-02-06 19:00 GMT |
Overview
run_sft is the end-to-end orchestrator function for supervised fine-tuning, the most commonly used training stage in LLaMA-Factory.
Description
The run_sft function loads the tokenizer, template, dataset at the "sft" stage, and the model, then configures a SFTDataCollatorWith4DAttentionMask for proper attention masking. It sets up evaluation metrics (accuracy or ROUGE/BLEU similarity), generation keyword arguments with EOS token handling, and selects between CustomSeq2SeqTrainer or KTransformers KTrainer. The function drives four optional phases: training (with TPS tracking and loss plotting), generative evaluation, prediction (with results saved to JSONL), and model card creation. It supports both standard training and generative evaluation with ROUGE/BLEU metrics.
Usage
Use run_sft when performing supervised fine-tuning on instruction-following or chat data. This is the primary training entry point invoked by the framework's dispatcher when the stage is "sft". It supports quantized models for evaluation-only, 4D attention masks for block-diagonal attention, and KTransformers for offloaded training.
Code Reference
Source Location
- Repository: Hiyouga_LLaMA_Factory
- File: src/llamafactory/train/sft/workflow.py
- Lines: 1-173
Signature
def run_sft(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[list["TrainerCallback"]] = None,
) -> None
Import
from llamafactory.train.sft.workflow import run_sft
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_args | ModelArguments | Yes | Model configuration including block_diag_attn, compute_dtype, and use_kt flags |
| data_args | DataArguments | Yes | Dataset configuration including ignore_pad_token_for_loss setting |
| training_args | Seq2SeqTrainingArguments | Yes | Training hyperparameters; do_train, do_eval, do_predict, and predict_with_generate flags |
| finetuning_args | FinetuningArguments | Yes | Fine-tuning settings including compute_accuracy, plot_loss, and include_effective_tokens_per_second |
| generating_args | GeneratingArguments | Yes | Generation parameters for evaluation and prediction; includes skip_special_tokens and to_dict method |
| callbacks | Optional[list[TrainerCallback]] | No | Additional trainer callbacks |
Outputs
| Name | Type | Description |
|---|---|---|
| (none) | None | Side effects: saves model, metrics (loss/accuracy/TPS/ROUGE/BLEU), predictions (prompt/predict/label JSONL), loss plots, and model card to output_dir |
Usage Examples
# Typical invocation for supervised fine-tuning
from llamafactory.train.sft.workflow import run_sft
run_sft(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=None,
)
# With predict_with_generate enabled, evaluation produces:
# {"rouge-1": 45.2, "rouge-2": 22.1, "rouge-l": 40.3, "bleu-4": 18.7}
# Prediction output (generated_predictions.jsonl):
# {"prompt": "...", "predict": "...", "label": "..."}
Related Pages
- Hiyouga_LLaMA_Factory_SFT_Trainer - The CustomSeq2SeqTrainer class used internally
- Hiyouga_LLaMA_Factory_SFT_Metric - Metric classes for evaluation (ComputeAccuracy, ComputeSimilarity)
- Hiyouga_LLaMA_Factory_PT_Workflow - Pre-training workflow, typically the preceding stage
- Hiyouga_LLaMA_Factory_PPO_Workflow - PPO workflow, typically the following RLHF stage
- Hiyouga_LLaMA_Factory_Data_Collator - SFTDataCollatorWith4DAttentionMask used for data collation