Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:CarperAI Trlx Reward Model Layer Freezing

From Leeroopedia



Knowledge Sources
Domains Reinforcement_Learning, Optimization, LLMs
Last Updated 2026-02-07 16:00 GMT

Overview

Training stability technique that freezes the bottom 70% of transformer layers when training a reward model from human preference comparisons, preventing catastrophic forgetting while allowing sufficient adaptation.

Description

When training a GPT-J-based reward model on pairwise human preference data (chosen vs rejected completions), freezing the first 70% of transformer layers provides a balance between preserving pre-trained knowledge and learning reward signals. This is a higher unfrozen ratio (30%) than PPO policy training (typically 7%), because reward modeling requires more layers to adapt to the comparison task. The reward model uses a simple linear head (`v_head`) on top of the transformer backbone, scoring each token position and extracting the end-of-sequence score for comparison.

Usage

Apply this heuristic when training a reward model from scratch on pairwise comparison data (e.g., chosen vs rejected summaries). The 70/30 freeze ratio is specifically tuned for GPT-J (28 layers) but the principle generalizes to other model sizes. Combine with conservative training hyperparameters: `lr=1e-5`, `batch_size=1`, `gradient_accumulation_steps=4`, and `fp16=True`.

The Insight (Rule of Thumb)

  • Action: Freeze the first 70% of transformer layers during reward model training.
  • Value:
    • Freeze ratio: 70% bottom layers frozen, 30% top layers trainable
    • For GPT-J (28 layers): freeze layers 0-19, train layers 20-27
    • Learning rate: `1e-5` (conservative)
    • Batch size: `1` per device with `gradient_accumulation_steps=4` (effective batch 4)
    • Epochs: 5
    • Mixed precision: `fp16=True`
  • Trade-off: More unfrozen layers than PPO (30% vs 7%) because the reward task differs more from the pre-training objective than continuation generation does. Too few unfrozen layers leads to underfitting on comparison data; too many leads to overfitting and loss of general language understanding.
  • Metric: Track pairwise accuracy (% of correctly ordered chosen > rejected pairs). Good models achieve 65-75% accuracy on held-out data.

Reasoning

Reward modeling is a classification-like task (which completion is better?) overlaid on a generative model. The bottom layers capture syntax, semantics, and general knowledge that transfer well from pre-training. The top layers need to adapt to output a scalar reward signal that distinguishes quality differences between completions. The 70/30 ratio provides enough trainable capacity while preventing catastrophic forgetting.

The conservative batch size (1 per device) is necessary because each training example consists of TWO full-length sequences (chosen + rejected), effectively doubling memory per sample. Gradient accumulation compensates for the small batch. The pairwise comparison loss uses log-sigmoid of the reward difference, computed only from the divergence point onward (where chosen and rejected actually differ).

Code Evidence

Layer freezing strategy from `examples/summarize_rlhf/reward_model/train_reward_model_gptj.py:120-125`:

layers = model.transformer.h
num_layers = len(layers)
num_unfrozen = int(0.3 * num_layers)
for layer in layers[:-num_unfrozen]:
    layer.requires_grad_(False)

Conservative training configuration from `examples/summarize_rlhf/reward_model/train_reward_model_gptj.py:96-115`:

training_args = TrainingArguments(
    output_dir="rm_checkpoint/",
    num_train_epochs=5,
    gradient_accumulation_steps=4,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    warmup_steps=100,
    fp16=True,
    learning_rate=1e-5,
    deepspeed="ds_config_gpt_j.json",
)

Pairwise comparison loss from `examples/summarize_rlhf/reward_model/reward_model.py:51-89`:

bs = input_ids.shape[0] // 2
chosen = input_ids[:bs]
rejected = input_ids[bs:]
# ...
loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()

Dataset quality filtering from `examples/summarize_rlhf/reward_model/train_reward_model_gptj.py:11-26`:

if chosen_summary == rejected_summary:
    continue  # Skip identical pairs
if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5:
    continue  # Skip very short summaries

Accuracy metric from `examples/summarize_rlhf/reward_model/train_reward_model_gptj.py:78-86`:

def compute_metrics(eval_preds):
    chosen_end_scores = eval_preds.predictions[0]
    rejected_end_scores = eval_preds.predictions[1]
    acc = sum(chosen_end_scores > rejected_end_scores) / len(rejected_end_scores)
    result["accuracy"] = acc

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment