Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Trl TrlParser DPOConfig

From Leeroopedia


Knowledge Sources
Domains NLP, RLHF
Last Updated 2026-02-06 17:00 GMT

Overview

Concrete tool for parsing command-line arguments and YAML configurations into DPO training hyperparameters, provided by the TRL library.

Description

DPOConfig is a dataclass extending Hugging Face's TrainingArguments with DPO-specific parameters. It defines all hyperparameters needed for Direct Preference Optimization training, including the loss type, beta (KL penalty), label smoothing, reference model management, and sequence length constraints. Default values are tuned for common DPO use cases (e.g., learning_rate=1e-6, gradient_checkpointing=True, bf16=True when fp16 is not set).

TrlParser is a subclass of HfArgumentParser that adds support for YAML configuration files via the --config flag. It parses command-line arguments and config files into tuples of dataclass instances, with command-line arguments overriding YAML values. It also supports setting environment variables via the env key in YAML configs.

Usage

Use DPOConfig and TrlParser when:

  • Setting up a DPO training run from a script or CLI
  • Loading hyperparameters from a YAML configuration file
  • Overriding default training arguments for preference optimization
  • Combining DPO-specific parameters with standard Hugging Face training arguments
  • Using the trl CLI tool (trl dpo)

Code Reference

Source Location

  • Repository: TRL
  • File (DPOConfig): trl/trainer/dpo_config.py (lines 59-744)
  • File (TrlParser): trl/scripts/utils.py (lines 241-389)

Signature

@dataclass
class DPOConfig(TrainingArguments):
    # Overridden defaults from TrainingArguments
    learning_rate: float = 1e-6
    logging_steps: float = 10
    gradient_checkpointing: bool = True
    bf16: bool | None = None  # defaults to True if fp16 is not set

    # Model and reference model
    model_init_kwargs: dict[str, Any] | None = None
    disable_dropout: bool = True

    # Data preprocessing
    dataset_num_proc: int | None = None
    pad_token: str | None = None
    max_length: int | None = 1024
    truncation_mode: str = "keep_end"
    padding_free: bool = False
    precompute_ref_log_probs: bool = False
    precompute_ref_batch_size: int | None = None

    # Training
    loss_type: list[str] = field(default_factory=lambda: ["sigmoid"])
    beta: float = 0.1
    f_divergence_type: FDivergenceType | str = FDivergenceType.REVERSE_KL
    f_alpha_divergence_coef: float = 1.0
    label_smoothing: float = 0.0
    use_weighting: bool = False
    ld_alpha: float | None = None
    discopop_tau: float = 0.05
    loss_weights: list[float] | None = None
    sync_ref_model: bool = False
    ref_model_mixup_alpha: float = 0.6
    ref_model_sync_steps: int = 512
class TrlParser(HfArgumentParser):
    def __init__(
        self,
        dataclass_types: DataClassType | Iterable[DataClassType] | None = None,
        **kwargs,
    ):

    def parse_args_and_config(
        self,
        args: Iterable[str] | None = None,
        return_remaining_strings: bool = False,
        fail_with_unknown_args: bool = True,
    ) -> tuple[DataClass, ...]:

    def set_defaults_with_config(self, **kwargs) -> list[str]:

Import

from trl import DPOConfig
from trl import TrlParser
from trl import ScriptArguments, ModelConfig, DatasetMixtureConfig

I/O Contract

Inputs

Name Type Required Description
beta float No (default: 0.1) KL penalty strength controlling deviation from reference model
loss_type list[str] No (default: ["sigmoid"]) Loss variant(s): sigmoid, hinge, ipo, exo_pair, nca_pair, robust, bco_pair, sppo_hard, aot, aot_pair, discopop, apo_zero, apo_down, sft
max_length int or None No (default: 1024) Maximum total sequence length (prompt + completion)
precompute_ref_log_probs bool No (default: False) Whether to precompute reference model log probabilities to save GPU memory during training
label_smoothing float No (default: 0.0) Robust DPO label smoothing, between 0.0 and 0.5
sync_ref_model bool No (default: False) Whether to synchronize reference with policy (TR-DPO)
loss_weights list[float] or None No Weights for each loss type when combining multiple losses (MPO)
--config (CLI) str No Path to YAML configuration file for TrlParser

Outputs

Name Type Description
parsed args tuple[ScriptArguments, DPOConfig, ModelConfig, DatasetMixtureConfig] Tuple of dataclass instances parsed from CLI/YAML by TrlParser
DPOConfig instance DPOConfig Complete configuration for DPOTrainer with all hyperparameters resolved

Usage Examples

# Example 1: Direct instantiation of DPOConfig
from trl import DPOConfig

training_args = DPOConfig(
    output_dir="./dpo-output",
    beta=0.1,
    loss_type=["sigmoid"],
    max_length=1024,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=5e-7,
    num_train_epochs=1,
    precompute_ref_log_probs=False,
    eval_strategy="steps",
    eval_steps=50,
)
# Example 2: Parsing from CLI with TrlParser
from trl import TrlParser, DPOConfig, ScriptArguments, ModelConfig, DatasetMixtureConfig

parser = TrlParser(
    dataclass_types=(ScriptArguments, DPOConfig, ModelConfig, DatasetMixtureConfig)
)
script_args, training_args, model_args, dataset_args = parser.parse_args_and_config()
# Example 3: YAML config file (config.yaml)
# Run with: python dpo.py --config config.yaml
model_name_or_path: Qwen/Qwen2-0.5B-Instruct
dataset_name: trl-lib/ultrafeedback_binarized
beta: 0.1
loss_type:
  - sigmoid
max_length: 1024
learning_rate: 5.0e-7
num_train_epochs: 1
per_device_train_batch_size: 2
gradient_accumulation_steps: 8
output_dir: Qwen2-0.5B-DPO

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment