Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:ContextualAI HALOs LoRA Merge At Save

From Leeroopedia




Knowledge Sources
Domains Optimization, Model_Checkpointing, LoRA
Last Updated 2026-02-08 03:00 GMT

Overview

LoRA weights are manually merged into the base model at final save time using a custom helper, because automated PEFT merging does not work correctly with FSDP.

Description

When training with LoRA (`use_peft=true`), the HALOs framework uses a custom `get_base_model_state_dict_from_peft()` function to manually merge LoRA weights (A and B matrices) into the base model weights at save time. This is done because the standard PEFT `.merge_and_unload()` method does not work correctly when the model is sharded across GPUs with FSDP. The merge formula is: `base_weight + (lora_B @ lora_A) * (lora_alpha / lora_r)`. Intermediate checkpoints during LoRA training save only the LoRA adapter weights, but the final save produces a fully merged model.

Usage

This heuristic is automatically applied when `use_peft=true` and the final save is performed. No user action is required. However, be aware that:

  • Intermediate checkpoints only contain LoRA weights (not the merged model), so they require the base model to resume training.
  • Final checkpoints contain the fully merged model and can be used directly for inference without PEFT.
  • The default LoRA config is `lora_r=64, lora_alpha=256, lora_dropout=0.05, target_modules="all-linear"`.

The Insight (Rule of Thumb)

  • Action: Use `++model.use_peft=true` for LoRA training; the framework handles merging at final save.
  • Value: Produces a standalone merged model at `FINAL/` that does not require PEFT for inference.
  • Trade-off: Intermediate checkpoints are LoRA-only (smaller, but need base model to use). The custom merge function adds complexity but is necessary for FSDP compatibility.
  • LoRA defaults: `lora_r=64` (rank), `lora_alpha=256` (scaling = alpha/r = 4.0), `lora_dropout=0.05`, targeting all linear layers.

Reasoning

FSDP shards model parameters across GPUs, and PEFT's merge utilities expect all parameters to be on a single device. The custom `get_base_model_state_dict_from_peft` function operates on the gathered state dict (after `accelerator.get_state_dict` collects from all shards), performing the LoRA merge mathematically: `W_merged = W_base + (B @ A) * (alpha / r)`. The function also strips PEFT-specific name prefixes (`base_model.model.`, `lora_A.default.`, etc.) to produce a clean state dict compatible with the original model architecture.

Code Evidence

Custom LoRA merge function in `train/utils.py:70-95`:

def get_base_model_state_dict_from_peft(peft_state_dict, lora_alpha, lora_r):
    """
    Return the state dict for the base model given the state dict for a lora-wrapped
    AutoModelForCausalLM, merging the lora weights as needed.

    This helper is needed because automated weight merging does not work with FSDP.
    """
    state_dict = {}
    for name in peft_state_dict.keys():
        if 'lora_A' in name:
            base_param_name = name.replace('lora_A.default', 'base_layer')
            lora_a = peft_state_dict[name]
            lora_b = peft_state_dict[name.replace('lora_A', 'lora_B')]
            scaling = lora_alpha / lora_r
            new_name = name.replace('lora_A.default.', '').replace('base_model.model.', '')
            state_dict[new_name] = peft_state_dict[base_param_name] + (lora_b @ lora_a) * scaling

Conditional merge at save in `train/trainers.py:430-441`:

if self.config.model.use_peft and final_save:
    state_dict = get_base_model_state_dict_from_peft(
        self.accelerator.get_state_dict(self.policy),
        self.config.model.peft.lora_alpha,
        self.config.model.peft.lora_r,
    )
    unwrapped_model = self.accelerator.unwrap_model(self.policy).base_model
else:
    state_dict = self.accelerator.get_state_dict(self.policy)
    unwrapped_model = self.accelerator.unwrap_model(self.policy)

LoRA dtype alignment in `launch.py:247-250`:

# Ensure LoRA layers are in the same dtype as the base model
for name, module in peft_model.named_modules():
    if 'lora_' in name:
        module.to(getattr(torch, config.model.policy_dtype))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment