Implementation:Microsoft LoRA Lora State Dict
| Knowledge Sources | |
|---|---|
| Domains | Serialization, Parameter_Efficient_Fine_Tuning |
| Last Updated | 2026-02-10 05:00 GMT |
Overview
Utility function that extracts only LoRA parameters (and optionally biases) from a model's state dict for compact checkpoint saving.
Description
The lora_state_dict function filters a model's full state dictionary to return only entries corresponding to LoRA parameters. This filtered dictionary can then be saved with torch.save to produce a compact checkpoint file containing only the task-specific adaptations.
Usage
Call this function whenever you need to save a LoRA checkpoint. Use the same bias mode that was passed to mark_only_lora_as_trainable during model preparation.
Code Reference
Source Location
- Repository: microsoft/LoRA
- File: loralib/utils.py
- Lines: 33-49
Signature
def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
"""Extract LoRA parameters from model state dict.
Args:
model: The PyTorch model containing LoRA layers
bias: Bias handling mode - 'none', 'all', or 'lora_only'
Returns:
Dictionary mapping parameter names to tensors, containing only
LoRA parameters (and optionally biases)
"""
Import
from loralib import lora_state_dict
# or
import loralib as lora
# then use lora.lora_state_dict
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | PyTorch model containing LoRA-augmented layers |
| bias | str | No (default 'none') | Bias handling mode: 'none', 'all', or 'lora_only' |
Outputs
| Name | Type | Description |
|---|---|---|
| state_dict | Dict[str, torch.Tensor] | Filtered state dictionary containing only LoRA parameters (and optionally biases) |
Bias Mode Filter Details
| bias Value | Keys Included in Output Dict |
|---|---|
| none | Only keys containing "lora_" (e.g., "transformer.h.0.attn.c_attn.lora_A", "transformer.h.0.attn.c_attn.lora_B") |
| all | Keys containing "lora_" + all keys containing "bias" (e.g., "transformer.h.0.attn.c_attn.bias", "transformer.ln_f.bias") |
| lora_only | Keys containing "lora_" + bias keys from modules that also have lora_ parameters |
Implementation Details
The function works by calling model.state_dict() and filtering the resulting dictionary:
- Collect the full state dictionary from the model
- For each key-value pair, check if the key contains "lora_"
- If bias mode is 'all', also include keys containing "bias"
- If bias mode is 'lora_only', also include bias keys from modules that have LoRA parameters
- Return the filtered dictionary
Usage Examples
Save LoRA Checkpoint
import torch
import loralib as lora
# After training, save only LoRA parameters
lora_dict = lora.lora_state_dict(model, bias='none')
torch.save(lora_dict, 'lora_checkpoint.pt')
# Check saved size
import os
size_mb = os.path.getsize('lora_checkpoint.pt') / (1024 * 1024)
print(f"LoRA checkpoint size: {size_mb:.2f} MB")
Save with Biases
import torch
import loralib as lora
# Must match the bias mode used in mark_only_lora_as_trainable
lora_dict = lora.lora_state_dict(model, bias='all')
torch.save(lora_dict, 'lora_checkpoint_with_bias.pt')
Load LoRA Checkpoint
import torch
import loralib as lora
# 1. Load base pretrained model
model = load_pretrained_model()
# 2. Add LoRA layers (must match the architecture used during training)
add_lora_layers(model, r=8, lora_alpha=16)
# 3. Load LoRA checkpoint (strict=False because only LoRA params are in the dict)
lora_dict = torch.load('lora_checkpoint.pt')
model.load_state_dict(lora_dict, strict=False)
Complete Save/Load Workflow
import torch
import loralib as lora
# ===== Saving (after training) =====
lora_dict = lora.lora_state_dict(model, bias='none')
torch.save(lora_dict, 'task_adapter.pt')
# Inspect what was saved
print(f"Saved {len(lora_dict)} parameter tensors:")
for name, param in lora_dict.items():
print(f" {name}: {param.shape}")
# ===== Loading (for inference or continued training) =====
# Start from base model
model = load_base_model_with_lora_layers()
# Overlay LoRA weights
lora_dict = torch.load('task_adapter.pt')
model.load_state_dict(lora_dict, strict=False)
# Ready for inference
model.eval()