Implementation:Huggingface Trl TrlParser SFTConfig
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training |
| Last Updated | 2026-02-06 17:00 GMT |
Overview
Concrete tool for parsing CLI arguments and YAML configuration files into typed dataclass instances for SFT training, provided by the TRL library.
Description
TrlParser is a subclass of transformers.HfArgumentParser that adds support for YAML configuration files (via the --config flag) and environment-variable injection (via the env block in YAML). It accepts one or more dataclass types, unions their fields into a single argument namespace, and produces a tuple of fully populated dataclass instances.
SFTConfig is a dataclass that subclasses transformers.TrainingArguments and adds SFT-specific fields such as max_length, packing, completion_only_loss, assistant_only_loss, loss_type, and activation_offloading. It also overrides several TrainingArguments defaults (e.g., gradient_checkpointing=True, logging_steps=10).
In the standard SFT script, four dataclass types are composed together: ScriptArguments, SFTConfig, ModelConfig, and DatasetMixtureConfig.
Usage
Use TrlParser and SFTConfig whenever you need to:
- Launch SFT training from the command line with typed argument validation.
- Load training configurations from a YAML file for reproducibility.
- Programmatically construct configuration objects for sweeps or notebook-based training.
Code Reference
Source Location
- Repository: TRL
- File:
trl/scripts/utils.py(TrlParser, lines 241-389) - File:
trl/trainer/sft_config.py(SFTConfig, lines 21-279) - File:
trl/scripts/utils.py(ScriptArguments, lines 154-211; DatasetMixtureConfig, lines 90-152)
Signature
class TrlParser(HfArgumentParser):
def __init__(
self,
dataclass_types: DataClassType | Iterable[DataClassType] | None = None,
**kwargs,
):
...
def parse_args_and_config(
self,
args: Iterable[str] | None = None,
return_remaining_strings: bool = False,
fail_with_unknown_args: bool = True,
) -> tuple[DataClass, ...]:
...
@dataclass
class SFTConfig(TrainingArguments):
# Overridden defaults
learning_rate: float = 2e-5
logging_steps: float = 10
gradient_checkpointing: bool = True
bf16: bool | None = None # auto-detected from fp16 setting
# Model parameters
model_init_kwargs: dict[str, Any] | None = None
chat_template_path: str | None = None
# Data preprocessing parameters
dataset_text_field: str = "text"
dataset_kwargs: dict[str, Any] | None = None
dataset_num_proc: int | None = None
eos_token: str | None = None
pad_token: str | None = None
max_length: int | None = 1024
shuffle_dataset: bool = False
packing: bool = False
packing_strategy: str = "bfd"
padding_free: bool = False
pad_to_multiple_of: int | None = None
eval_packing: bool | None = None
# Training parameters
completion_only_loss: bool | None = None
assistant_only_loss: bool = False
loss_type: str = "nll"
activation_offloading: bool = False
Import
from trl import TrlParser, SFTConfig
from trl import ScriptArguments, ModelConfig, DatasetMixtureConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dataclass_types | Iterable[DataClassType] | Yes | One or more dataclass types whose fields become CLI/YAML arguments |
| args | None | No | CLI arguments; defaults to sys.argv[1:]
|
| --config (CLI) | str |
No | Path to YAML configuration file |
| return_remaining_strings | bool |
No | If True, return unrecognized arguments instead of raising an error |
Outputs
| Name | Type | Description |
|---|---|---|
| result | tuple[DataClass, ...] |
A tuple of populated dataclass instances, one per type in dataclass_types. In the SFT script this is (ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig).
|
Usage Examples
Basic Usage (CLI)
# sft_train.py
from trl import TrlParser, SFTConfig, SFTTrainer
from trl import ScriptArguments, ModelConfig, DatasetMixtureConfig
parser = TrlParser(
dataclass_types=(ScriptArguments, SFTConfig, ModelConfig, DatasetMixtureConfig)
)
script_args, training_args, model_args, dataset_args = parser.parse_args_and_config(
return_remaining_strings=False
)
# training_args is a fully populated SFTConfig instance
print(training_args.max_length) # 1024 (default)
print(training_args.packing) # False (default)
print(model_args.model_name_or_path) # set via CLI or YAML
YAML Configuration
# config.yaml
env:
WANDB_PROJECT: my-sft-experiment
model_name_or_path: Qwen/Qwen2-0.5B
dataset_name: trl-lib/Capybara
learning_rate: 2.0e-5
num_train_epochs: 1
packing: true
max_length: 2048
per_device_train_batch_size: 2
gradient_accumulation_steps: 8
output_dir: Qwen2-0.5B-SFT
use_peft: true
lora_r: 32
lora_alpha: 16
python sft_train.py --config config.yaml
Programmatic Usage
from trl import SFTConfig
training_args = SFTConfig(
output_dir="./output",
max_length=2048,
packing=True,
completion_only_loss=True,
num_train_epochs=3,
per_device_train_batch_size=4,
learning_rate=2e-5,
)