Implementation:Microsoft LoRA LoRA Layers
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, Parameter_Efficient_Fine_Tuning |
| Last Updated | 2026-02-10 05:00 GMT |
Overview
Concrete implementation of LoRA-augmented PyTorch layer classes that serve as drop-in replacements for standard layers.
Description
The loralib/layers.py module defines the LoRALayer base class and four concrete layer types: Linear, Embedding, MergedLinear, and ConvLoRA (with Conv1d, Conv2d, Conv3d variants). Each layer adds low-rank trainable matrices A and B alongside the frozen pretrained weights, enabling parameter-efficient fine-tuning.
Usage
Replace target layers in a pretrained model with their LoRA equivalents. Typically applied to attention projection layers (Q, K, V, O) in transformer models. After replacement, call mark_only_lora_as_trainable to freeze all non-LoRA parameters.
Code Reference
Source Location
- Repository: microsoft/LoRA
- File: loralib/layers.py
- Lines: 12-311
Signatures
LoRALayer.__init__
class LoRALayer:
def __init__(self, r: int, lora_alpha: int, lora_dropout: float, merge_weights: bool):
"""Base class for all LoRA layers.
Args:
r: Rank of the low-rank decomposition
lora_alpha: Scaling factor (applied as lora_alpha / r)
lora_dropout: Dropout probability for LoRA path
merge_weights: If True, merge LoRA weights into base weights on eval()
"""
Linear.__init__
class Linear(nn.Linear, LoRALayer):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False,
merge_weights: bool = True,
**kwargs
):
"""LoRA-augmented linear layer.
Args:
in_features: Size of input features
out_features: Size of output features
r: LoRA rank (0 disables LoRA)
lora_alpha: Scaling factor
lora_dropout: Dropout on LoRA path
fan_in_fan_out: Set True if weight is stored as (fan_in, fan_out) e.g. GPT-2
merge_weights: Merge LoRA weights on eval()
"""
Embedding.__init__
class Embedding(nn.Embedding, LoRALayer):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
r: int = 0,
lora_alpha: int = 1,
merge_weights: bool = True,
**kwargs
):
"""LoRA-augmented embedding layer.
Args:
num_embeddings: Size of the embedding dictionary
embedding_dim: Size of each embedding vector
r: LoRA rank (0 disables LoRA)
lora_alpha: Scaling factor
merge_weights: Merge LoRA weights on eval()
"""
MergedLinear.__init__
class MergedLinear(nn.Linear, LoRALayer):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
enable_lora: list = [False],
fan_in_fan_out: bool = False,
merge_weights: bool = True,
**kwargs
):
"""LoRA-augmented merged linear layer for combined projections (e.g., QKV).
Args:
in_features: Size of input features
out_features: Total size of all merged output features
r: LoRA rank (0 disables LoRA)
lora_alpha: Scaling factor
lora_dropout: Dropout on LoRA path
enable_lora: List of booleans indicating which merged outputs get LoRA
fan_in_fan_out: Set True if weight is stored as (fan_in, fan_out)
merge_weights: Merge LoRA weights on eval()
"""
Conv2d (via ConvLoRA)
class ConvLoRA(nn.Module, LoRALayer):
def __init__(self, conv_module, in_channels, out_channels, kernel_size,
r=0, lora_alpha=1, lora_dropout=0.0, merge_weights=True, **kwargs):
"""Base class for LoRA-augmented convolution layers."""
class Conv2d(ConvLoRA):
def __init__(self, *args, **kwargs):
"""LoRA-augmented 2D convolution layer.
Passes nn.Conv2d as the conv_module to ConvLoRA.
"""
Import
from loralib import Linear, Embedding, MergedLinear, Conv2d
# or
import loralib as lora
# then use lora.Linear, lora.Embedding, lora.MergedLinear, lora.Conv2d
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| in_features / num_embeddings / in_channels | int | Yes | Input dimension (same as corresponding nn module) |
| out_features / embedding_dim / out_channels | int | Yes | Output dimension (same as corresponding nn module) |
| r | int | No (default 0) | LoRA rank; 0 disables LoRA and layer acts as standard PyTorch layer |
| lora_alpha | int | No (default 1) | Scaling constant; effective scale is lora_alpha / r |
| lora_dropout | float | No (default 0.0) | Dropout probability applied to input on the LoRA path |
| merge_weights | bool | No (default True) | Whether to merge LoRA weights into base weights during eval() |
| enable_lora | list[bool] | MergedLinear only | Which merged output segments receive LoRA adaptation |
| fan_in_fan_out | bool | Linear/MergedLinear only | True if weight stored transposed (e.g., GPT-2 Conv1D) |
Outputs
| Name | Type | Description |
|---|---|---|
| layer | nn.Module | Drop-in replacement layer with identical forward signature to the original PyTorch module |
Usage Examples
Replace a Standard Linear Layer
import loralib as lora
# Before: self.linear = nn.Linear(768, 768)
# After:
self.linear = lora.Linear(768, 768, r=8, lora_alpha=16)
Replace Combined QKV Attention Projection
import loralib as lora
# For GPT-2 style combined QKV projection where Q, K, V are concatenated
# Apply LoRA to Q and V but not K:
self.c_attn = lora.MergedLinear(
768, 768 * 3,
r=4,
lora_alpha=8,
enable_lora=[True, False, True],
fan_in_fan_out=True
)
Replace an Embedding Layer
import loralib as lora
# Before: self.embed = nn.Embedding(50257, 768)
# After:
self.embed = lora.Embedding(50257, 768, r=4, lora_alpha=8)
Replace a Conv2d Layer
import loralib as lora
# Before: self.conv = nn.Conv2d(3, 64, kernel_size=3, padding=1)
# After:
self.conv = lora.Conv2d(3, 64, kernel_size=3, padding=1, r=4, lora_alpha=8)
Complete Model Modification Example
import torch.nn as nn
import loralib as lora
class MyModel(nn.Module):
def __init__(self):
super().__init__()
# Replace target layers with LoRA versions
self.query = lora.Linear(768, 768, r=8, lora_alpha=16)
self.value = lora.Linear(768, 768, r=8, lora_alpha=16)
# Keep non-target layers as standard PyTorch
self.key = nn.Linear(768, 768)
self.ffn = nn.Linear(768, 3072)
def forward(self, x):
q = self.query(x) # LoRA-augmented
k = self.key(x) # Standard (frozen)
v = self.value(x) # LoRA-augmented
return q, k, v