Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL Merge LoRA Into State Dict

From Leeroopedia


Knowledge Sources
Domains Model_Management, Diffusion_Models
Last Updated 2026-02-07 20:00 GMT

Overview

Concrete LoRA merge utility script for diffusion models provided by the Alibaba ROLL examples.

Description

The merge_lora_into_state_dict function merges LoRA adapter weights into a base model state dictionary. It matches LoRA weight pairs (weight_A/weight_B or lora_A/lora_B) to their corresponding base weights and computes the merged values using matrix multiplication.

Usage

Run as a standalone script after training to produce a merged model for deployment.

Code Reference

Source Location

  • Repository: Alibaba ROLL
  • File: examples/wan2.2-14B-reward_fl_ds/merge_lora.py
  • Lines: L14-61

Signature

def load_state_dict_from_safetensors(file_path, device: str = "cpu") -> dict:
    """Load state dict from safetensors file."""

def merge_lora_into_state_dict(
    base_state_dict: dict,
    lora_state_dict: dict,
    alpha: float = 1.0,
    device: str = "cpu",
    dtype=torch.bfloat16
) -> dict:
    """
    Merge LoRA adapters into base model state dictionary.

    Args:
        base_state_dict: Base model weights
        lora_state_dict: LoRA adapter weights (from training checkpoint)
        alpha: Merge alpha (default 1.0)
        device: Target device
        dtype: Target dtype (default bfloat16)

    Returns:
        Merged state dictionary: base + alpha * (W_up @ W_down)
    """

Import

# Standalone script usage:
# python examples/wan2.2-14B-reward_fl_ds/merge_lora.py \
#   --base_model_path ./wan22_dit.safetensors \
#   --lora_path ./checkpoint/diffusion_module.pth \
#   --output_path ./merged_model.safetensors

I/O Contract

Inputs

Name Type Required Description
base_state_dict dict Yes Base model weights from safetensors
lora_state_dict dict Yes LoRA weights from training checkpoint
alpha float No Merge alpha (default 1.0)

Outputs

Name Type Description
merged_state_dict dict Merged model weights ready for deployment
safetensors file File Saved merged model in safetensors format

Usage Examples

from examples.merge_lora import load_state_dict_from_safetensors, merge_lora_into_state_dict
from safetensors.torch import save_file
import torch

# Load base model and LoRA weights
base_sd = load_state_dict_from_safetensors("./wan22_dit.safetensors")
lora_sd = torch.load("./checkpoint/diffusion_module.pth")

# Merge LoRA into base
merged_sd = merge_lora_into_state_dict(base_sd, lora_sd, alpha=1.0)

# Save merged model
save_file(merged_sd, "./merged_model.safetensors")

Related Pages

Implements Principle

Requires Environment

Environment Dependencies

This implementation requires the following environment constraints:

Heuristics Applied

No specific heuristics apply to this implementation.

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment