Overview
Description
The save_model method extracts the policy model from the PolicyAndValueWrapper and saves only the policy weights. The generate_completions method performs evaluation by generating near-greedy responses from evaluation prompts, scoring them with the reward model, and logging the results as formatted tables. Both methods include special handling for DeepSpeed distributed training.
Usage
save_model is called at the end of training or at checkpoint intervals. generate_completions is called periodically during training based on num_sample_generations and also once at the end of training for a final evaluation.
Code Reference
Source Location
- save_model:
trl/experimental/ppo/ppo_trainer.py lines 581-594
- generate_completions:
trl/experimental/ppo/ppo_trainer.py lines 942-1010
- _save_checkpoint:
trl/experimental/ppo/ppo_trainer.py lines 1012-1019
Signature
def save_model(self, output_dir: str | None = None, _internal_call: bool = False):
"""
Save only the policy model weights (not the value model).
Temporarily swaps self.model to self.model.policy for saving,
then restores the wrapper. For DeepSpeed, also swaps self.deepspeed.
Args:
output_dir: Directory to save the model. Defaults to args.output_dir.
_internal_call: Internal flag for Trainer compatibility.
"""
backup_model = self.model
self.model = self.model.policy # save only the policy
if self.is_deepspeed_enabled:
backup_deepspeed = self.deepspeed
self.deepspeed = self.model
super().save_model(output_dir, _internal_call)
self.model = backup_model
if self.is_deepspeed_enabled:
self.deepspeed = backup_deepspeed
def generate_completions(self, sampling: bool = False):
"""
Generate completions from evaluation prompts and log results.
Uses near-greedy decoding (temperature=0.01) for deterministic evaluation.
Scores each generated response with the reward model.
Logs results to console (Rich tables), Weights & Biases, and Comet ML.
Args:
sampling: If True, only process the first batch (used during training).
If False, process all evaluation batches (used at end of training).
"""
def _save_checkpoint(self, model, trial):
"""Save checkpoint with automatic model card generation."""
if self.args.hub_model_id is None:
model_name = Path(self.args.output_dir).name
else:
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name)
super()._save_checkpoint(model, trial)
Import
# These are methods on PPOTrainer, invoked as:
from trl.experimental.ppo import PPOTrainer
trainer = PPOTrainer(...)
trainer.save_model("output_dir")
trainer.generate_completions()
I/O Contract
save_model Inputs
| Parameter |
Type |
Default |
Description
|
| output_dir |
str or None |
None |
Save directory; defaults to args.output_dir
|
save_model Outputs
| Output |
Location |
Description
|
| Policy weights |
output_dir/model.safetensors |
Only the causal LM policy weights (no value model weights)
|
| Tokenizer |
output_dir/ |
Tokenizer files for the policy model
|
| Model card |
output_dir/README.md |
Auto-generated model card
|
generate_completions Inputs
| Parameter |
Type |
Default |
Description
|
| sampling |
bool |
False |
If True, process only the first eval batch; if False, process all
|
generate_completions Outputs
| Output |
Destination |
Description
|
| table["query"] |
Logs |
Decoded prompt text
|
| table["model response"] |
Logs |
Decoded generated response text
|
| table["score"] |
Logs |
Reward model score for each response
|
| Rich table |
Console |
Pretty-printed first 5 samples
|
| wandb.Table |
Weights and Biases |
Full table for interactive exploration
|
| CSV table |
Comet ML |
Full table for experiment tracking
|
Generation Configuration
| Parameter |
Value |
Rationale
|
| temperature |
0.01 + 1e-7 |
Near-greedy for deterministic evaluation
|
| max_new_tokens |
args.response_length |
Same length limit as training
|
| top_k |
0.0 |
Disabled (no top-k filtering)
|
| top_p |
1.0 |
Disabled (no nucleus sampling)
|
| do_sample |
True |
Required for temperature-based sampling
|
Usage Examples
Save After Training
from trl.experimental.ppo import PPOTrainer
# After training completes
trainer.train()
# Save only the policy model
trainer.save_model("ppo-trained-policy")
# Push to Hub if configured
if trainer.args.push_to_hub:
trainer.push_to_hub(dataset_name="my-dataset")
Generate and Evaluate
# During training (sampling=True processes only first batch)
trainer.generate_completions(sampling=True)
# Final evaluation (sampling=False processes all eval batches)
trainer.generate_completions(sampling=False)
Load Saved Policy for Inference
from transformers import AutoModelForCausalLM, AutoTokenizer
# The saved model is a standard causal LM (no value model)
model = AutoModelForCausalLM.from_pretrained("ppo-trained-policy")
tokenizer = AutoTokenizer.from_pretrained("ppo-trained-policy")
inputs = tokenizer("What is the meaning of life?", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Related Pages