Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:ContextualAI HALOs BasicTrainer Save

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Infrastructure
Last Updated 2026-02-08 03:00 GMT

Overview

Concrete tool for saving model checkpoints with LoRA merging and FSDP support provided by BasicTrainer.save.

Description

The BasicTrainer.save method handles the complete checkpoint persistence flow: saving the tokenizer, writing training metrics to JSON, extracting optimizer and scheduler state dicts, and saving the model weights. When use_peft=True and final_save=True, it calls get_base_model_state_dict_from_peft to merge LoRA adapter weights into the base model before saving.

The PPOTrainer overrides save() to additionally persist the value head weights (v_head.pt).

Usage

Called automatically by BasicTrainer.train() at the end of training and at intermediate checkpoints. Can also be called directly for manual saves.

Code Reference

Source Location

  • Repository: ContextualAI/HALOs
  • File: train/trainers.py (BasicTrainer.save), train/utils.py (get_base_model_state_dict_from_peft)
  • Lines: train/trainers.py:L400-451 (BasicTrainer.save), train/utils.py:L70-95 (get_base_model_state_dict_from_peft)

Signature

class BasicTrainer:
    def save(
        self,
        output_dir: Optional[str] = None,
        metrics: Optional[Dict] = {},
        final_save: bool = True
    ) -> None:
        """Save tokenizer, policy model, optimizer, scheduler state to disk.

        Args:
            output_dir: Path to save directory. Defaults to {run_dir}/step-{counter}.
            metrics: Training metrics dict to write alongside model.
            final_save: If True and use_peft, merges LoRA into base model.
        """

def get_base_model_state_dict_from_peft(
    peft_state_dict: Dict,
    lora_alpha: int,
    lora_r: int
) -> Dict:
    """Merge LoRA adapter weights into the base model state dict.

    Computes: W_merged = W_base + (alpha/r) * B @ A
    for each LoRA-modified layer.
    """

Import

from train.trainers import BasicTrainer
from train.utils import get_base_model_state_dict_from_peft

I/O Contract

Inputs

Name Type Required Description
output_dir str No Save directory (defaults to step-based path)
metrics Dict No Training metrics to save as metrics.json
final_save bool No If True + LoRA, merges adapters into base model
self.policy nn.Module Yes Trained model (possibly FSDP-wrapped with LoRA)
self.optimizer Optimizer Yes Optimizer whose state dict is saved
self.scheduler LRScheduler Yes Scheduler whose state dict is saved

Outputs

Name Type Description
model weights Files Model weights via save_pretrained (bin or safetensors)
config.json File Model configuration
tokenizer files Files Tokenizer vocab, config, special tokens
optimizer.pt File Optimizer state dict
scheduler.pt File Scheduler state dict
metrics.json File Training metrics with example counter
v_head.pt File Value head weights (PPO only)

Usage Examples

Automatic Save During Training

# Inside BasicTrainer.train(), save is called automatically:
# At end of training:
output_dir = os.path.join(self.run_dir, 'FINAL')
self.save(output_dir, results['results'], final_save=True)

# At intermediate checkpoints (if config.intermediate_checkpoints=True):
output_dir = os.path.join(self.run_dir, f'step-{self.example_counter}')
self.save(output_dir, results['results'], final_save=False)

Manual LoRA Merge

from train.utils import get_base_model_state_dict_from_peft

# Manually merge LoRA weights into base model
peft_state_dict = accelerator.get_state_dict(policy)
merged_state_dict = get_base_model_state_dict_from_peft(
    peft_state_dict,
    lora_alpha=256,
    lora_r=64
)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment