Implementation:CarperAI Trlx Save Pretrained
| Knowledge Sources | |
|---|---|
| Domains | Training, Model_Persistence |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for saving trained model weights in HuggingFace format provided by the trlx base trainer.
Description
The save_pretrained() method on AccelerateRLTrainer (base class for all Accelerate trainers) saves the model, tokenizer, and configuration to a directory. It coordinates across distributed processes using Accelerate, ensuring only the main process writes files. For PEFT models, adapter weights are saved separately.
Usage
Call trainer.save_pretrained() after training completes, or rely on automatic checkpointing via config.train.checkpoint_interval. The default save location is config.train.checkpoint_dir/hf_model.
Code Reference
Source Location
- Repository: trlx
- File: trlx/trainer/accelerate_base_trainer.py
- Lines: L284-307
Signature
def save_pretrained(self, directory: Optional[str] = None, **kwargs) -> None:
"""Save the underlying model, tokenizer, and configuration files.
Args:
directory: The directory to save to. If None, saves to
checkpoint_dir/hf_model.
**kwargs: Additional keyword arguments passed to the underlying
HuggingFace model's save_pretrained method.
"""
Import
# Called on the trainer instance returned by trlx.train()
import trlx
trainer = trlx.train(...)
trainer.save_pretrained("./my_model")
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| directory | Optional[str] | No | Save path (default: checkpoint_dir/hf_model) |
| **kwargs | Dict | No | Extra args passed to HuggingFace save_pretrained |
Outputs
| Name | Type | Description |
|---|---|---|
| config.json | File | Model configuration file |
| pytorch_model.bin | File | Model weights (or adapter weights for PEFT) |
| tokenizer files | Files | tokenizer.json, special_tokens_map.json, etc. |
Usage Examples
Save After Training
import trlx
from trlx.data.default_configs import default_ppo_config
config = default_ppo_config()
trainer = trlx.train(reward_fn=reward_fn, prompts=prompts, config=config)
# Save to default location (checkpoint_dir/hf_model)
trainer.save_pretrained()
# Save to custom location
trainer.save_pretrained("./my_ppo_model")
Load Saved Model
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load the saved model with standard HuggingFace API
model = AutoModelForCausalLM.from_pretrained("./my_ppo_model")
tokenizer = AutoTokenizer.from_pretrained("./my_ppo_model")