Implementation:Alibaba ROLL LoraParallelLinear
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, LoRA, Distributed_Computing |
| Last Updated | 2026-02-07 20:00 GMT |
Overview
LoRA adapter layers compatible with Megatron-Core tensor-parallel and expert-parallel linear layers, enabling parameter-efficient fine-tuning in distributed training settings.
Description
lora_layer.py provides a family of LoRA (Low-Rank Adaptation) layer classes that integrate with Megatron-Core's Transformer Engine (TE) linear layers. The base class LoraParallelLinear inherits from both MegatronModule and PEFT's LoraLayer, bridging the gap between the PEFT LoRA ecosystem and Megatron-Core's distributed infrastructure.
Key design decisions:
- Sequence parallel awareness: The forward pass handles gather/scatter operations for sequence parallelism, gathering inputs before LoRA computation for column-parallel layers and scattering results for row-parallel layers.
- Grouped GEMM support: Supports TE grouped linear layers used in Mixture-of-Experts (MoE) architectures, where multiple expert weights are batched into grouped GEMMs.
- Dtype management: Casts inputs to the LoRA weight dtype during forward, then restores the original dtype, preventing mixed-precision issues.
- Distributed checkpointing: Overrides sharded_state_dict() to produce checkpoint-compatible sharded tensors, including SwiGLU factory transformations for MLP layers.
- Router LoRA: Supports applying LoRA to MoE TopKRouter layers by patching the router's gating function via a context manager.
The module provides four concrete LoRA layer classes and two utility functions:
- LoraRouterParallelLinear for MoE router layers (non-parallel TELinear)
- LoraRowParallelLinear for row-parallel linear layers (splits input across TP ranks)
- LoraColumnParallelLinear for column-parallel linear layers (splits output across TP ranks)
- dispatch_megatron() returns the correct LoRA wrapper class for a given base layer
- apply_megatron_lora() patches PEFT's dispatch mechanism and TE layer representations
Usage
Use this module when applying LoRA fine-tuning to models loaded with Megatron-Core tensor/expert parallelism. Typically invoked indirectly through apply_megatron_lora() which patches PEFT to use these parallel-aware LoRA layers instead of standard ones.
Code Reference
Source Location
- Repository: Alibaba_ROLL
- File: mcore_adapter/src/mcore_adapter/adapters/lora_layer.py
- Lines: 1-550
Key Classes
LoraParallelLinear
class LoraParallelLinear(MegatronModule, LoraLayer):
def __init__(
self,
base_layer,
adapter_name: str,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False,
init_lora_weights: bool = True,
use_rslora: bool = False,
use_dora: bool = False,
lora_bias: bool = False,
**kwargs,
)
Base class for all parallel LoRA layers. Manages LoRA A/B weight creation, initialization (Kaiming uniform for A, zeros for B), scaling computation, merge/unmerge operations, and distributed checkpointing. Subclasses implement _create_lora_layers() to define the specific parallel topology of A and B matrices.
Key methods:
- forward(x) (lines 207-260): Computes base layer output, then adds scaled LoRA delta. Handles sequence parallel gather/scatter and grouped GEMM variants.
- merge(safe_merge, adapter_names) (lines 262-316): Merges LoRA weights into base weights, with optional NaN checking via safe_merge.
- sharded_state_dict(prefix, sharded_offsets, metadata) (lines 318-359): Produces distributed checkpoint-compatible state dict with SwiGLU sharding for MLP fc1 layers.
- get_delta_weights(adapter) (lines 361-382): Computes B @ A * scaling for each expert (or single weight pair).
LoraRouterParallelLinear
class LoraRouterParallelLinear(LoraParallelLinear) # lines 385-406
LoRA layer for MoE TopKRouter modules. Uses non-parallel TELinear for both A and B matrices since router weights are not tensor-parallelized.
LoraRowParallelLinear
class LoraRowParallelLinear(LoraParallelLinear) # lines 409-450
LoRA for row-parallel layers. The A matrix uses TERowParallelLinear (or TERowParallelGroupedLinear for MoE) to accept already-partitioned input, while B uses non-parallel TELinear (or TEGroupedLinear).
LoraColumnParallelLinear
class LoraColumnParallelLinear(LoraParallelLinear) # lines 453-494
LoRA for column-parallel layers. The A matrix uses non-parallel TELinear (or TEGroupedLinear for MoE), while B uses TEColumnParallelLinear (or TEColumnParallelGroupedLinear) to partition output.
Key Functions
dispatch_megatron
def dispatch_megatron(
target: torch.nn.Module,
adapter_name: str,
lora_config,
**kwargs: Any,
) -> Optional[torch.nn.Module] # lines 497-522
Factory function that inspects the base layer type and returns the appropriate LoRA wrapper. Checks for TopKRouter, TERowParallelLinear, TEColumnParallelLinear, TELayerNormColumnParallelLinear, TELinear, and TEGroupedLinear.
apply_megatron_lora
def apply_megatron_lora() # lines 547-550
Patches PEFT's model.dispatch_megatron with the local dispatch_megatron function, patches TELinear.__repr__ for readable logging, and patches TEGroupedLinear.sharded_state_dict for compatibility.
Import
import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.extensions.transformer_engine import (
TEColumnParallelLinear, TERowParallelLinear, TELinear, TEGroupedLinear,
TEColumnParallelGroupedLinear, TERowParallelGroupedLinear,
TELayerNormColumnParallelLinear,
)
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.router import TopKRouter
from peft.tuners.lora.layer import LoraLayer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| base_layer | torch.nn.Module | Yes | The Megatron-Core TE linear layer to wrap with LoRA |
| adapter_name | str | Yes | Name identifier for the LoRA adapter (e.g., default) |
| r | int | No | LoRA rank (default: 0, must be positive) |
| lora_alpha | int | No | LoRA scaling factor alpha (default: 1) |
| lora_dropout | float | No | Dropout probability for LoRA input (default: 0.0) |
| x | torch.Tensor | Yes | Input tensor to the forward pass |
Outputs
| Name | Type | Description |
|---|---|---|
| result | torch.Tensor | Output tensor: base_layer(x) + LoRA_B(LoRA_A(dropout(x))) * scaling |
| bias | torch.Tensor | Bias tensor from the base layer (may be None) |
Usage Examples
# Typical usage is indirect via apply_megatron_lora()
from mcore_adapter.adapters import apply_megatron_lora
from peft import LoraConfig, get_peft_model
# Patch PEFT to use Megatron-compatible LoRA layers
apply_megatron_lora()
# Then use PEFT as normal - it will dispatch to LoraParallelLinear variants
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["linear_fc1", "linear_fc2"])
model = get_peft_model(model, lora_config)
# Direct usage of dispatch_megatron
from mcore_adapter.adapters.lora_layer import dispatch_megatron
new_module = dispatch_megatron(
target=some_te_linear_layer,
adapter_name="default",
lora_config=lora_config,
r=16,
lora_alpha=32,
)