Principle:Alibaba ROLL Megatron LoRA Adaptation
| Knowledge Sources | |
|---|---|
| Domains | Model_Architecture, LoRA, Distributed_Computing |
| Last Updated | 2026-02-07 20:00 GMT |
Overview
Parameter-efficient adaptation of linear layers in tensor-parallel and expert-parallel distributed models through low-rank matrix decomposition.
Description
Low-Rank Adaptation (LoRA) is a technique that freezes the original model weights and injects trainable rank-decomposition matrices into each target layer. The core idea is that weight updates during fine-tuning occupy a low-rank subspace, so a full-rank gradient update can be approximated by two small matrices and where .
In a distributed setting with tensor parallelism, model weights are sharded across GPUs along specific dimensions. Applying LoRA naively would break the parallelism invariants because the low-rank matrices must respect the same sharding scheme as their base layers. This principle addresses the problem by creating LoRA adapter layers that are parallelism-aware: row-parallel base layers get row-parallel LoRA-A matrices, column-parallel base layers get column-parallel LoRA-B matrices, and grouped expert layers get grouped LoRA matrices that preserve the expert-parallel layout.
The design also handles sequence parallelism interactions: when sequence parallelism is active, the LoRA forward path must gather inputs from the sequence-parallel region before applying the low-rank computation through LoRA-A, then scatter the result back after LoRA-B for row-parallel layers.
Usage
Use this principle when:
- Fine-tuning a large language model that is deployed with Megatron-Core tensor or expert parallelism, and full-parameter training is too expensive.
- The model uses Transformer Engine linear layers (TEColumnParallelLinear, TERowParallelLinear, TEGroupedLinear) and you need adapter layers that preserve the existing parallel communication patterns.
- You want to apply LoRA to Mixture-of-Experts routers or grouped expert layers while maintaining correct sharded checkpointing.
Theoretical Basis
The standard LoRA update for a pretrained weight matrix is:
where , , is the scaling factor, and is the rank.
For a column-parallel base layer where the output dimension is sharded across tensor-parallel ranks:
LoRA_A: TELinear(in=d, out=r) # not parallelized (full input) LoRA_B: TEColumnParallelLinear(in=r, out=k*T) # sharded on output dim
For a row-parallel base layer where the input dimension is sharded:
LoRA_A: TERowParallelLinear(in=d*T, out=r, input_is_parallel=True) # sharded on input dim LoRA_B: TELinear(in=r, out=k) # not parallelized (full output)
For grouped expert layers with local experts:
LoRA_A: TEGroupedLinear(num_gemms=E, in=d, out=r/topk) LoRA_B: TEColumnParallelGroupedLinear(num_gemms=E, in=r/topk, out=k*T)
The forward pass computes:
Failed to parse (syntax error): {\displaystyle \text{output} = \text{base\_layer}(x) + \text{scaling} \cdot B(A(\text{dropout}(x)))}
When sequence parallelism is active on a column-parallel layer, must first be gathered from the sequence-parallel region before the LoRA computation. For row-parallel layers, the LoRA result must be scattered back to the sequence-parallel region.
The scaling factor is:
- Standard:
- RSLoRA:
Weight initialization follows the original LoRA paper: is initialized with Kaiming uniform, and is initialized to zero, ensuring the adapter starts as an identity function.