Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL LoraParallelLinear

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, LoRA, Distributed_Computing
Last Updated 2026-02-07 20:00 GMT

Overview

LoRA adapter layers compatible with Megatron-Core tensor-parallel and expert-parallel linear layers, enabling parameter-efficient fine-tuning in distributed training settings.

Description

lora_layer.py provides a family of LoRA (Low-Rank Adaptation) layer classes that integrate with Megatron-Core's Transformer Engine (TE) linear layers. The base class LoraParallelLinear inherits from both MegatronModule and PEFT's LoraLayer, bridging the gap between the PEFT LoRA ecosystem and Megatron-Core's distributed infrastructure.

Key design decisions:

  • Sequence parallel awareness: The forward pass handles gather/scatter operations for sequence parallelism, gathering inputs before LoRA computation for column-parallel layers and scattering results for row-parallel layers.
  • Grouped GEMM support: Supports TE grouped linear layers used in Mixture-of-Experts (MoE) architectures, where multiple expert weights are batched into grouped GEMMs.
  • Dtype management: Casts inputs to the LoRA weight dtype during forward, then restores the original dtype, preventing mixed-precision issues.
  • Distributed checkpointing: Overrides sharded_state_dict() to produce checkpoint-compatible sharded tensors, including SwiGLU factory transformations for MLP layers.
  • Router LoRA: Supports applying LoRA to MoE TopKRouter layers by patching the router's gating function via a context manager.

The module provides four concrete LoRA layer classes and two utility functions:

  • LoraRouterParallelLinear for MoE router layers (non-parallel TELinear)
  • LoraRowParallelLinear for row-parallel linear layers (splits input across TP ranks)
  • LoraColumnParallelLinear for column-parallel linear layers (splits output across TP ranks)
  • dispatch_megatron() returns the correct LoRA wrapper class for a given base layer
  • apply_megatron_lora() patches PEFT's dispatch mechanism and TE layer representations

Usage

Use this module when applying LoRA fine-tuning to models loaded with Megatron-Core tensor/expert parallelism. Typically invoked indirectly through apply_megatron_lora() which patches PEFT to use these parallel-aware LoRA layers instead of standard ones.

Code Reference

Source Location

Key Classes

LoraParallelLinear

class LoraParallelLinear(MegatronModule, LoraLayer):
    def __init__(
        self,
        base_layer,
        adapter_name: str,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,
        init_lora_weights: bool = True,
        use_rslora: bool = False,
        use_dora: bool = False,
        lora_bias: bool = False,
        **kwargs,
    )

Base class for all parallel LoRA layers. Manages LoRA A/B weight creation, initialization (Kaiming uniform for A, zeros for B), scaling computation, merge/unmerge operations, and distributed checkpointing. Subclasses implement _create_lora_layers() to define the specific parallel topology of A and B matrices.

Key methods:

  • forward(x) (lines 207-260): Computes base layer output, then adds scaled LoRA delta. Handles sequence parallel gather/scatter and grouped GEMM variants.
  • merge(safe_merge, adapter_names) (lines 262-316): Merges LoRA weights into base weights, with optional NaN checking via safe_merge.
  • sharded_state_dict(prefix, sharded_offsets, metadata) (lines 318-359): Produces distributed checkpoint-compatible state dict with SwiGLU sharding for MLP fc1 layers.
  • get_delta_weights(adapter) (lines 361-382): Computes B @ A * scaling for each expert (or single weight pair).

LoraRouterParallelLinear

class LoraRouterParallelLinear(LoraParallelLinear)  # lines 385-406

LoRA layer for MoE TopKRouter modules. Uses non-parallel TELinear for both A and B matrices since router weights are not tensor-parallelized.

LoraRowParallelLinear

class LoraRowParallelLinear(LoraParallelLinear)  # lines 409-450

LoRA for row-parallel layers. The A matrix uses TERowParallelLinear (or TERowParallelGroupedLinear for MoE) to accept already-partitioned input, while B uses non-parallel TELinear (or TEGroupedLinear).

LoraColumnParallelLinear

class LoraColumnParallelLinear(LoraParallelLinear)  # lines 453-494

LoRA for column-parallel layers. The A matrix uses non-parallel TELinear (or TEGroupedLinear for MoE), while B uses TEColumnParallelLinear (or TEColumnParallelGroupedLinear) to partition output.

Key Functions

dispatch_megatron

def dispatch_megatron(
    target: torch.nn.Module,
    adapter_name: str,
    lora_config,
    **kwargs: Any,
) -> Optional[torch.nn.Module]  # lines 497-522

Factory function that inspects the base layer type and returns the appropriate LoRA wrapper. Checks for TopKRouter, TERowParallelLinear, TEColumnParallelLinear, TELayerNormColumnParallelLinear, TELinear, and TEGroupedLinear.

apply_megatron_lora

def apply_megatron_lora()  # lines 547-550

Patches PEFT's model.dispatch_megatron with the local dispatch_megatron function, patches TELinear.__repr__ for readable logging, and patches TEGroupedLinear.sharded_state_dict for compatibility.

Import

import torch
import torch.nn as nn
import torch.nn.functional as F
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.extensions.transformer_engine import (
    TEColumnParallelLinear, TERowParallelLinear, TELinear, TEGroupedLinear,
    TEColumnParallelGroupedLinear, TERowParallelGroupedLinear,
    TELayerNormColumnParallelLinear,
)
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.router import TopKRouter
from peft.tuners.lora.layer import LoraLayer

I/O Contract

Inputs

Name Type Required Description
base_layer torch.nn.Module Yes The Megatron-Core TE linear layer to wrap with LoRA
adapter_name str Yes Name identifier for the LoRA adapter (e.g., default)
r int No LoRA rank (default: 0, must be positive)
lora_alpha int No LoRA scaling factor alpha (default: 1)
lora_dropout float No Dropout probability for LoRA input (default: 0.0)
x torch.Tensor Yes Input tensor to the forward pass

Outputs

Name Type Description
result torch.Tensor Output tensor: base_layer(x) + LoRA_B(LoRA_A(dropout(x))) * scaling
bias torch.Tensor Bias tensor from the base layer (may be None)

Usage Examples

# Typical usage is indirect via apply_megatron_lora()
from mcore_adapter.adapters import apply_megatron_lora
from peft import LoraConfig, get_peft_model

# Patch PEFT to use Megatron-compatible LoRA layers
apply_megatron_lora()

# Then use PEFT as normal - it will dispatch to LoraParallelLinear variants
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["linear_fc1", "linear_fc2"])
model = get_peft_model(model, lora_config)

# Direct usage of dispatch_megatron
from mcore_adapter.adapters.lora_layer import dispatch_megatron

new_module = dispatch_megatron(
    target=some_te_linear_layer,
    adapter_name="default",
    lora_config=lora_config,
    r=16,
    lora_alpha=32,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment