Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Trl PPOTrainer Save Generate

From Leeroopedia


Property Value
Implementation Name PPOTrainer Save Generate
Technology Huggingface TRL
Type Wrapper Doc
Workflow PPO RLHF Training
Principle Principle:Huggingface_Trl_PPO_Model_Saving_and_Evaluation

Overview

Description

The save_model method extracts the policy model from the PolicyAndValueWrapper and saves only the policy weights. The generate_completions method performs evaluation by generating near-greedy responses from evaluation prompts, scoring them with the reward model, and logging the results as formatted tables. Both methods include special handling for DeepSpeed distributed training.

Usage

save_model is called at the end of training or at checkpoint intervals. generate_completions is called periodically during training based on num_sample_generations and also once at the end of training for a final evaluation.

Code Reference

Source Location

  • save_model: trl/experimental/ppo/ppo_trainer.py lines 581-594
  • generate_completions: trl/experimental/ppo/ppo_trainer.py lines 942-1010
  • _save_checkpoint: trl/experimental/ppo/ppo_trainer.py lines 1012-1019

Signature

def save_model(self, output_dir: str | None = None, _internal_call: bool = False):
    """
    Save only the policy model weights (not the value model).

    Temporarily swaps self.model to self.model.policy for saving,
    then restores the wrapper. For DeepSpeed, also swaps self.deepspeed.

    Args:
        output_dir: Directory to save the model. Defaults to args.output_dir.
        _internal_call: Internal flag for Trainer compatibility.
    """
    backup_model = self.model
    self.model = self.model.policy  # save only the policy

    if self.is_deepspeed_enabled:
        backup_deepspeed = self.deepspeed
        self.deepspeed = self.model

    super().save_model(output_dir, _internal_call)

    self.model = backup_model

    if self.is_deepspeed_enabled:
        self.deepspeed = backup_deepspeed
def generate_completions(self, sampling: bool = False):
    """
    Generate completions from evaluation prompts and log results.

    Uses near-greedy decoding (temperature=0.01) for deterministic evaluation.
    Scores each generated response with the reward model.
    Logs results to console (Rich tables), Weights & Biases, and Comet ML.

    Args:
        sampling: If True, only process the first batch (used during training).
                  If False, process all evaluation batches (used at end of training).
    """
def _save_checkpoint(self, model, trial):
    """Save checkpoint with automatic model card generation."""
    if self.args.hub_model_id is None:
        model_name = Path(self.args.output_dir).name
    else:
        model_name = self.args.hub_model_id.split("/")[-1]
    self.create_model_card(model_name=model_name)
    super()._save_checkpoint(model, trial)

Import

# These are methods on PPOTrainer, invoked as:
from trl.experimental.ppo import PPOTrainer
trainer = PPOTrainer(...)
trainer.save_model("output_dir")
trainer.generate_completions()

I/O Contract

save_model Inputs

Parameter Type Default Description
output_dir str or None None Save directory; defaults to args.output_dir

save_model Outputs

Output Location Description
Policy weights output_dir/model.safetensors Only the causal LM policy weights (no value model weights)
Tokenizer output_dir/ Tokenizer files for the policy model
Model card output_dir/README.md Auto-generated model card

generate_completions Inputs

Parameter Type Default Description
sampling bool False If True, process only the first eval batch; if False, process all

generate_completions Outputs

Output Destination Description
table["query"] Logs Decoded prompt text
table["model response"] Logs Decoded generated response text
table["score"] Logs Reward model score for each response
Rich table Console Pretty-printed first 5 samples
wandb.Table Weights and Biases Full table for interactive exploration
CSV table Comet ML Full table for experiment tracking

Generation Configuration

Parameter Value Rationale
temperature 0.01 + 1e-7 Near-greedy for deterministic evaluation
max_new_tokens args.response_length Same length limit as training
top_k 0.0 Disabled (no top-k filtering)
top_p 1.0 Disabled (no nucleus sampling)
do_sample True Required for temperature-based sampling

Usage Examples

Save After Training

from trl.experimental.ppo import PPOTrainer

# After training completes
trainer.train()

# Save only the policy model
trainer.save_model("ppo-trained-policy")

# Push to Hub if configured
if trainer.args.push_to_hub:
    trainer.push_to_hub(dataset_name="my-dataset")

Generate and Evaluate

# During training (sampling=True processes only first batch)
trainer.generate_completions(sampling=True)

# Final evaluation (sampling=False processes all eval batches)
trainer.generate_completions(sampling=False)

Load Saved Policy for Inference

from transformers import AutoModelForCausalLM, AutoTokenizer

# The saved model is a standard causal LM (no value model)
model = AutoModelForCausalLM.from_pretrained("ppo-trained-policy")
tokenizer = AutoTokenizer.from_pretrained("ppo-trained-policy")

inputs = tokenizer("What is the meaning of life?", return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment