Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft LoRA Lora State Dict

From Leeroopedia


Knowledge Sources
Domains Serialization, Parameter_Efficient_Fine_Tuning
Last Updated 2026-02-10 05:00 GMT

Overview

Utility function that extracts only LoRA parameters (and optionally biases) from a model's state dict for compact checkpoint saving.

Description

The lora_state_dict function filters a model's full state dictionary to return only entries corresponding to LoRA parameters. This filtered dictionary can then be saved with torch.save to produce a compact checkpoint file containing only the task-specific adaptations.

Usage

Call this function whenever you need to save a LoRA checkpoint. Use the same bias mode that was passed to mark_only_lora_as_trainable during model preparation.

Code Reference

Source Location

Signature

def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
    """Extract LoRA parameters from model state dict.

    Args:
        model: The PyTorch model containing LoRA layers
        bias: Bias handling mode - 'none', 'all', or 'lora_only'

    Returns:
        Dictionary mapping parameter names to tensors, containing only
        LoRA parameters (and optionally biases)
    """

Import

from loralib import lora_state_dict
# or
import loralib as lora
# then use lora.lora_state_dict

I/O Contract

Inputs

Name Type Required Description
model nn.Module Yes PyTorch model containing LoRA-augmented layers
bias str No (default 'none') Bias handling mode: 'none', 'all', or 'lora_only'

Outputs

Name Type Description
state_dict Dict[str, torch.Tensor] Filtered state dictionary containing only LoRA parameters (and optionally biases)

Bias Mode Filter Details

bias Value Keys Included in Output Dict
none Only keys containing "lora_" (e.g., "transformer.h.0.attn.c_attn.lora_A", "transformer.h.0.attn.c_attn.lora_B")
all Keys containing "lora_" + all keys containing "bias" (e.g., "transformer.h.0.attn.c_attn.bias", "transformer.ln_f.bias")
lora_only Keys containing "lora_" + bias keys from modules that also have lora_ parameters

Implementation Details

The function works by calling model.state_dict() and filtering the resulting dictionary:

  1. Collect the full state dictionary from the model
  2. For each key-value pair, check if the key contains "lora_"
  3. If bias mode is 'all', also include keys containing "bias"
  4. If bias mode is 'lora_only', also include bias keys from modules that have LoRA parameters
  5. Return the filtered dictionary

Usage Examples

Save LoRA Checkpoint

import torch
import loralib as lora

# After training, save only LoRA parameters
lora_dict = lora.lora_state_dict(model, bias='none')
torch.save(lora_dict, 'lora_checkpoint.pt')

# Check saved size
import os
size_mb = os.path.getsize('lora_checkpoint.pt') / (1024 * 1024)
print(f"LoRA checkpoint size: {size_mb:.2f} MB")

Save with Biases

import torch
import loralib as lora

# Must match the bias mode used in mark_only_lora_as_trainable
lora_dict = lora.lora_state_dict(model, bias='all')
torch.save(lora_dict, 'lora_checkpoint_with_bias.pt')

Load LoRA Checkpoint

import torch
import loralib as lora

# 1. Load base pretrained model
model = load_pretrained_model()

# 2. Add LoRA layers (must match the architecture used during training)
add_lora_layers(model, r=8, lora_alpha=16)

# 3. Load LoRA checkpoint (strict=False because only LoRA params are in the dict)
lora_dict = torch.load('lora_checkpoint.pt')
model.load_state_dict(lora_dict, strict=False)

Complete Save/Load Workflow

import torch
import loralib as lora

# ===== Saving (after training) =====
lora_dict = lora.lora_state_dict(model, bias='none')
torch.save(lora_dict, 'task_adapter.pt')

# Inspect what was saved
print(f"Saved {len(lora_dict)} parameter tensors:")
for name, param in lora_dict.items():
    print(f"  {name}: {param.shape}")

# ===== Loading (for inference or continued training) =====
# Start from base model
model = load_base_model_with_lora_layers()

# Overlay LoRA weights
lora_dict = torch.load('task_adapter.pt')
model.load_state_dict(lora_dict, strict=False)

# Ready for inference
model.eval()

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment