Implementation:Hpcaitech ColossalAI LoRAConstructor
| Knowledge Sources | |
|---|---|
| Domains | LoRA, Parameter Efficient Fine-Tuning, Distributed Training |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
A utility class for reconstructing model weight increments from LoRA parameters, enabling efficient transfer of only LoRA weights between remote Ray actors in the distributed RLHF pipeline.
Description
LoRAConstructor provides tools for extracting, filtering, transferring, and applying LoRA (Low-Rank Adaptation) weight updates between sender and receiver components in the distributed ColossalChat training system. The class implements a multi-step protocol: the sender filters the state dict to extract only LoRA parameters (lora_A and lora_B matrices), optionally extracts LoRA configuration from LoraLinear modules, and sends both to the receiver. The receiver then reconstructs the full weight increment via reconstruct_increase (computing lora_B @ lora_A * scaling) and applies it to the model via load_state_dict_increase.
The companion LoRAConfig dataclass stores per-layer LoRA hyperparameters including rank, alpha, dropout, and fan_in_fan_out settings.
Usage
Use LoRAConstructor when update_lora_weights=True is set in DetachedPPOTrainer and ExperienceMakerHolder to reduce communication overhead during model weight synchronization. Instead of transferring the full model state dict, only LoRA parameters are sent and the weight increments are reconstructed on the receiver side.
Code Reference
Source Location
- Repository: Hpcaitech_ColossalAI
- File: applications/ColossalChat/coati/ray/lora_constructor.py
- Lines: 1-123
Signature
@dataclass
class LoRAConfig:
r: int = 0
lora_alpha: int = 1
lora_dropout: float = 0
fan_in_fan_out: bool = False
class LoRAConstructor:
def __init__(self): ...
def register_lora_config(self, lora_config_dict: Dict[str, Any]): ...
def reconstruct_increase(
self,
state_dict_lora: Dict[str, Any],
lora_config_dict: Dict[str, Any],
) -> OrderedDict: ...
def load_state_dict_increase(
self,
model: nn.Module,
state_dict_increase: Dict[str, Any],
): ...
@staticmethod
def filter_state_dict_lora(
state_dict: Dict[str, Any],
keep_non_lora: bool = False,
) -> Tuple[OrderedDict, Optional[OrderedDict]]: ...
@staticmethod
def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]: ...
Import
from coati.ray.lora_constructor import LoRAConstructor, LoRAConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| state_dict_lora | Dict[str, Any] | Yes | State dict containing only lora_A and lora_B parameters |
| lora_config_dict | Dict[str, Any] | Yes | Ordered dict mapping layer names to LoRAConfig instances |
| model | nn.Module | Yes | The target model for applying weight increments or extracting config |
| state_dict | Dict[str, Any] | Yes | Full model state dict to filter (for filter_state_dict_lora) |
| keep_non_lora | bool | No | Whether to also return non-LoRA parameters (default False) |
Outputs
| Name | Type | Description |
|---|---|---|
| reconstruct_increase return | OrderedDict | State dict with reconstructed weight increments (layer_name.weight -> increment tensor) |
| filter_state_dict_lora return | Tuple[OrderedDict, Optional[OrderedDict]] | Tuple of (lora_state_dict, non_lora_state_dict or None) |
| extract_lora_config return | Dict[str, LoRAConfig] | Ordered dict mapping layer names to LoRAConfig |
Usage Examples
from coati.ray.lora_constructor import LoRAConstructor, LoRAConfig
# Sender side: filter and extract
constructor = LoRAConstructor()
state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(model.state_dict())
lora_config = LoRAConstructor.extract_lora_config(model)
# Receiver side: reconstruct and apply
state_dict_increase = constructor.reconstruct_increase(state_dict_lora, lora_config)
constructor.load_state_dict_increase(receiver_model, state_dict_increase)