Implementation:Huggingface Trl TrlParser DPOConfig
| Knowledge Sources | |
|---|---|
| Domains | NLP, RLHF |
| Last Updated | 2026-02-06 17:00 GMT |
Overview
Concrete tool for parsing command-line arguments and YAML configurations into DPO training hyperparameters, provided by the TRL library.
Description
DPOConfig is a dataclass extending Hugging Face's TrainingArguments with DPO-specific parameters. It defines all hyperparameters needed for Direct Preference Optimization training, including the loss type, beta (KL penalty), label smoothing, reference model management, and sequence length constraints. Default values are tuned for common DPO use cases (e.g., learning_rate=1e-6, gradient_checkpointing=True, bf16=True when fp16 is not set).
TrlParser is a subclass of HfArgumentParser that adds support for YAML configuration files via the --config flag. It parses command-line arguments and config files into tuples of dataclass instances, with command-line arguments overriding YAML values. It also supports setting environment variables via the env key in YAML configs.
Usage
Use DPOConfig and TrlParser when:
- Setting up a DPO training run from a script or CLI
- Loading hyperparameters from a YAML configuration file
- Overriding default training arguments for preference optimization
- Combining DPO-specific parameters with standard Hugging Face training arguments
- Using the
trlCLI tool (trl dpo)
Code Reference
Source Location
- Repository: TRL
- File (DPOConfig):
trl/trainer/dpo_config.py(lines 59-744) - File (TrlParser):
trl/scripts/utils.py(lines 241-389)
Signature
@dataclass
class DPOConfig(TrainingArguments):
# Overridden defaults from TrainingArguments
learning_rate: float = 1e-6
logging_steps: float = 10
gradient_checkpointing: bool = True
bf16: bool | None = None # defaults to True if fp16 is not set
# Model and reference model
model_init_kwargs: dict[str, Any] | None = None
disable_dropout: bool = True
# Data preprocessing
dataset_num_proc: int | None = None
pad_token: str | None = None
max_length: int | None = 1024
truncation_mode: str = "keep_end"
padding_free: bool = False
precompute_ref_log_probs: bool = False
precompute_ref_batch_size: int | None = None
# Training
loss_type: list[str] = field(default_factory=lambda: ["sigmoid"])
beta: float = 0.1
f_divergence_type: FDivergenceType | str = FDivergenceType.REVERSE_KL
f_alpha_divergence_coef: float = 1.0
label_smoothing: float = 0.0
use_weighting: bool = False
ld_alpha: float | None = None
discopop_tau: float = 0.05
loss_weights: list[float] | None = None
sync_ref_model: bool = False
ref_model_mixup_alpha: float = 0.6
ref_model_sync_steps: int = 512
class TrlParser(HfArgumentParser):
def __init__(
self,
dataclass_types: DataClassType | Iterable[DataClassType] | None = None,
**kwargs,
):
def parse_args_and_config(
self,
args: Iterable[str] | None = None,
return_remaining_strings: bool = False,
fail_with_unknown_args: bool = True,
) -> tuple[DataClass, ...]:
def set_defaults_with_config(self, **kwargs) -> list[str]:
Import
from trl import DPOConfig
from trl import TrlParser
from trl import ScriptArguments, ModelConfig, DatasetMixtureConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| beta | float |
No (default: 0.1) | KL penalty strength controlling deviation from reference model |
| loss_type | list[str] |
No (default: ["sigmoid"]) | Loss variant(s): sigmoid, hinge, ipo, exo_pair, nca_pair, robust, bco_pair, sppo_hard, aot, aot_pair, discopop, apo_zero, apo_down, sft |
| max_length | int or None |
No (default: 1024) | Maximum total sequence length (prompt + completion) |
| precompute_ref_log_probs | bool |
No (default: False) | Whether to precompute reference model log probabilities to save GPU memory during training |
| label_smoothing | float |
No (default: 0.0) | Robust DPO label smoothing, between 0.0 and 0.5 |
| sync_ref_model | bool |
No (default: False) | Whether to synchronize reference with policy (TR-DPO) |
| loss_weights | list[float] or None |
No | Weights for each loss type when combining multiple losses (MPO) |
| --config (CLI) | str |
No | Path to YAML configuration file for TrlParser |
Outputs
| Name | Type | Description |
|---|---|---|
| parsed args | tuple[ScriptArguments, DPOConfig, ModelConfig, DatasetMixtureConfig] |
Tuple of dataclass instances parsed from CLI/YAML by TrlParser |
| DPOConfig instance | DPOConfig |
Complete configuration for DPOTrainer with all hyperparameters resolved |
Usage Examples
# Example 1: Direct instantiation of DPOConfig
from trl import DPOConfig
training_args = DPOConfig(
output_dir="./dpo-output",
beta=0.1,
loss_type=["sigmoid"],
max_length=1024,
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
learning_rate=5e-7,
num_train_epochs=1,
precompute_ref_log_probs=False,
eval_strategy="steps",
eval_steps=50,
)
# Example 2: Parsing from CLI with TrlParser
from trl import TrlParser, DPOConfig, ScriptArguments, ModelConfig, DatasetMixtureConfig
parser = TrlParser(
dataclass_types=(ScriptArguments, DPOConfig, ModelConfig, DatasetMixtureConfig)
)
script_args, training_args, model_args, dataset_args = parser.parse_args_and_config()
# Example 3: YAML config file (config.yaml)
# Run with: python dpo.py --config config.yaml
model_name_or_path: Qwen/Qwen2-0.5B-Instruct
dataset_name: trl-lib/ultrafeedback_binarized
beta: 0.1
loss_type:
- sigmoid
max_length: 1024
learning_rate: 5.0e-7
num_train_epochs: 1
per_device_train_batch_size: 2
gradient_accumulation_steps: 8
output_dir: Qwen2-0.5B-DPO