Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx Save Pretrained

From Leeroopedia


Knowledge Sources
Domains Training, Model_Persistence
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for saving trained model weights in HuggingFace format provided by the trlx base trainer.

Description

The save_pretrained() method on AccelerateRLTrainer (base class for all Accelerate trainers) saves the model, tokenizer, and configuration to a directory. It coordinates across distributed processes using Accelerate, ensuring only the main process writes files. For PEFT models, adapter weights are saved separately.

Usage

Call trainer.save_pretrained() after training completes, or rely on automatic checkpointing via config.train.checkpoint_interval. The default save location is config.train.checkpoint_dir/hf_model.

Code Reference

Source Location

  • Repository: trlx
  • File: trlx/trainer/accelerate_base_trainer.py
  • Lines: L284-307

Signature

def save_pretrained(self, directory: Optional[str] = None, **kwargs) -> None:
    """Save the underlying model, tokenizer, and configuration files.

    Args:
        directory: The directory to save to. If None, saves to
            checkpoint_dir/hf_model.
        **kwargs: Additional keyword arguments passed to the underlying
            HuggingFace model's save_pretrained method.
    """

Import

# Called on the trainer instance returned by trlx.train()
import trlx
trainer = trlx.train(...)
trainer.save_pretrained("./my_model")

I/O Contract

Inputs

Name Type Required Description
directory Optional[str] No Save path (default: checkpoint_dir/hf_model)
**kwargs Dict No Extra args passed to HuggingFace save_pretrained

Outputs

Name Type Description
config.json File Model configuration file
pytorch_model.bin File Model weights (or adapter weights for PEFT)
tokenizer files Files tokenizer.json, special_tokens_map.json, etc.

Usage Examples

Save After Training

import trlx
from trlx.data.default_configs import default_ppo_config

config = default_ppo_config()
trainer = trlx.train(reward_fn=reward_fn, prompts=prompts, config=config)

# Save to default location (checkpoint_dir/hf_model)
trainer.save_pretrained()

# Save to custom location
trainer.save_pretrained("./my_ppo_model")

Load Saved Model

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the saved model with standard HuggingFace API
model = AutoModelForCausalLM.from_pretrained("./my_ppo_model")
tokenizer = AutoTokenizer.from_pretrained("./my_ppo_model")

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment