Heuristic:ContextualAI HALOs LoRA Merge At Save
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Model_Checkpointing, LoRA |
| Last Updated | 2026-02-08 03:00 GMT |
Overview
LoRA weights are manually merged into the base model at final save time using a custom helper, because automated PEFT merging does not work correctly with FSDP.
Description
When training with LoRA (`use_peft=true`), the HALOs framework uses a custom `get_base_model_state_dict_from_peft()` function to manually merge LoRA weights (A and B matrices) into the base model weights at save time. This is done because the standard PEFT `.merge_and_unload()` method does not work correctly when the model is sharded across GPUs with FSDP. The merge formula is: `base_weight + (lora_B @ lora_A) * (lora_alpha / lora_r)`. Intermediate checkpoints during LoRA training save only the LoRA adapter weights, but the final save produces a fully merged model.
Usage
This heuristic is automatically applied when `use_peft=true` and the final save is performed. No user action is required. However, be aware that:
- Intermediate checkpoints only contain LoRA weights (not the merged model), so they require the base model to resume training.
- Final checkpoints contain the fully merged model and can be used directly for inference without PEFT.
- The default LoRA config is `lora_r=64, lora_alpha=256, lora_dropout=0.05, target_modules="all-linear"`.
The Insight (Rule of Thumb)
- Action: Use `++model.use_peft=true` for LoRA training; the framework handles merging at final save.
- Value: Produces a standalone merged model at `FINAL/` that does not require PEFT for inference.
- Trade-off: Intermediate checkpoints are LoRA-only (smaller, but need base model to use). The custom merge function adds complexity but is necessary for FSDP compatibility.
- LoRA defaults: `lora_r=64` (rank), `lora_alpha=256` (scaling = alpha/r = 4.0), `lora_dropout=0.05`, targeting all linear layers.
Reasoning
FSDP shards model parameters across GPUs, and PEFT's merge utilities expect all parameters to be on a single device. The custom `get_base_model_state_dict_from_peft` function operates on the gathered state dict (after `accelerator.get_state_dict` collects from all shards), performing the LoRA merge mathematically: `W_merged = W_base + (B @ A) * (alpha / r)`. The function also strips PEFT-specific name prefixes (`base_model.model.`, `lora_A.default.`, etc.) to produce a clean state dict compatible with the original model architecture.
Code Evidence
Custom LoRA merge function in `train/utils.py:70-95`:
def get_base_model_state_dict_from_peft(peft_state_dict, lora_alpha, lora_r):
"""
Return the state dict for the base model given the state dict for a lora-wrapped
AutoModelForCausalLM, merging the lora weights as needed.
This helper is needed because automated weight merging does not work with FSDP.
"""
state_dict = {}
for name in peft_state_dict.keys():
if 'lora_A' in name:
base_param_name = name.replace('lora_A.default', 'base_layer')
lora_a = peft_state_dict[name]
lora_b = peft_state_dict[name.replace('lora_A', 'lora_B')]
scaling = lora_alpha / lora_r
new_name = name.replace('lora_A.default.', '').replace('base_model.model.', '')
state_dict[new_name] = peft_state_dict[base_param_name] + (lora_b @ lora_a) * scaling
Conditional merge at save in `train/trainers.py:430-441`:
if self.config.model.use_peft and final_save:
state_dict = get_base_model_state_dict_from_peft(
self.accelerator.get_state_dict(self.policy),
self.config.model.peft.lora_alpha,
self.config.model.peft.lora_r,
)
unwrapped_model = self.accelerator.unwrap_model(self.policy).base_model
else:
state_dict = self.accelerator.get_state_dict(self.policy)
unwrapped_model = self.accelerator.unwrap_model(self.policy)
LoRA dtype alignment in `launch.py:247-250`:
# Ensure LoRA layers are in the same dtype as the base model
for name, module in peft_model.named_modules():
if 'lora_' in name:
module.to(getattr(torch, config.model.policy_dtype))