Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:CarperAI Trlx PEFT LoRA Integration

From Leeroopedia




Knowledge Sources
Domains Optimization, LLMs, Reinforcement_Learning
Last Updated 2026-02-07 16:00 GMT

Overview

Parameter-efficient fine-tuning strategy using LoRA adapters during PPO training, reducing trainable parameters by 90%+ while eliminating the need for a separate reference model.

Description

trlx supports PEFT (Parameter Efficient Fine-Tuning) through the HuggingFace PEFT library, with LoRA (Low-Rank Adaptation) as the primary method. When a `peft_config` is set in `ModelConfig`, the model is wrapped with LoRA adapters that inject low-rank weight matrices into attention layers. For KL divergence computation, the adapter is temporarily disabled via `forward_hydra()` to obtain reference logits, eliminating the need to store a full separate reference model. This is mutually exclusive with the `num_layers_unfrozen` freezing strategy.

Usage

Apply this heuristic when GPU memory is limited and you need to train models larger than what fits with full fine-tuning. LoRA is particularly effective when combined with PPO because the adapter-disable trick provides free KL divergence computation. Use when `num_layers_unfrozen` partial freezing is insufficient or when you want to preserve the entire base model's knowledge.

The Insight (Rule of Thumb)

  • Action: Set `config.model.peft_config = LoraConfig(r=8, task_type=TaskType.CAUSAL_LM, lora_alpha=32, lora_dropout=0.1)`.
  • Value:
    • `r=8`: LoRA rank (reasonable default; range 4-32)
    • `lora_alpha=32`: Scaling factor (typically 4x the rank)
    • `lora_dropout=0.1`: Regularization (range 0.05-0.2)
    • `task_type=TaskType.CAUSAL_LM`: Must match model type
  • Trade-off: LoRA trains only ~0.1-1% of parameters but limits the model's capacity to deviate from the base. Higher rank increases capacity but also memory. The value head does NOT use PEFT (TODO in codebase).
  • Incompatibility: `num_layers_unfrozen` is ignored when PEFT is active (a warning is logged). Use `LoraConfig.modules_to_save` instead.

Reasoning

LoRA injects trainable low-rank decomposition matrices (A and B) into frozen weight matrices: `W' = W + BA`. During PPO, KL divergence requires computing logits from the reference (original) model. With LoRA, this is achieved by temporarily disabling the adapter (`lora_model.disable_adapter_layers()`), which restores the original weights without keeping a separate model copy. This saves approximately 50% of GPU memory compared to the frozen head approach.

The `forward_hydra()` method handles this: for LoRA models it disables adapters, computes reference logits, then re-enables adapters. For prefix/prompt tuning, it falls back to the PeftModel's base_model directly. The value branch currently does not support PEFT (marked as TODO).

Code Evidence

LoRA configuration from `examples/ppo_sentiments_peft.py:41-47`:

config.model.peft_config = LoraConfig(
    r=8,
    task_type=TaskType.CAUSAL_LM,
    lora_alpha=32,
    lora_dropout=0.1,
)

PEFT model wrapping from `trlx/models/modeling_base.py:114-117`:

if peft_config:
    if isinstance(peft_config, dict):
        peft_config = get_peft_config(peft_config)
    base_model = get_peft_model(base_model, peft_config)

KL divergence with adapter disable from `trlx/models/modeling_ppo.py:314-329`:

if self.peft_type and ignore_peft_adapter:
    if "LORA" in self.peft_type:
        lora_model = self.base_model.base_model
        lora_model.disable_adapter_layers()
        outputs = self.base_model(**forward_kwargs)
        lora_model.enable_adapter_layers()
    else:
        outputs = self.base_model.base_model(**forward_kwargs)

Reference model skipped for PEFT from `trlx/trainer/accelerate_ppo_trainer.py:73-77`:

if not hasattr(self.model, "frozen_head") and not self.model.peft_type:
    self.ref_model = self.get_arch(self.config)
    self.ref_model.to(self.accelerator.device)
    self.ref_model.eval()

PEFT overrides layer freezing from `trlx/trainer/accelerate_base_trainer.py:163-169`:

if self.config.model.num_layers_unfrozen >= 0:
    logger.warning(
        "The argument num_layers_unfrozen is ignored when using peft, "
        "to prevent unexpected behaviour. "
        "For Lora, use the `LoraConfig` argument `modules_to_save` instead."
    )

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment