Implementation:Huggingface Trl TrlParser RewardConfig
Appearance
| Property | Value |
|---|---|
| Implementation Name | TrlParser RewardConfig |
| Technology | Huggingface TRL |
| Type | API Doc |
| Workflow | Reward Model Training |
| Principle | Principle:Huggingface_Trl_Reward_Argument_Configuration |
Overview
Description
The RewardConfig dataclass provides all configuration parameters for reward model training in TRL. It extends transformers.TrainingArguments with reward-specific fields such as center_rewards_coefficient, disable_dropout, and max_length. The TrlParser class extends HfArgumentParser to support YAML configuration files alongside command-line arguments.
Usage
RewardConfig can be instantiated directly or parsed from command-line arguments. It is passed as the args parameter to RewardTrainer.
Code Reference
Source Location
- RewardConfig:
trl/trainer/reward_config.pylines 22-185 - TrlParser:
trl/scripts/utils.pylines 241-389
Signature
@dataclass
class RewardConfig(TrainingArguments):
# Overridden defaults from TrainingArguments
learning_rate: float = 1e-4
logging_steps: float = 10
gradient_checkpointing: bool = True
bf16: bool | None = None
# Model control parameters
model_init_kwargs: dict[str, Any] | None = None
chat_template_path: str | None = None
disable_dropout: bool = True
# Data preprocessing parameters
dataset_num_proc: int | None = None
eos_token: str | None = None
pad_token: str | None = None
max_length: int | None = 1024
pad_to_multiple_of: int | None = None
# Training parameters
center_rewards_coefficient: float | None = None
activation_offloading: bool = False
class TrlParser(HfArgumentParser):
def __init__(
self,
dataclass_types: DataClassType | Iterable[DataClassType] | None = None,
**kwargs,
):
def parse_args_and_config(
self,
args: Iterable[str] | None = None,
return_remaining_strings: bool = False,
fail_with_unknown_args: bool = True,
) -> tuple[DataClass, ...]:
Import
from trl import RewardConfig
from trl import TrlParser
I/O Contract
Inputs
| Parameter | Type | Default | Description |
|---|---|---|---|
| learning_rate | float | 1e-4 | Initial learning rate for AdamW optimizer |
| max_length | int or None | 1024 | Maximum tokenized sequence length; samples exceeding this are filtered |
| center_rewards_coefficient | float or None | None | Coefficient for mean-zero reward regularization (recommended: 0.01) |
| disable_dropout | bool | True | Whether to disable all dropout layers in the model |
| activation_offloading | bool | False | Whether to offload activations to CPU during training |
| model_init_kwargs | dict or None | None | Keyword arguments for AutoModelForSequenceClassification.from_pretrained |
| gradient_checkpointing | bool | True | Whether to use gradient checkpointing |
| dataset_num_proc | int or None | None | Number of processes for dataset preprocessing |
| pad_to_multiple_of | int or None | None | Pad sequences to a multiple of this value |
Outputs
| Output | Type | Description |
|---|---|---|
| RewardConfig instance | RewardConfig | Configured training arguments for RewardTrainer |
| Parsed dataclass tuple | tuple[DataClass, ...] | When using TrlParser, returns all parsed dataclass instances |
Usage Examples
Direct Instantiation
from trl import RewardConfig
config = RewardConfig(
output_dir="reward-model-output",
learning_rate=1e-4,
per_device_train_batch_size=8,
max_length=512,
center_rewards_coefficient=0.01,
disable_dropout=True,
num_train_epochs=1,
gradient_checkpointing=True,
)
Command-Line Parsing with TrlParser
from trl import TrlParser, RewardConfig, ScriptArguments, ModelConfig
parser = TrlParser(dataclass_types=(ScriptArguments, RewardConfig, ModelConfig))
script_args, training_args, model_args = parser.parse_args_and_config()
YAML Configuration File
# reward_config.yaml
output_dir: reward-model-output
learning_rate: 1e-4
per_device_train_batch_size: 8
max_length: 512
center_rewards_coefficient: 0.01
disable_dropout: true
num_train_epochs: 1
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment