Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Open r1 SFTTrainer Usage

From Leeroopedia


Template:Metadata

Principle:Huggingface_Open_r1_Supervised_Fine_Tuning

Overview

Wrapper for HuggingFace TRL's SFTTrainer for supervised fine-tuning of language models, configured with Open-R1's custom SFTConfig and callback system.

Description

This is a Wrapper Doc. Open-R1 uses TRL's SFTTrainer with its own extended SFTConfig that adds benchmark callbacks, Hub revision pushing, chat template support, and W&B logging integration. The trainer is initialized with model, tokenizer, dataset, optional PEFT config (for LoRA), and custom callbacks. Training supports checkpoint resumption and distributed training via DeepSpeed ZeRO or FSDP.

Usage

Use when running the Open-R1 SFT distillation workflow. The trainer handles the full training loop including logging, checkpointing, and evaluation.

Code Reference

Source Location

Repository File Lines
open-r1 src/open_r1/sft.py L101-125

Signature

trainer = SFTTrainer(
    model=model,                          # AutoModelForCausalLM
    args=training_args,                    # SFTConfig
    train_dataset=dataset[train_split],    # Dataset
    eval_dataset=dataset[test_split],      # Optional[Dataset]
    processing_class=tokenizer,            # PreTrainedTokenizer
    peft_config=get_peft_config(model_args),  # Optional[PeftConfig]
    callbacks=get_callbacks(training_args, model_args),  # list[TrainerCallback]
)
train_result = trainer.train(resume_from_checkpoint=checkpoint)

Import

from trl import SFTTrainer, get_peft_config, setup_chat_format

External Reference

HuggingFace TRL SFTTrainer Documentation

I/O Contract

Inputs

Parameter Type Required Description
model AutoModelForCausalLM Yes The pretrained language model to fine-tune.
args SFTConfig Yes Training configuration including learning rate, batch size, gradient accumulation steps, output directory, logging settings, and Open-R1 extensions.
train_dataset Dataset Yes The training split of the instruction-response dataset.
eval_dataset Dataset No The evaluation split for computing validation metrics during training.
processing_class PreTrainedTokenizer Yes Tokenizer used to encode inputs and decode outputs. Must match the model's vocabulary.
peft_config PeftConfig No Configuration for Parameter-Efficient Fine-Tuning (e.g., LoRA). When provided, only adapter weights are trained.
callbacks list[TrainerCallback] No Custom callbacks for benchmark evaluation, W&B logging, and Hub revision pushing.

Outputs

Output Description
TrainOutput Object containing global_step, training_loss, and metrics dictionary.
Checkpoints Model weights, optimizer state, and scheduler state saved to output_dir at intervals defined by save_steps.
Logs Training metrics written to logging_dir and optionally to W&B.

Usage Examples

The following shows the full SFT training setup as implemented in sft.py:

from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
from datasets import load_dataset
from trl import SFTTrainer, get_peft_config, setup_chat_format

# 1. Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    torch_dtype=model_args.torch_dtype,
    attn_implementation=model_args.attn_implementation,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_args.model_name_or_path,
    use_fast=True,
)

# 2. Apply chat template if the model lacks one
if tokenizer.chat_template is None:
    model, tokenizer = setup_chat_format(model, tokenizer)

# 3. Load dataset
dataset = load_dataset(training_args.dataset_name)

# 4. Initialize trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    processing_class=tokenizer,
    peft_config=get_peft_config(model_args),
    callbacks=get_callbacks(training_args, model_args),
)

# 5. Train with optional checkpoint resumption
checkpoint = get_last_checkpoint(training_args.output_dir)
train_result = trainer.train(resume_from_checkpoint=checkpoint)

# 6. Save final model and metrics
trainer.save_model(training_args.output_dir)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

Related Pages

Implements Principle

Principle:Huggingface_Open_r1_Supervised_Fine_Tuning

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment