Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hpcaitech ColossalAI LoRAConstructor

From Leeroopedia


Knowledge Sources
Domains LoRA, Parameter Efficient Fine-Tuning, Distributed Training
Last Updated 2026-02-09 00:00 GMT

Overview

A utility class for reconstructing model weight increments from LoRA parameters, enabling efficient transfer of only LoRA weights between remote Ray actors in the distributed RLHF pipeline.

Description

LoRAConstructor provides tools for extracting, filtering, transferring, and applying LoRA (Low-Rank Adaptation) weight updates between sender and receiver components in the distributed ColossalChat training system. The class implements a multi-step protocol: the sender filters the state dict to extract only LoRA parameters (lora_A and lora_B matrices), optionally extracts LoRA configuration from LoraLinear modules, and sends both to the receiver. The receiver then reconstructs the full weight increment via reconstruct_increase (computing lora_B @ lora_A * scaling) and applies it to the model via load_state_dict_increase.

The companion LoRAConfig dataclass stores per-layer LoRA hyperparameters including rank, alpha, dropout, and fan_in_fan_out settings.

Usage

Use LoRAConstructor when update_lora_weights=True is set in DetachedPPOTrainer and ExperienceMakerHolder to reduce communication overhead during model weight synchronization. Instead of transferring the full model state dict, only LoRA parameters are sent and the weight increments are reconstructed on the receiver side.

Code Reference

Source Location

Signature

@dataclass
class LoRAConfig:
    r: int = 0
    lora_alpha: int = 1
    lora_dropout: float = 0
    fan_in_fan_out: bool = False

class LoRAConstructor:
    def __init__(self): ...
    def register_lora_config(self, lora_config_dict: Dict[str, Any]): ...
    def reconstruct_increase(
        self,
        state_dict_lora: Dict[str, Any],
        lora_config_dict: Dict[str, Any],
    ) -> OrderedDict: ...
    def load_state_dict_increase(
        self,
        model: nn.Module,
        state_dict_increase: Dict[str, Any],
    ): ...

    @staticmethod
    def filter_state_dict_lora(
        state_dict: Dict[str, Any],
        keep_non_lora: bool = False,
    ) -> Tuple[OrderedDict, Optional[OrderedDict]]: ...

    @staticmethod
    def extract_lora_config(model: nn.Module) -> Dict[str, LoRAConfig]: ...

Import

from coati.ray.lora_constructor import LoRAConstructor, LoRAConfig

I/O Contract

Inputs

Name Type Required Description
state_dict_lora Dict[str, Any] Yes State dict containing only lora_A and lora_B parameters
lora_config_dict Dict[str, Any] Yes Ordered dict mapping layer names to LoRAConfig instances
model nn.Module Yes The target model for applying weight increments or extracting config
state_dict Dict[str, Any] Yes Full model state dict to filter (for filter_state_dict_lora)
keep_non_lora bool No Whether to also return non-LoRA parameters (default False)

Outputs

Name Type Description
reconstruct_increase return OrderedDict State dict with reconstructed weight increments (layer_name.weight -> increment tensor)
filter_state_dict_lora return Tuple[OrderedDict, Optional[OrderedDict]] Tuple of (lora_state_dict, non_lora_state_dict or None)
extract_lora_config return Dict[str, LoRAConfig] Ordered dict mapping layer names to LoRAConfig

Usage Examples

from coati.ray.lora_constructor import LoRAConstructor, LoRAConfig

# Sender side: filter and extract
constructor = LoRAConstructor()
state_dict_lora, _ = LoRAConstructor.filter_state_dict_lora(model.state_dict())
lora_config = LoRAConstructor.extract_lora_config(model)

# Receiver side: reconstruct and apply
state_dict_increase = constructor.reconstruct_increase(state_dict_lora, lora_config)
constructor.load_state_dict_increase(receiver_model, state_dict_increase)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment