Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft DeepSpeedExamples Fuse LoRA

From Leeroopedia


  1. Implementation: Fuse_LoRA

Metadata

Field Value
Page Type Implementation (Pattern Doc)
Title Fuse_LoRA
Repository Microsoft/DeepSpeedExamples
Application DeepSpeed-VisualChat
Files applications/DeepSpeed-VisualChat/utils/module/lora.py (Lines 113-137), applications/DeepSpeed-VisualChat/utils/utils.py (Lines 179-206)
Language Python
Status Active

Overview

Concrete tool for fusing LoRA weights and saving DeepSpeed-VisualChat models for deployment.

Code Reference

LinearLayer_LoRA Class (lora.py, Lines 13-82)

class LinearLayer_LoRA(nn.Module):
    """Simple LoRA implementation for Linear layers."""

    def __init__(self, weight, lora_dim=0, lora_scaling=1,
                 lora_droppout=0, bias=None):
        super(LinearLayer_LoRA, self).__init__()
        self.weight = weight
        self.bias = bias

        try:
            rows, columns = weight.ds_shape    # ZeRO-3 shape
        except:
            rows, columns = weight.shape

        self.lora_right_weight = nn.Parameter(
            torch.zeros(columns, lora_dim))    # B matrix
        self.lora_left_weight = nn.Parameter(
            torch.zeros(lora_dim, rows))       # A matrix
        self.lora_scaling = lora_scaling / lora_dim

        if lora_droppout > 0:
            self.lora_dropout = nn.Dropout(lora_droppout)
        else:
            self.lora_dropout = nn.Identity()

        self.reset_parameters()
        self.weight.requires_grad = False      # freeze original weight
        self.fuse_lora = False

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_left_weight)

    def fuse_lora_weight(self):
        if not self.fuse_lora:
            self.weight.data += self.lora_scaling * torch.matmul(
                self.lora_left_weight.t(), self.lora_right_weight.t())
        self.fuse_lora = True

    def unfuse_lora_weight(self):
        if self.fuse_lora:
            self.weight.data -= self.lora_scaling * torch.matmul(
                self.lora_left_weight.t(), self.lora_right_weight.t())
        self.fuse_lora = False

    def forward(self, input):
        if self.fuse_lora:
            return F.linear(input, self.weight, self.bias)
        else:
            return F.linear(input, self.weight, self.bias) + \
                   (self.lora_dropout(input) @ self.lora_right_weight
                    @ self.lora_left_weight) * self.lora_scaling

LoRA Conversion Functions (lora.py, Lines 86-137)

def convert_linear_layer_to_lora(model, part_module_name,
                                  lora_dim=0, lora_scaling=1,
                                  lora_droppout=0):
    """Replace matching Linear layers with LoRA variants."""
    repalce_name = []
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and part_module_name in name:
            repalce_name.append(name)
    for name in repalce_name:
        module = recursive_getattr(model, name)
        tmp = LinearLayer_LoRA(
            module.weight, lora_dim, lora_scaling, lora_droppout,
            module.bias).to(module.weight.device).to(module.weight.dtype)
        recursive_setattr(model, name, tmp)
    return model


def convert_lora_to_linear_layer(model, fuse_lora=True):
    """Fuse or unfuse all LoRA layers in the model."""
    repalce_name = []
    for name, module in model.named_modules():
        if isinstance(module, LinearLayer_LoRA):
            repalce_name.append(name)
    for name in repalce_name:
        module = recursive_getattr(model, name)
        zero_stage_3 = hasattr(module.weight, 'ds_id')
        with deepspeed.zero.GatheredParameters(
            _z3_params_to_fetch([
                module.weight, module.bias,
                module.lora_left_weight, module.lora_right_weight
            ]),
            modifier_rank=0,
            enabled=zero_stage_3):
            if fuse_lora:
                module.fuse_lora_weight()
            else:
                module.unfuse_lora_weight()
    return model


def fuse_lora(model):
    """Convenience wrapper: fuse all LoRA weights into base weights."""
    return convert_lora_to_linear_layer(model, fuse_lora=True)


def unfuse_lora(model):
    """Convenience wrapper: unfuse all LoRA weights from base weights."""
    return convert_lora_to_linear_layer(model, fuse_lora=False)


def only_optimize_lora_parameters(model):
    """Freeze all parameters except LoRA weights."""
    for name, param in model.named_parameters():
        if "lora_right_weight" in name or "lora_left_weight" in name:
            param.requires_grad = True
        else:
            param.requires_grad = False
    return model

ZeRO-3 Model Saving (utils/utils.py, Lines 179-206)

def save_zero_three_model(model_ema, global_rank, save_dir,
                          zero_stage=0, sub_folder=""):
    zero_stage_3 = (zero_stage == 3)
    output_dir = os.path.join(save_dir, sub_folder)
    os.makedirs(output_dir, exist_ok=True)
    WEIGHTS_NAME = "pytorch_model.bin"
    output_model_file = os.path.join(output_dir, WEIGHTS_NAME)

    model_to_save = model_ema.module if hasattr(model_ema, 'module') \
                    else model_ema
    if not zero_stage_3:
        if global_rank == 0:
            torch.save(model_to_save.state_dict(), output_model_file)
    else:
        output_state_dict = {}
        for k, v in model_to_save.named_parameters():
            if hasattr(v, 'ds_id'):
                with deepspeed.zero.GatheredParameters(
                    _z3_params_to_fetch([v]),
                    enabled=zero_stage_3):
                    v_p = v.data.clone().detach().cpu()
            else:
                v_p = v.cpu()
            if global_rank == 0 and "lora" not in k:
                output_state_dict[k] = v_p
        if global_rank == 0:
            torch.save(output_state_dict, output_model_file)
        del output_state_dict

I/O Contract

fuse_lora / unfuse_lora

Direction Parameter Type Description
Input model nn.Module Model containing LinearLayer_LoRA modules
Output (return) nn.Module Same model with LoRA weights fused/unfused in-place

convert_linear_layer_to_lora

Direction Parameter Type Description
Input model nn.Module Model with standard nn.Linear layers
Input part_module_name str Module name filter (e.g., "model.layers.")
Input lora_dim int LoRA rank (e.g., 16)
Input lora_scaling int Scaling numerator (default 1)
Input lora_droppout float Dropout probability (default 0)
Output (return) nn.Module Model with matching Linear layers replaced by LinearLayer_LoRA

save_zero_three_model

Direction Parameter Type Description
Input model_ema DeepSpeedEngine The DeepSpeed-wrapped model
Input global_rank int Current process rank (only rank 0 saves)
Input save_dir str Root output directory
Input zero_stage int ZeRO stage (special handling for stage 3)
Input sub_folder str Subfolder name (e.g., "epoch-0")
Output (side effect) file Saves pytorch_model.bin to save_dir/sub_folder/

Usage Example

Applying LoRA During Model Setup

# In training/main.py
if args.lang_lora_dim > 0:
    model.lang_decoder = convert_linear_layer_to_lora(
        model.lang_decoder,
        args.lang_lora_module_name,   # "model.layers."
        args.lang_lora_dim)           # e.g., 16
    if args.only_optimize_lora:
        model.lang_decoder = only_optimize_lora_parameters(model.lang_decoder)

if args.vis_lora_dim > 0:
    model.vis_encoder = convert_linear_layer_to_lora(
        model.vis_encoder,
        args.vis_lora_module_name,    # "encoder.layers."
        args.vis_lora_dim)
    if args.only_optimize_lora:
        model.vis_encoder = only_optimize_lora_parameters(model.vis_encoder)

Per-Epoch Fusion and Saving

# At end of each epoch in training/main.py
model = fuse_lora(model)

# Save in Hugging Face format (rank 0 only)
if args.global_rank == 0:
    save_hf_format(model, tokenizer, args, f'epoch-{epoch}')

# Save ZeRO-3 checkpoint (all ranks participate in gathering)
if args.zero_stage == 3:
    save_zero_three_model(model, args.global_rank, args.output_dir,
                          zero_stage=args.zero_stage,
                          sub_folder=f'epoch-{epoch}')

# Restore for continued training
model = unfuse_lora(model)

Command-Line LoRA Configuration

deepspeed training/main.py \
    --lang_lora_dim 16 \
    --lang_lora_module_name "model.layers." \
    --vis_lora_dim 16 \
    --vis_lora_module_name "encoder.layers." \
    --only_optimize_lora \
    ...

Fusion Mathematics in Code

The fusion operation in fuse_lora_weight():

# self.lora_left_weight:  shape [r, out_features]
# self.lora_right_weight: shape [in_features, r]
# self.weight:            shape [out_features, in_features]

self.weight.data += self.lora_scaling * torch.matmul(
    self.lora_left_weight.t(),    # [out_features, r]
    self.lora_right_weight.t()    # [r, in_features]
)
# Result: [out_features, in_features] -- matches weight shape

The unfusion is the exact inverse:

self.weight.data -= self.lora_scaling * torch.matmul(
    self.lora_left_weight.t(),
    self.lora_right_weight.t()
)

A boolean flag self.fuse_lora prevents double-fusion or double-unfusion.

Dependencies

  • torch -- Core tensor operations
  • torch.nn -- Linear, Parameter, Dropout, Identity
  • torch.nn.functional -- F.linear for forward pass
  • deepspeed -- ZeRO-3 parameter gathering
  • deepspeed.compression.helper -- recursive_getattr, recursive_setattr for module replacement
  • deepspeed.runtime.zero.partition_parameters -- ZeroParamStatus for parameter availability checks

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment