Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Trl TrlParser RewardConfig

From Leeroopedia


Property Value
Implementation Name TrlParser RewardConfig
Technology Huggingface TRL
Type API Doc
Workflow Reward Model Training
Principle Principle:Huggingface_Trl_Reward_Argument_Configuration

Overview

Description

The RewardConfig dataclass provides all configuration parameters for reward model training in TRL. It extends transformers.TrainingArguments with reward-specific fields such as center_rewards_coefficient, disable_dropout, and max_length. The TrlParser class extends HfArgumentParser to support YAML configuration files alongside command-line arguments.

Usage

RewardConfig can be instantiated directly or parsed from command-line arguments. It is passed as the args parameter to RewardTrainer.

Code Reference

Source Location

  • RewardConfig: trl/trainer/reward_config.py lines 22-185
  • TrlParser: trl/scripts/utils.py lines 241-389

Signature

@dataclass
class RewardConfig(TrainingArguments):
    # Overridden defaults from TrainingArguments
    learning_rate: float = 1e-4
    logging_steps: float = 10
    gradient_checkpointing: bool = True
    bf16: bool | None = None

    # Model control parameters
    model_init_kwargs: dict[str, Any] | None = None
    chat_template_path: str | None = None
    disable_dropout: bool = True

    # Data preprocessing parameters
    dataset_num_proc: int | None = None
    eos_token: str | None = None
    pad_token: str | None = None
    max_length: int | None = 1024
    pad_to_multiple_of: int | None = None

    # Training parameters
    center_rewards_coefficient: float | None = None
    activation_offloading: bool = False
class TrlParser(HfArgumentParser):
    def __init__(
        self,
        dataclass_types: DataClassType | Iterable[DataClassType] | None = None,
        **kwargs,
    ):

    def parse_args_and_config(
        self,
        args: Iterable[str] | None = None,
        return_remaining_strings: bool = False,
        fail_with_unknown_args: bool = True,
    ) -> tuple[DataClass, ...]:

Import

from trl import RewardConfig
from trl import TrlParser

I/O Contract

Inputs

Parameter Type Default Description
learning_rate float 1e-4 Initial learning rate for AdamW optimizer
max_length int or None 1024 Maximum tokenized sequence length; samples exceeding this are filtered
center_rewards_coefficient float or None None Coefficient for mean-zero reward regularization (recommended: 0.01)
disable_dropout bool True Whether to disable all dropout layers in the model
activation_offloading bool False Whether to offload activations to CPU during training
model_init_kwargs dict or None None Keyword arguments for AutoModelForSequenceClassification.from_pretrained
gradient_checkpointing bool True Whether to use gradient checkpointing
dataset_num_proc int or None None Number of processes for dataset preprocessing
pad_to_multiple_of int or None None Pad sequences to a multiple of this value

Outputs

Output Type Description
RewardConfig instance RewardConfig Configured training arguments for RewardTrainer
Parsed dataclass tuple tuple[DataClass, ...] When using TrlParser, returns all parsed dataclass instances

Usage Examples

Direct Instantiation

from trl import RewardConfig

config = RewardConfig(
    output_dir="reward-model-output",
    learning_rate=1e-4,
    per_device_train_batch_size=8,
    max_length=512,
    center_rewards_coefficient=0.01,
    disable_dropout=True,
    num_train_epochs=1,
    gradient_checkpointing=True,
)

Command-Line Parsing with TrlParser

from trl import TrlParser, RewardConfig, ScriptArguments, ModelConfig

parser = TrlParser(dataclass_types=(ScriptArguments, RewardConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()

YAML Configuration File

# reward_config.yaml
output_dir: reward-model-output
learning_rate: 1e-4
per_device_train_batch_size: 8
max_length: 512
center_rewards_coefficient: 0.01
disable_dropout: true
num_train_epochs: 1

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment