Implementation:Microsoft DeepSpeedExamples Fuse LoRA
Appearance
- Implementation: Fuse_LoRA
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (Pattern Doc) |
| Title | Fuse_LoRA |
| Repository | Microsoft/DeepSpeedExamples |
| Application | DeepSpeed-VisualChat |
| Files | applications/DeepSpeed-VisualChat/utils/module/lora.py (Lines 113-137), applications/DeepSpeed-VisualChat/utils/utils.py (Lines 179-206)
|
| Language | Python |
| Status | Active |
Overview
Concrete tool for fusing LoRA weights and saving DeepSpeed-VisualChat models for deployment.
Code Reference
LinearLayer_LoRA Class (lora.py, Lines 13-82)
class LinearLayer_LoRA(nn.Module):
"""Simple LoRA implementation for Linear layers."""
def __init__(self, weight, lora_dim=0, lora_scaling=1,
lora_droppout=0, bias=None):
super(LinearLayer_LoRA, self).__init__()
self.weight = weight
self.bias = bias
try:
rows, columns = weight.ds_shape # ZeRO-3 shape
except:
rows, columns = weight.shape
self.lora_right_weight = nn.Parameter(
torch.zeros(columns, lora_dim)) # B matrix
self.lora_left_weight = nn.Parameter(
torch.zeros(lora_dim, rows)) # A matrix
self.lora_scaling = lora_scaling / lora_dim
if lora_droppout > 0:
self.lora_dropout = nn.Dropout(lora_droppout)
else:
self.lora_dropout = nn.Identity()
self.reset_parameters()
self.weight.requires_grad = False # freeze original weight
self.fuse_lora = False
def reset_parameters(self):
nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_left_weight)
def fuse_lora_weight(self):
if not self.fuse_lora:
self.weight.data += self.lora_scaling * torch.matmul(
self.lora_left_weight.t(), self.lora_right_weight.t())
self.fuse_lora = True
def unfuse_lora_weight(self):
if self.fuse_lora:
self.weight.data -= self.lora_scaling * torch.matmul(
self.lora_left_weight.t(), self.lora_right_weight.t())
self.fuse_lora = False
def forward(self, input):
if self.fuse_lora:
return F.linear(input, self.weight, self.bias)
else:
return F.linear(input, self.weight, self.bias) + \
(self.lora_dropout(input) @ self.lora_right_weight
@ self.lora_left_weight) * self.lora_scaling
LoRA Conversion Functions (lora.py, Lines 86-137)
def convert_linear_layer_to_lora(model, part_module_name,
lora_dim=0, lora_scaling=1,
lora_droppout=0):
"""Replace matching Linear layers with LoRA variants."""
repalce_name = []
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and part_module_name in name:
repalce_name.append(name)
for name in repalce_name:
module = recursive_getattr(model, name)
tmp = LinearLayer_LoRA(
module.weight, lora_dim, lora_scaling, lora_droppout,
module.bias).to(module.weight.device).to(module.weight.dtype)
recursive_setattr(model, name, tmp)
return model
def convert_lora_to_linear_layer(model, fuse_lora=True):
"""Fuse or unfuse all LoRA layers in the model."""
repalce_name = []
for name, module in model.named_modules():
if isinstance(module, LinearLayer_LoRA):
repalce_name.append(name)
for name in repalce_name:
module = recursive_getattr(model, name)
zero_stage_3 = hasattr(module.weight, 'ds_id')
with deepspeed.zero.GatheredParameters(
_z3_params_to_fetch([
module.weight, module.bias,
module.lora_left_weight, module.lora_right_weight
]),
modifier_rank=0,
enabled=zero_stage_3):
if fuse_lora:
module.fuse_lora_weight()
else:
module.unfuse_lora_weight()
return model
def fuse_lora(model):
"""Convenience wrapper: fuse all LoRA weights into base weights."""
return convert_lora_to_linear_layer(model, fuse_lora=True)
def unfuse_lora(model):
"""Convenience wrapper: unfuse all LoRA weights from base weights."""
return convert_lora_to_linear_layer(model, fuse_lora=False)
def only_optimize_lora_parameters(model):
"""Freeze all parameters except LoRA weights."""
for name, param in model.named_parameters():
if "lora_right_weight" in name or "lora_left_weight" in name:
param.requires_grad = True
else:
param.requires_grad = False
return model
ZeRO-3 Model Saving (utils/utils.py, Lines 179-206)
def save_zero_three_model(model_ema, global_rank, save_dir,
zero_stage=0, sub_folder=""):
zero_stage_3 = (zero_stage == 3)
output_dir = os.path.join(save_dir, sub_folder)
os.makedirs(output_dir, exist_ok=True)
WEIGHTS_NAME = "pytorch_model.bin"
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
model_to_save = model_ema.module if hasattr(model_ema, 'module') \
else model_ema
if not zero_stage_3:
if global_rank == 0:
torch.save(model_to_save.state_dict(), output_model_file)
else:
output_state_dict = {}
for k, v in model_to_save.named_parameters():
if hasattr(v, 'ds_id'):
with deepspeed.zero.GatheredParameters(
_z3_params_to_fetch([v]),
enabled=zero_stage_3):
v_p = v.data.clone().detach().cpu()
else:
v_p = v.cpu()
if global_rank == 0 and "lora" not in k:
output_state_dict[k] = v_p
if global_rank == 0:
torch.save(output_state_dict, output_model_file)
del output_state_dict
I/O Contract
fuse_lora / unfuse_lora
| Direction | Parameter | Type | Description |
|---|---|---|---|
| Input | model |
nn.Module |
Model containing LinearLayer_LoRA modules
|
| Output | (return) | nn.Module |
Same model with LoRA weights fused/unfused in-place |
convert_linear_layer_to_lora
| Direction | Parameter | Type | Description |
|---|---|---|---|
| Input | model |
nn.Module |
Model with standard nn.Linear layers
|
| Input | part_module_name |
str |
Module name filter (e.g., "model.layers.")
|
| Input | lora_dim |
int |
LoRA rank (e.g., 16) |
| Input | lora_scaling |
int |
Scaling numerator (default 1) |
| Input | lora_droppout |
float |
Dropout probability (default 0) |
| Output | (return) | nn.Module |
Model with matching Linear layers replaced by LinearLayer_LoRA |
save_zero_three_model
| Direction | Parameter | Type | Description |
|---|---|---|---|
| Input | model_ema |
DeepSpeedEngine |
The DeepSpeed-wrapped model |
| Input | global_rank |
int |
Current process rank (only rank 0 saves) |
| Input | save_dir |
str |
Root output directory |
| Input | zero_stage |
int |
ZeRO stage (special handling for stage 3) |
| Input | sub_folder |
str |
Subfolder name (e.g., "epoch-0")
|
| Output | (side effect) | file | Saves pytorch_model.bin to save_dir/sub_folder/
|
Usage Example
Applying LoRA During Model Setup
# In training/main.py
if args.lang_lora_dim > 0:
model.lang_decoder = convert_linear_layer_to_lora(
model.lang_decoder,
args.lang_lora_module_name, # "model.layers."
args.lang_lora_dim) # e.g., 16
if args.only_optimize_lora:
model.lang_decoder = only_optimize_lora_parameters(model.lang_decoder)
if args.vis_lora_dim > 0:
model.vis_encoder = convert_linear_layer_to_lora(
model.vis_encoder,
args.vis_lora_module_name, # "encoder.layers."
args.vis_lora_dim)
if args.only_optimize_lora:
model.vis_encoder = only_optimize_lora_parameters(model.vis_encoder)
Per-Epoch Fusion and Saving
# At end of each epoch in training/main.py
model = fuse_lora(model)
# Save in Hugging Face format (rank 0 only)
if args.global_rank == 0:
save_hf_format(model, tokenizer, args, f'epoch-{epoch}')
# Save ZeRO-3 checkpoint (all ranks participate in gathering)
if args.zero_stage == 3:
save_zero_three_model(model, args.global_rank, args.output_dir,
zero_stage=args.zero_stage,
sub_folder=f'epoch-{epoch}')
# Restore for continued training
model = unfuse_lora(model)
Command-Line LoRA Configuration
deepspeed training/main.py \
--lang_lora_dim 16 \
--lang_lora_module_name "model.layers." \
--vis_lora_dim 16 \
--vis_lora_module_name "encoder.layers." \
--only_optimize_lora \
...
Fusion Mathematics in Code
The fusion operation in fuse_lora_weight():
# self.lora_left_weight: shape [r, out_features]
# self.lora_right_weight: shape [in_features, r]
# self.weight: shape [out_features, in_features]
self.weight.data += self.lora_scaling * torch.matmul(
self.lora_left_weight.t(), # [out_features, r]
self.lora_right_weight.t() # [r, in_features]
)
# Result: [out_features, in_features] -- matches weight shape
The unfusion is the exact inverse:
self.weight.data -= self.lora_scaling * torch.matmul(
self.lora_left_weight.t(),
self.lora_right_weight.t()
)
A boolean flag self.fuse_lora prevents double-fusion or double-unfusion.
Dependencies
torch-- Core tensor operationstorch.nn-- Linear, Parameter, Dropout, Identitytorch.nn.functional-- F.linear for forward passdeepspeed-- ZeRO-3 parameter gatheringdeepspeed.compression.helper-- recursive_getattr, recursive_setattr for module replacementdeepspeed.runtime.zero.partition_parameters-- ZeroParamStatus for parameter availability checks
Related Pages
- Principle:Microsoft_DeepSpeedExamples_LoRA_Fusion_And_Export -- The theoretical basis for LoRA fusion and export
- Implementation:Microsoft_DeepSpeedExamples_DeepSpeed_Initialize_VisualChat -- The training loop that calls fuse/unfuse
- Implementation:Microsoft_DeepSpeedExamples_Create_DSVL_Model -- The model architecture containing LoRA-adapted layers
- Environment:Microsoft_DeepSpeedExamples_VisualChat_Training_Environment
- Heuristic:Microsoft_DeepSpeedExamples_LoRA_Learning_Rate_Scaling
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment