Implementation:ContextualAI HALOs BasicTrainer Save
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Infrastructure |
| Last Updated | 2026-02-08 03:00 GMT |
Overview
Concrete tool for saving model checkpoints with LoRA merging and FSDP support provided by BasicTrainer.save.
Description
The BasicTrainer.save method handles the complete checkpoint persistence flow: saving the tokenizer, writing training metrics to JSON, extracting optimizer and scheduler state dicts, and saving the model weights. When use_peft=True and final_save=True, it calls get_base_model_state_dict_from_peft to merge LoRA adapter weights into the base model before saving.
The PPOTrainer overrides save() to additionally persist the value head weights (v_head.pt).
Usage
Called automatically by BasicTrainer.train() at the end of training and at intermediate checkpoints. Can also be called directly for manual saves.
Code Reference
Source Location
- Repository: ContextualAI/HALOs
- File: train/trainers.py (BasicTrainer.save), train/utils.py (get_base_model_state_dict_from_peft)
- Lines: train/trainers.py:L400-451 (BasicTrainer.save), train/utils.py:L70-95 (get_base_model_state_dict_from_peft)
Signature
class BasicTrainer:
def save(
self,
output_dir: Optional[str] = None,
metrics: Optional[Dict] = {},
final_save: bool = True
) -> None:
"""Save tokenizer, policy model, optimizer, scheduler state to disk.
Args:
output_dir: Path to save directory. Defaults to {run_dir}/step-{counter}.
metrics: Training metrics dict to write alongside model.
final_save: If True and use_peft, merges LoRA into base model.
"""
def get_base_model_state_dict_from_peft(
peft_state_dict: Dict,
lora_alpha: int,
lora_r: int
) -> Dict:
"""Merge LoRA adapter weights into the base model state dict.
Computes: W_merged = W_base + (alpha/r) * B @ A
for each LoRA-modified layer.
"""
Import
from train.trainers import BasicTrainer
from train.utils import get_base_model_state_dict_from_peft
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| output_dir | str | No | Save directory (defaults to step-based path) |
| metrics | Dict | No | Training metrics to save as metrics.json |
| final_save | bool | No | If True + LoRA, merges adapters into base model |
| self.policy | nn.Module | Yes | Trained model (possibly FSDP-wrapped with LoRA) |
| self.optimizer | Optimizer | Yes | Optimizer whose state dict is saved |
| self.scheduler | LRScheduler | Yes | Scheduler whose state dict is saved |
Outputs
| Name | Type | Description |
|---|---|---|
| model weights | Files | Model weights via save_pretrained (bin or safetensors) |
| config.json | File | Model configuration |
| tokenizer files | Files | Tokenizer vocab, config, special tokens |
| optimizer.pt | File | Optimizer state dict |
| scheduler.pt | File | Scheduler state dict |
| metrics.json | File | Training metrics with example counter |
| v_head.pt | File | Value head weights (PPO only) |
Usage Examples
Automatic Save During Training
# Inside BasicTrainer.train(), save is called automatically:
# At end of training:
output_dir = os.path.join(self.run_dir, 'FINAL')
self.save(output_dir, results['results'], final_save=True)
# At intermediate checkpoints (if config.intermediate_checkpoints=True):
output_dir = os.path.join(self.run_dir, f'step-{self.example_counter}')
self.save(output_dir, results['results'], final_save=False)
Manual LoRA Merge
from train.utils import get_base_model_state_dict_from_peft
# Manually merge LoRA weights into base model
peft_state_dict = accelerator.get_state_dict(policy)
merged_state_dict = get_base_model_state_dict_from_peft(
peft_state_dict,
lora_alpha=256,
lora_r=64
)