Implementation:Alibaba ROLL Merge LoRA Into State Dict
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Model_Management, Diffusion_Models |
| Last Updated | 2026-02-07 20:00 GMT |
Overview
Concrete LoRA merge utility script for diffusion models provided by the Alibaba ROLL examples.
Description
The merge_lora_into_state_dict function merges LoRA adapter weights into a base model state dictionary. It matches LoRA weight pairs (weight_A/weight_B or lora_A/lora_B) to their corresponding base weights and computes the merged values using matrix multiplication.
Usage
Run as a standalone script after training to produce a merged model for deployment.
Code Reference
Source Location
- Repository: Alibaba ROLL
- File: examples/wan2.2-14B-reward_fl_ds/merge_lora.py
- Lines: L14-61
Signature
def load_state_dict_from_safetensors(file_path, device: str = "cpu") -> dict:
"""Load state dict from safetensors file."""
def merge_lora_into_state_dict(
base_state_dict: dict,
lora_state_dict: dict,
alpha: float = 1.0,
device: str = "cpu",
dtype=torch.bfloat16
) -> dict:
"""
Merge LoRA adapters into base model state dictionary.
Args:
base_state_dict: Base model weights
lora_state_dict: LoRA adapter weights (from training checkpoint)
alpha: Merge alpha (default 1.0)
device: Target device
dtype: Target dtype (default bfloat16)
Returns:
Merged state dictionary: base + alpha * (W_up @ W_down)
"""
Import
# Standalone script usage:
# python examples/wan2.2-14B-reward_fl_ds/merge_lora.py \
# --base_model_path ./wan22_dit.safetensors \
# --lora_path ./checkpoint/diffusion_module.pth \
# --output_path ./merged_model.safetensors
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| base_state_dict | dict | Yes | Base model weights from safetensors |
| lora_state_dict | dict | Yes | LoRA weights from training checkpoint |
| alpha | float | No | Merge alpha (default 1.0) |
Outputs
| Name | Type | Description |
|---|---|---|
| merged_state_dict | dict | Merged model weights ready for deployment |
| safetensors file | File | Saved merged model in safetensors format |
Usage Examples
from examples.merge_lora import load_state_dict_from_safetensors, merge_lora_into_state_dict
from safetensors.torch import save_file
import torch
# Load base model and LoRA weights
base_sd = load_state_dict_from_safetensors("./wan22_dit.safetensors")
lora_sd = torch.load("./checkpoint/diffusion_module.pth")
# Merge LoRA into base
merged_sd = merge_lora_into_state_dict(base_sd, lora_sd, alpha=1.0)
# Save merged model
save_file(merged_sd, "./merged_model.safetensors")
Related Pages
Implements Principle
Requires Environment
Environment Dependencies
This implementation requires the following environment constraints:
Heuristics Applied
No specific heuristics apply to this implementation.
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment