Heuristic:CarperAI Trlx PEFT LoRA Integration
| Knowledge Sources | |
|---|---|
| Domains | Optimization, LLMs, Reinforcement_Learning |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Parameter-efficient fine-tuning strategy using LoRA adapters during PPO training, reducing trainable parameters by 90%+ while eliminating the need for a separate reference model.
Description
trlx supports PEFT (Parameter Efficient Fine-Tuning) through the HuggingFace PEFT library, with LoRA (Low-Rank Adaptation) as the primary method. When a `peft_config` is set in `ModelConfig`, the model is wrapped with LoRA adapters that inject low-rank weight matrices into attention layers. For KL divergence computation, the adapter is temporarily disabled via `forward_hydra()` to obtain reference logits, eliminating the need to store a full separate reference model. This is mutually exclusive with the `num_layers_unfrozen` freezing strategy.
Usage
Apply this heuristic when GPU memory is limited and you need to train models larger than what fits with full fine-tuning. LoRA is particularly effective when combined with PPO because the adapter-disable trick provides free KL divergence computation. Use when `num_layers_unfrozen` partial freezing is insufficient or when you want to preserve the entire base model's knowledge.
The Insight (Rule of Thumb)
- Action: Set `config.model.peft_config = LoraConfig(r=8, task_type=TaskType.CAUSAL_LM, lora_alpha=32, lora_dropout=0.1)`.
- Value:
- `r=8`: LoRA rank (reasonable default; range 4-32)
- `lora_alpha=32`: Scaling factor (typically 4x the rank)
- `lora_dropout=0.1`: Regularization (range 0.05-0.2)
- `task_type=TaskType.CAUSAL_LM`: Must match model type
- Trade-off: LoRA trains only ~0.1-1% of parameters but limits the model's capacity to deviate from the base. Higher rank increases capacity but also memory. The value head does NOT use PEFT (TODO in codebase).
- Incompatibility: `num_layers_unfrozen` is ignored when PEFT is active (a warning is logged). Use `LoraConfig.modules_to_save` instead.
Reasoning
LoRA injects trainable low-rank decomposition matrices (A and B) into frozen weight matrices: `W' = W + BA`. During PPO, KL divergence requires computing logits from the reference (original) model. With LoRA, this is achieved by temporarily disabling the adapter (`lora_model.disable_adapter_layers()`), which restores the original weights without keeping a separate model copy. This saves approximately 50% of GPU memory compared to the frozen head approach.
The `forward_hydra()` method handles this: for LoRA models it disables adapters, computes reference logits, then re-enables adapters. For prefix/prompt tuning, it falls back to the PeftModel's base_model directly. The value branch currently does not support PEFT (marked as TODO).
Code Evidence
LoRA configuration from `examples/ppo_sentiments_peft.py:41-47`:
config.model.peft_config = LoraConfig(
r=8,
task_type=TaskType.CAUSAL_LM,
lora_alpha=32,
lora_dropout=0.1,
)
PEFT model wrapping from `trlx/models/modeling_base.py:114-117`:
if peft_config:
if isinstance(peft_config, dict):
peft_config = get_peft_config(peft_config)
base_model = get_peft_model(base_model, peft_config)
KL divergence with adapter disable from `trlx/models/modeling_ppo.py:314-329`:
if self.peft_type and ignore_peft_adapter:
if "LORA" in self.peft_type:
lora_model = self.base_model.base_model
lora_model.disable_adapter_layers()
outputs = self.base_model(**forward_kwargs)
lora_model.enable_adapter_layers()
else:
outputs = self.base_model.base_model(**forward_kwargs)
Reference model skipped for PEFT from `trlx/trainer/accelerate_ppo_trainer.py:73-77`:
if not hasattr(self.model, "frozen_head") and not self.model.peft_type:
self.ref_model = self.get_arch(self.config)
self.ref_model.to(self.accelerator.device)
self.ref_model.eval()
PEFT overrides layer freezing from `trlx/trainer/accelerate_base_trainer.py:163-169`:
if self.config.model.num_layers_unfrozen >= 0:
logger.warning(
"The argument num_layers_unfrozen is ignored when using peft, "
"to prevent unexpected behaviour. "
"For Lora, use the `LoraConfig` argument `modules_to_save` instead."
)