Implementation:Huggingface Open r1 SFTTrainer Usage
Principle:Huggingface_Open_r1_Supervised_Fine_Tuning
Overview
Wrapper for HuggingFace TRL's SFTTrainer for supervised fine-tuning of language models, configured with Open-R1's custom SFTConfig and callback system.
Description
This is a Wrapper Doc. Open-R1 uses TRL's SFTTrainer with its own extended SFTConfig that adds benchmark callbacks, Hub revision pushing, chat template support, and W&B logging integration. The trainer is initialized with model, tokenizer, dataset, optional PEFT config (for LoRA), and custom callbacks. Training supports checkpoint resumption and distributed training via DeepSpeed ZeRO or FSDP.
Usage
Use when running the Open-R1 SFT distillation workflow. The trainer handles the full training loop including logging, checkpointing, and evaluation.
Code Reference
Source Location
| Repository | File | Lines |
|---|---|---|
open-r1 |
src/open_r1/sft.py |
L101-125 |
Signature
trainer = SFTTrainer(
model=model, # AutoModelForCausalLM
args=training_args, # SFTConfig
train_dataset=dataset[train_split], # Dataset
eval_dataset=dataset[test_split], # Optional[Dataset]
processing_class=tokenizer, # PreTrainedTokenizer
peft_config=get_peft_config(model_args), # Optional[PeftConfig]
callbacks=get_callbacks(training_args, model_args), # list[TrainerCallback]
)
train_result = trainer.train(resume_from_checkpoint=checkpoint)
Import
from trl import SFTTrainer, get_peft_config, setup_chat_format
External Reference
HuggingFace TRL SFTTrainer Documentation
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
model |
AutoModelForCausalLM |
Yes | The pretrained language model to fine-tune. |
args |
SFTConfig |
Yes | Training configuration including learning rate, batch size, gradient accumulation steps, output directory, logging settings, and Open-R1 extensions. |
train_dataset |
Dataset |
Yes | The training split of the instruction-response dataset. |
eval_dataset |
Dataset |
No | The evaluation split for computing validation metrics during training. |
processing_class |
PreTrainedTokenizer |
Yes | Tokenizer used to encode inputs and decode outputs. Must match the model's vocabulary. |
peft_config |
PeftConfig |
No | Configuration for Parameter-Efficient Fine-Tuning (e.g., LoRA). When provided, only adapter weights are trained. |
callbacks |
list[TrainerCallback] |
No | Custom callbacks for benchmark evaluation, W&B logging, and Hub revision pushing. |
Outputs
| Output | Description |
|---|---|
TrainOutput |
Object containing global_step, training_loss, and metrics dictionary.
|
| Checkpoints | Model weights, optimizer state, and scheduler state saved to output_dir at intervals defined by save_steps.
|
| Logs | Training metrics written to logging_dir and optionally to W&B.
|
Usage Examples
The following shows the full SFT training setup as implemented in sft.py:
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from datasets import load_dataset
from trl import SFTTrainer, get_peft_config, setup_chat_format
# 1. Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
torch_dtype=model_args.torch_dtype,
attn_implementation=model_args.attn_implementation,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
)
# 2. Apply chat template if the model lacks one
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)
# 3. Load dataset
dataset = load_dataset(training_args.dataset_name)
# 4. Initialize trainer
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
callbacks=get_callbacks(training_args, model_args),
)
# 5. Train with optional checkpoint resumption
checkpoint = get_last_checkpoint(training_args.output_dir)
train_result = trainer.train(resume_from_checkpoint=checkpoint)
# 6. Save final model and metrics
trainer.save_model(training_args.output_dir)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
Related Pages
Implements Principle
Principle:Huggingface_Open_r1_Supervised_Fine_Tuning