Implementation:Huggingface Trl PPO Model Loading Pattern
Appearance
| Property | Value |
|---|---|
| Implementation Name | PPO Model Loading Pattern |
| Technology | Huggingface TRL, Transformers |
| Type | Pattern Doc |
| Workflow | PPO RLHF Training |
| Principle | Principle:Huggingface_Trl_PPO_Multi_Model_Loading |
Overview
Description
The PPO model loading pattern loads four models required for the full RLHF pipeline: a tokenizer with left padding, a value model and reward model (both sequence classifiers with num_labels=1), a policy (causal language model), and a reference policy (frozen copy of the policy). When PEFT is enabled, the reference policy is omitted since the base model weights serve as the reference.
Usage
This pattern is implemented in the PPO training script. The models are loaded independently and then passed to PPOTrainer.
Code Reference
Source Location
examples/scripts/ppo/ppo.py lines 97-121
Pattern
# Tokenizer with left-side padding for generation
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
padding_side="left",
trust_remote_code=model_args.trust_remote_code,
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# Value model: initialized from reward model checkpoint
value_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path,
trust_remote_code=model_args.trust_remote_code,
num_labels=1,
**model_kwargs,
)
# Reward model: frozen scoring function
reward_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path,
trust_remote_code=model_args.trust_remote_code,
num_labels=1,
**model_kwargs,
)
# Policy: the causal LM to be optimized
policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
# Reference policy: frozen copy for KL divergence (skipped when using PEFT)
peft_config = get_peft_config(model_args)
if peft_config is None:
ref_policy = AutoModelForCausalLM.from_pretrained(
training_args.sft_model_path,
trust_remote_code=model_args.trust_remote_code,
**model_kwargs,
)
else:
ref_policy = None
Import
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from trl import ModelConfig, get_peft_config, get_quantization_config, get_kbit_device_map
I/O Contract
Inputs
| Parameter | Type | Source | Description |
|---|---|---|---|
| model_args.model_name_or_path | str | ModelConfig | Base model identifier (for tokenizer) |
| training_args.sft_model_path | str | PPOConfig | Path to the supervised fine-tuned model (for policy and ref_policy) |
| training_args.reward_model_path | str | PPOConfig | Path to the trained reward model (for reward_model and value_model) |
| model_kwargs | dict | Computed | Contains revision, attn_implementation, dtype, and optionally quantization_config and device_map |
Outputs
| Output | Type | Description |
|---|---|---|
| tokenizer | PreTrainedTokenizerBase | Tokenizer with padding_side="left" and explicit pad_token |
| policy | AutoModelForCausalLM | Trainable causal language model (actor) |
| ref_policy | AutoModelForCausalLM or None | Frozen reference policy; None when using PEFT |
| reward_model | AutoModelForSequenceClassification | Frozen reward scoring model (num_labels=1) |
| value_model | AutoModelForSequenceClassification | Trainable value estimator (critic, num_labels=1) |
Usage Examples
Standard Four-Model Loading
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
sft_path = "my-sft-model"
reward_path = "my-reward-model"
tokenizer = AutoTokenizer.from_pretrained(sft_path, padding_side="left")
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
model_kwargs = {"torch_dtype": torch.bfloat16}
value_model = AutoModelForSequenceClassification.from_pretrained(
reward_path, num_labels=1, **model_kwargs
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
reward_path, num_labels=1, **model_kwargs
)
policy = AutoModelForCausalLM.from_pretrained(sft_path, **model_kwargs)
ref_policy = AutoModelForCausalLM.from_pretrained(sft_path, **model_kwargs)
With PEFT (Three Models)
from peft import LoraConfig
peft_config = LoraConfig(
r=16,
lora_alpha=32,
task_type="CAUSAL_LM",
)
# No reference policy needed when using PEFT
ref_policy = None
# The PPOTrainer will wrap the policy with LoRA adapters
# and use the base model weights as the reference
With Quantization
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model_kwargs = {
"torch_dtype": torch.bfloat16,
"quantization_config": quantization_config,
"device_map": {"": 0},
}
# Load all models with 4-bit quantization
value_model = AutoModelForSequenceClassification.from_pretrained(
reward_path, num_labels=1, **model_kwargs
)
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment