Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Trl PPOTrainer Init

From Leeroopedia


Property Value
Implementation Name PPOTrainer Init
Technology Huggingface TRL
Type API Doc
Workflow PPO RLHF Training
Principle Principle:Huggingface_Trl_PPO_Trainer_Initialization

Overview

Description

The PPOTrainer.__init__ method assembles the complete PPO training pipeline from four separate models. It wraps the policy and value models in a PolicyAndValueWrapper for joint training, computes the batch size hierarchy, sets up the Accelerator for distributed training, creates the optimizer and scheduler, and prepares the DataLoaders. For DeepSpeed setups, it additionally prepares the reward and reference models through prepare_deepspeed.

Usage

All four models, the tokenizer, and the datasets must be pre-loaded before constructing the PPOTrainer. The trainer handles all internal wiring and distributed training setup.

Code Reference

Source Location

  • PolicyAndValueWrapper: trl/experimental/ppo/ppo_trainer.py lines 279-291
  • PPOTrainer.__init__: trl/experimental/ppo/ppo_trainer.py lines 345-559

Signature

class PolicyAndValueWrapper(nn.Module):
    def __init__(self, policy, value_model) -> None:
        super().__init__()
        self.policy = policy
        self.value_model = value_model
        self.critic_backbone = getattr(value_model, value_model.base_model_prefix)
        self.is_gradient_checkpointing = policy.is_gradient_checkpointing

    def forward(self, **kwargs):
        output = self.critic_backbone(**kwargs)
        logits = self.value_model.score(output.hidden_states[-1])
        return self.policy(**kwargs), logits
class PPOTrainer(BaseTrainer):
    def __init__(
        self,
        args: PPOConfig,
        processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin,
        model: nn.Module,
        ref_model: nn.Module | None,
        reward_model: nn.Module,
        train_dataset: Dataset,
        value_model: nn.Module,
        data_collator: DataCollatorWithPadding | None = None,
        eval_dataset: Dataset | dict[str, Dataset] | None = None,
        optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        callbacks: list[TrainerCallback] | None = None,
        peft_config: "PeftConfig | None" = None,
    ) -> None:

Import

from trl.experimental.ppo import PPOTrainer, PPOConfig

I/O Contract

Inputs

Parameter Type Default Description
args PPOConfig (required) PPO training configuration
processing_class PreTrainedTokenizerBase (required) Tokenizer with left-side padding
model nn.Module (required) Policy model (AutoModelForCausalLM)
ref_model nn.Module or None (required) Reference policy; None when using PEFT
reward_model nn.Module (required) Frozen reward model (AutoModelForSequenceClassification)
train_dataset Dataset (required) Tokenized prompt dataset with input_ids column
value_model nn.Module (required) Value model (AutoModelForSequenceClassification)
data_collator DataCollatorWithPadding or None None Batch collator; defaults to DataCollatorWithPadding
eval_dataset Dataset or dict or None None Evaluation dataset for generation quality checks
peft_config PeftConfig or None None PEFT configuration for parameter-efficient training

Internal State After Initialization

Attribute Type Description
self.model PolicyAndValueWrapper Combined policy + value wrapper (accelerator-prepared)
self.ref_model nn.Module or None Reference model (DeepSpeed-prepared or on device)
self.reward_model nn.Module Reward model (DeepSpeed-prepared or on device)
self.optimizer Optimizer AdamW optimizer for policy + value parameters
self.lr_scheduler LRScheduler Learning rate scheduler
self.dataloader DataLoader Accelerator-prepared training DataLoader
self.eval_dataloader DataLoader Accelerator-prepared evaluation DataLoader
self.accelerator Accelerator Huggingface Accelerator instance

Usage Examples

Standard Initialization

from trl.experimental.ppo import PPOTrainer, PPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification

config = PPOConfig(
    output_dir="ppo-output",
    per_device_train_batch_size=64,
    total_episodes=10000,
    sft_model_path="my-sft-model",
    reward_model_path="my-reward-model",
)

tokenizer = AutoTokenizer.from_pretrained("my-sft-model", padding_side="left")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})

policy = AutoModelForCausalLM.from_pretrained("my-sft-model")
ref_policy = AutoModelForCausalLM.from_pretrained("my-sft-model")
reward_model = AutoModelForSequenceClassification.from_pretrained("my-reward-model", num_labels=1)
value_model = AutoModelForSequenceClassification.from_pretrained("my-reward-model", num_labels=1)

trainer = PPOTrainer(
    args=config,
    processing_class=tokenizer,
    model=policy,
    ref_model=ref_policy,
    reward_model=reward_model,
    value_model=value_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

With PEFT

from peft import LoraConfig

peft_config = LoraConfig(r=16, lora_alpha=32, task_type="CAUSAL_LM")

trainer = PPOTrainer(
    args=config,
    processing_class=tokenizer,
    model=policy,
    ref_model=None,  # No separate reference model needed
    reward_model=reward_model,
    value_model=value_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
)
# trainer.is_peft_model == True
# trainer.ref_model is None (adapter disabling provides reference behavior)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment