Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Trl TrlParser SFTConfig

From Leeroopedia


Knowledge Sources
Domains NLP, Training
Last Updated 2026-02-06 17:00 GMT

Overview

Concrete tool for parsing CLI arguments and YAML configuration files into typed dataclass instances for SFT training, provided by the TRL library.

Description

TrlParser is a subclass of transformers.HfArgumentParser that adds support for YAML configuration files (via the --config flag) and environment-variable injection (via the env block in YAML). It accepts one or more dataclass types, unions their fields into a single argument namespace, and produces a tuple of fully populated dataclass instances.

SFTConfig is a dataclass that subclasses transformers.TrainingArguments and adds SFT-specific fields such as max_length, packing, completion_only_loss, assistant_only_loss, loss_type, and activation_offloading. It also overrides several TrainingArguments defaults (e.g., gradient_checkpointing=True, logging_steps=10).

In the standard SFT script, four dataclass types are composed together: ScriptArguments, SFTConfig, ModelConfig, and DatasetMixtureConfig.

Usage

Use TrlParser and SFTConfig whenever you need to:

  • Launch SFT training from the command line with typed argument validation.
  • Load training configurations from a YAML file for reproducibility.
  • Programmatically construct configuration objects for sweeps or notebook-based training.

Code Reference

Source Location

  • Repository: TRL
  • File: trl/scripts/utils.py (TrlParser, lines 241-389)
  • File: trl/trainer/sft_config.py (SFTConfig, lines 21-279)
  • File: trl/scripts/utils.py (ScriptArguments, lines 154-211; DatasetMixtureConfig, lines 90-152)

Signature

class TrlParser(HfArgumentParser):
    def __init__(
        self,
        dataclass_types: DataClassType | Iterable[DataClassType] | None = None,
        **kwargs,
    ):
        ...

    def parse_args_and_config(
        self,
        args: Iterable[str] | None = None,
        return_remaining_strings: bool = False,
        fail_with_unknown_args: bool = True,
    ) -> tuple[DataClass, ...]:
        ...


@dataclass
class SFTConfig(TrainingArguments):
    # Overridden defaults
    learning_rate: float = 2e-5
    logging_steps: float = 10
    gradient_checkpointing: bool = True
    bf16: bool | None = None  # auto-detected from fp16 setting

    # Model parameters
    model_init_kwargs: dict[str, Any] | None = None
    chat_template_path: str | None = None

    # Data preprocessing parameters
    dataset_text_field: str = "text"
    dataset_kwargs: dict[str, Any] | None = None
    dataset_num_proc: int | None = None
    eos_token: str | None = None
    pad_token: str | None = None
    max_length: int | None = 1024
    shuffle_dataset: bool = False
    packing: bool = False
    packing_strategy: str = "bfd"
    padding_free: bool = False
    pad_to_multiple_of: int | None = None
    eval_packing: bool | None = None

    # Training parameters
    completion_only_loss: bool | None = None
    assistant_only_loss: bool = False
    loss_type: str = "nll"
    activation_offloading: bool = False

Import

from trl import TrlParser, SFTConfig
from trl import ScriptArguments, ModelConfig, DatasetMixtureConfig

I/O Contract

Inputs

Name Type Required Description
dataclass_types Iterable[DataClassType] Yes One or more dataclass types whose fields become CLI/YAML arguments
args None No CLI arguments; defaults to sys.argv[1:]
--config (CLI) str No Path to YAML configuration file
return_remaining_strings bool No If True, return unrecognized arguments instead of raising an error

Outputs

Name Type Description
result tuple[DataClass, ...] A tuple of populated dataclass instances, one per type in dataclass_types. In the SFT script this is (ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig).

Usage Examples

Basic Usage (CLI)

# sft_train.py
from trl import TrlParser, SFTConfig, SFTTrainer
from trl import ScriptArguments, ModelConfig, DatasetMixtureConfig

parser = TrlParser(
    dataclass_types=(ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig)
)
script_args, training_args, model_args, dataset_args = parser.parse_args_and_config(
    return_remaining_strings=False
)

# training_args is a fully populated SFTConfig instance
print(training_args.max_length)       # 1024 (default)
print(training_args.packing)          # False (default)
print(model_args.model_name_or_path)  # set via CLI or YAML

YAML Configuration

# config.yaml
env:
  WANDB_PROJECT: my-sft-experiment

model_name_or_path: Qwen/Qwen2-0.5B
dataset_name: trl-lib/Capybara
learning_rate: 2.0e-5
num_train_epochs: 1
packing: true
max_length: 2048
per_device_train_batch_size: 2
gradient_accumulation_steps: 8
output_dir: Qwen2-0.5B-SFT
use_peft: true
lora_r: 32
lora_alpha: 16
python sft_train.py --config config.yaml

Programmatic Usage

from trl import SFTConfig

training_args = SFTConfig(
    output_dir="./output",
    max_length=2048,
    packing=True,
    completion_only_loss=True,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    learning_rate=2e-5,
)

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment