Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:CarperAI Trlx Partial Layer Freezing

From Leeroopedia




Knowledge Sources
Domains Optimization, Reinforcement_Learning, LLMs
Last Updated 2026-02-07 16:00 GMT

Overview

Parameter efficiency strategy that freezes the bottom transformer layers during PPO training, training only the top 2-8 layers to reduce memory usage and prevent catastrophic forgetting.

Description

In trlx PPO training, `num_layers_unfrozen` controls how many of the top (last) transformer decoder layers remain trainable. All lower layers and the input embeddings are frozen. When partial freezing is active, a frozen head (deep copy of the unfrozen layers) is maintained as the reference model for KL divergence computation, avoiding the need to keep a full separate reference model in memory. This is distinct from PEFT/LoRA, which is handled separately and overrides the freezing mechanism.

Usage

Apply this heuristic when training large models with PPO (6B+ parameters) where keeping a full reference model would double GPU memory requirements. Set `num_layers_unfrozen=2` as a conservative default. Use `-1` to train all layers (full fine-tuning) or `0` to freeze everything (embeddings only).

The Insight (Rule of Thumb)

  • Action: Set `num_layers_unfrozen=2` in `ModelConfig` for large model PPO training.
  • Value:
    • `num_layers_unfrozen=2`: Standard for 6B-20B models (trains ~7% of layers)
    • `num_layers_unfrozen=8`: More aggressive, used for summarization tasks (GPT-J)
    • `num_layers_unfrozen=-1`: Full fine-tuning (all layers trainable, requires separate ref model)
    • `num_layers_unfrozen=0`: Freeze all layers (only value/reward heads train)
  • Trade-off: Fewer unfrozen layers means less memory and faster training but limited model capacity to adapt. The frozen head trick saves ~50% memory by avoiding a full reference model copy.
  • Incompatibility: When PEFT/LoRA is active, `num_layers_unfrozen` is ignored with a warning.

Reasoning

The intuition is that lower transformer layers capture general linguistic knowledge that does not need to change during RLHF, while the top layers handle task-specific generation patterns. By freezing the bottom layers, we:

  1. Reduce trainable parameters and memory for optimizer states
  2. Use frozen layers as an implicit reference model (via the "frozen head" architecture)
  3. Prevent catastrophic forgetting of pre-trained knowledge
  4. Enable training much larger models on limited hardware

The frozen head architecture (`AutoModelForCausalLMWithHydraValueHead`) creates a deep copy of only the unfrozen decoder layers + final norm + LM head. This copy serves as the reference model for KL divergence, using `forward_hydra()` instead of a separate full model forward pass.

Code Evidence

Causal model freezing from `trlx/utils/modeling.py:22-39`:

def freeze_bottom_causal_layers(model, num_layers_unfrozen: int = 0):
    hidden_layers = hf_get_decoder_blocks(model)
    if num_layers_unfrozen == 0:
        hidden_layers_to_freeze = list(hidden_layers)
        hidden_layers_to_freeze += [model.get_input_embeddings(),
                                     model.get_output_embeddings()]
    elif num_layers_unfrozen > 0:
        hidden_layers_to_freeze = list(hidden_layers)[:-num_layers_unfrozen]
        hidden_layers_to_freeze += [model.get_input_embeddings()]
        if model.config.tie_word_embeddings:
            hidden_layers_to_freeze += [model.get_output_embeddings()]
    for layer in hidden_layers_to_freeze:
        layer.requires_grad_(False)

Frozen head creation from `trlx/models/modeling_ppo.py:389-408`:

class AutoModelForCausalLMWithHydraValueHead(AutoModelForCausalLMWithValueHead):
    _supported_args = ["num_layers_unfrozen", "peft_config", "num_value_layers_unfrozen"]

    def __init__(self, base_model, *, num_layers_unfrozen=-1, peft_config=None, ...):
        super().__init__(base_model, peft_config=peft_config, ...)
        self.num_layers_unfrozen = num_layers_unfrozen
        if self.num_layers_unfrozen > 0 and not self.peft_type:
            config = self.base_model.config
            branch_class = hf_get_branch_class(config)
            self.frozen_head = branch_class(
                self.base_model, num_layers_unfrozen=self.num_layers_unfrozen,
            ).eval()

PEFT incompatibility warning from `trlx/trainer/accelerate_base_trainer.py:163-169`:

if self.config.model.num_layers_unfrozen >= 0:
    logger.warning(
        "The argument num_layers_unfrozen is ignored when using peft, "
        "to prevent unexpected behaviour. "
        "For Lora, use the `LoraConfig` argument `modules_to_save` instead."
    )

Real-world usage from `examples/hh/ppo_hh.py:39`:

num_layers_unfrozen=2  # GPT-J-6B: 2 of 28 layers

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment