Implementation:Huggingface Trl PPOTrainer Init
| Property | Value |
|---|---|
| Implementation Name | PPOTrainer Init |
| Technology | Huggingface TRL |
| Type | API Doc |
| Workflow | PPO RLHF Training |
| Principle | Principle:Huggingface_Trl_PPO_Trainer_Initialization |
Overview
Description
The PPOTrainer.__init__ method assembles the complete PPO training pipeline from four separate models. It wraps the policy and value models in a PolicyAndValueWrapper for joint training, computes the batch size hierarchy, sets up the Accelerator for distributed training, creates the optimizer and scheduler, and prepares the DataLoaders. For DeepSpeed setups, it additionally prepares the reward and reference models through prepare_deepspeed.
Usage
All four models, the tokenizer, and the datasets must be pre-loaded before constructing the PPOTrainer. The trainer handles all internal wiring and distributed training setup.
Code Reference
Source Location
- PolicyAndValueWrapper:
trl/experimental/ppo/ppo_trainer.pylines 279-291 - PPOTrainer.__init__:
trl/experimental/ppo/ppo_trainer.pylines 345-559
Signature
class PolicyAndValueWrapper(nn.Module):
def __init__(self, policy, value_model) -> None:
super().__init__()
self.policy = policy
self.value_model = value_model
self.critic_backbone = getattr(value_model, value_model.base_model_prefix)
self.is_gradient_checkpointing = policy.is_gradient_checkpointing
def forward(self, **kwargs):
output = self.critic_backbone(**kwargs)
logits = self.value_model.score(output.hidden_states[-1])
return self.policy(**kwargs), logits
class PPOTrainer(BaseTrainer):
def __init__(
self,
args: PPOConfig,
processing_class: PreTrainedTokenizerBase | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin,
model: nn.Module,
ref_model: nn.Module | None,
reward_model: nn.Module,
train_dataset: Dataset,
value_model: nn.Module,
data_collator: DataCollatorWithPadding | None = None,
eval_dataset: Dataset | dict[str, Dataset] | None = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
callbacks: list[TrainerCallback] | None = None,
peft_config: "PeftConfig | None" = None,
) -> None:
Import
from trl.experimental.ppo import PPOTrainer, PPOConfig
I/O Contract
Inputs
| Parameter | Type | Default | Description |
|---|---|---|---|
| args | PPOConfig | (required) | PPO training configuration |
| processing_class | PreTrainedTokenizerBase | (required) | Tokenizer with left-side padding |
| model | nn.Module | (required) | Policy model (AutoModelForCausalLM) |
| ref_model | nn.Module or None | (required) | Reference policy; None when using PEFT |
| reward_model | nn.Module | (required) | Frozen reward model (AutoModelForSequenceClassification) |
| train_dataset | Dataset | (required) | Tokenized prompt dataset with input_ids column |
| value_model | nn.Module | (required) | Value model (AutoModelForSequenceClassification) |
| data_collator | DataCollatorWithPadding or None | None | Batch collator; defaults to DataCollatorWithPadding |
| eval_dataset | Dataset or dict or None | None | Evaluation dataset for generation quality checks |
| peft_config | PeftConfig or None | None | PEFT configuration for parameter-efficient training |
Internal State After Initialization
| Attribute | Type | Description |
|---|---|---|
| self.model | PolicyAndValueWrapper | Combined policy + value wrapper (accelerator-prepared) |
| self.ref_model | nn.Module or None | Reference model (DeepSpeed-prepared or on device) |
| self.reward_model | nn.Module | Reward model (DeepSpeed-prepared or on device) |
| self.optimizer | Optimizer | AdamW optimizer for policy + value parameters |
| self.lr_scheduler | LRScheduler | Learning rate scheduler |
| self.dataloader | DataLoader | Accelerator-prepared training DataLoader |
| self.eval_dataloader | DataLoader | Accelerator-prepared evaluation DataLoader |
| self.accelerator | Accelerator | Huggingface Accelerator instance |
Usage Examples
Standard Initialization
from trl.experimental.ppo import PPOTrainer, PPOConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
config = PPOConfig(
output_dir="ppo-output",
per_device_train_batch_size=64,
total_episodes=10000,
sft_model_path="my-sft-model",
reward_model_path="my-reward-model",
)
tokenizer = AutoTokenizer.from_pretrained("my-sft-model", padding_side="left")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
policy = AutoModelForCausalLM.from_pretrained("my-sft-model")
ref_policy = AutoModelForCausalLM.from_pretrained("my-sft-model")
reward_model = AutoModelForSequenceClassification.from_pretrained("my-reward-model", num_labels=1)
value_model = AutoModelForSequenceClassification.from_pretrained("my-reward-model", num_labels=1)
trainer = PPOTrainer(
args=config,
processing_class=tokenizer,
model=policy,
ref_model=ref_policy,
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
With PEFT
from peft import LoraConfig
peft_config = LoraConfig(r=16, lora_alpha=32, task_type="CAUSAL_LM")
trainer = PPOTrainer(
args=config,
processing_class=tokenizer,
model=policy,
ref_model=None, # No separate reference model needed
reward_model=reward_model,
value_model=value_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
)
# trainer.is_peft_model == True
# trainer.ref_model is None (adapter disabling provides reference behavior)
Related Pages
- Principle:Huggingface_Trl_PPO_Trainer_Initialization
- Implementation:Huggingface_Trl_HfArgumentParser_PPOConfig
- Implementation:Huggingface_Trl_PPO_Model_Loading_Pattern
- Implementation:Huggingface_Trl_PPOTrainer_Train
- Environment:Huggingface_Trl_Python_Core_Dependencies
- Environment:Huggingface_Trl_DeepSpeed_Environment