Heuristic:CarperAI Trlx Partial Layer Freezing
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Reinforcement_Learning, LLMs |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Parameter efficiency strategy that freezes the bottom transformer layers during PPO training, training only the top 2-8 layers to reduce memory usage and prevent catastrophic forgetting.
Description
In trlx PPO training, `num_layers_unfrozen` controls how many of the top (last) transformer decoder layers remain trainable. All lower layers and the input embeddings are frozen. When partial freezing is active, a frozen head (deep copy of the unfrozen layers) is maintained as the reference model for KL divergence computation, avoiding the need to keep a full separate reference model in memory. This is distinct from PEFT/LoRA, which is handled separately and overrides the freezing mechanism.
Usage
Apply this heuristic when training large models with PPO (6B+ parameters) where keeping a full reference model would double GPU memory requirements. Set `num_layers_unfrozen=2` as a conservative default. Use `-1` to train all layers (full fine-tuning) or `0` to freeze everything (embeddings only).
The Insight (Rule of Thumb)
- Action: Set `num_layers_unfrozen=2` in `ModelConfig` for large model PPO training.
- Value:
- `num_layers_unfrozen=2`: Standard for 6B-20B models (trains ~7% of layers)
- `num_layers_unfrozen=8`: More aggressive, used for summarization tasks (GPT-J)
- `num_layers_unfrozen=-1`: Full fine-tuning (all layers trainable, requires separate ref model)
- `num_layers_unfrozen=0`: Freeze all layers (only value/reward heads train)
- Trade-off: Fewer unfrozen layers means less memory and faster training but limited model capacity to adapt. The frozen head trick saves ~50% memory by avoiding a full reference model copy.
- Incompatibility: When PEFT/LoRA is active, `num_layers_unfrozen` is ignored with a warning.
Reasoning
The intuition is that lower transformer layers capture general linguistic knowledge that does not need to change during RLHF, while the top layers handle task-specific generation patterns. By freezing the bottom layers, we:
- Reduce trainable parameters and memory for optimizer states
- Use frozen layers as an implicit reference model (via the "frozen head" architecture)
- Prevent catastrophic forgetting of pre-trained knowledge
- Enable training much larger models on limited hardware
The frozen head architecture (`AutoModelForCausalLMWithHydraValueHead`) creates a deep copy of only the unfrozen decoder layers + final norm + LM head. This copy serves as the reference model for KL divergence, using `forward_hydra()` instead of a separate full model forward pass.
Code Evidence
Causal model freezing from `trlx/utils/modeling.py:22-39`:
def freeze_bottom_causal_layers(model, num_layers_unfrozen: int = 0):
hidden_layers = hf_get_decoder_blocks(model)
if num_layers_unfrozen == 0:
hidden_layers_to_freeze = list(hidden_layers)
hidden_layers_to_freeze += [model.get_input_embeddings(),
model.get_output_embeddings()]
elif num_layers_unfrozen > 0:
hidden_layers_to_freeze = list(hidden_layers)[:-num_layers_unfrozen]
hidden_layers_to_freeze += [model.get_input_embeddings()]
if model.config.tie_word_embeddings:
hidden_layers_to_freeze += [model.get_output_embeddings()]
for layer in hidden_layers_to_freeze:
layer.requires_grad_(False)
Frozen head creation from `trlx/models/modeling_ppo.py:389-408`:
class AutoModelForCausalLMWithHydraValueHead(AutoModelForCausalLMWithValueHead):
_supported_args = ["num_layers_unfrozen", "peft_config", "num_value_layers_unfrozen"]
def __init__(self, base_model, *, num_layers_unfrozen=-1, peft_config=None, ...):
super().__init__(base_model, peft_config=peft_config, ...)
self.num_layers_unfrozen = num_layers_unfrozen
if self.num_layers_unfrozen > 0 and not self.peft_type:
config = self.base_model.config
branch_class = hf_get_branch_class(config)
self.frozen_head = branch_class(
self.base_model, num_layers_unfrozen=self.num_layers_unfrozen,
).eval()
PEFT incompatibility warning from `trlx/trainer/accelerate_base_trainer.py:163-169`:
if self.config.model.num_layers_unfrozen >= 0:
logger.warning(
"The argument num_layers_unfrozen is ignored when using peft, "
"to prevent unexpected behaviour. "
"For Lora, use the `LoraConfig` argument `modules_to_save` instead."
)
Real-world usage from `examples/hh/ppo_hh.py:39`:
num_layers_unfrozen=2 # GPT-J-6B: 2 of 28 layers