Heuristic:CarperAI Trlx Reward Model Layer Freezing
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Optimization, LLMs |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Training stability technique that freezes the bottom 70% of transformer layers when training a reward model from human preference comparisons, preventing catastrophic forgetting while allowing sufficient adaptation.
Description
When training a GPT-J-based reward model on pairwise human preference data (chosen vs rejected completions), freezing the first 70% of transformer layers provides a balance between preserving pre-trained knowledge and learning reward signals. This is a higher unfrozen ratio (30%) than PPO policy training (typically 7%), because reward modeling requires more layers to adapt to the comparison task. The reward model uses a simple linear head (`v_head`) on top of the transformer backbone, scoring each token position and extracting the end-of-sequence score for comparison.
Usage
Apply this heuristic when training a reward model from scratch on pairwise comparison data (e.g., chosen vs rejected summaries). The 70/30 freeze ratio is specifically tuned for GPT-J (28 layers) but the principle generalizes to other model sizes. Combine with conservative training hyperparameters: `lr=1e-5`, `batch_size=1`, `gradient_accumulation_steps=4`, and `fp16=True`.
The Insight (Rule of Thumb)
- Action: Freeze the first 70% of transformer layers during reward model training.
- Value:
- Freeze ratio: 70% bottom layers frozen, 30% top layers trainable
- For GPT-J (28 layers): freeze layers 0-19, train layers 20-27
- Learning rate: `1e-5` (conservative)
- Batch size: `1` per device with `gradient_accumulation_steps=4` (effective batch 4)
- Epochs: 5
- Mixed precision: `fp16=True`
- Trade-off: More unfrozen layers than PPO (30% vs 7%) because the reward task differs more from the pre-training objective than continuation generation does. Too few unfrozen layers leads to underfitting on comparison data; too many leads to overfitting and loss of general language understanding.
- Metric: Track pairwise accuracy (% of correctly ordered chosen > rejected pairs). Good models achieve 65-75% accuracy on held-out data.
Reasoning
Reward modeling is a classification-like task (which completion is better?) overlaid on a generative model. The bottom layers capture syntax, semantics, and general knowledge that transfer well from pre-training. The top layers need to adapt to output a scalar reward signal that distinguishes quality differences between completions. The 70/30 ratio provides enough trainable capacity while preventing catastrophic forgetting.
The conservative batch size (1 per device) is necessary because each training example consists of TWO full-length sequences (chosen + rejected), effectively doubling memory per sample. Gradient accumulation compensates for the small batch. The pairwise comparison loss uses log-sigmoid of the reward difference, computed only from the divergence point onward (where chosen and rejected actually differ).
Code Evidence
Layer freezing strategy from `examples/summarize_rlhf/reward_model/train_reward_model_gptj.py:120-125`:
layers = model.transformer.h
num_layers = len(layers)
num_unfrozen = int(0.3 * num_layers)
for layer in layers[:-num_unfrozen]:
layer.requires_grad_(False)
Conservative training configuration from `examples/summarize_rlhf/reward_model/train_reward_model_gptj.py:96-115`:
training_args = TrainingArguments(
output_dir="rm_checkpoint/",
num_train_epochs=5,
gradient_accumulation_steps=4,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
warmup_steps=100,
fp16=True,
learning_rate=1e-5,
deepspeed="ds_config_gpt_j.json",
)
Pairwise comparison loss from `examples/summarize_rlhf/reward_model/reward_model.py:51-89`:
bs = input_ids.shape[0] // 2
chosen = input_ids[:bs]
rejected = input_ids[bs:]
# ...
loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()
Dataset quality filtering from `examples/summarize_rlhf/reward_model/train_reward_model_gptj.py:11-26`:
if chosen_summary == rejected_summary:
continue # Skip identical pairs
if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5:
continue # Skip very short summaries
Accuracy metric from `examples/summarize_rlhf/reward_model/train_reward_model_gptj.py:78-86`:
def compute_metrics(eval_preds):
chosen_end_scores = eval_preds.predictions[0]
rejected_end_scores = eval_preds.predictions[1]
acc = sum(chosen_end_scores > rejected_end_scores) / len(rejected_end_scores)
result["accuracy"] = acc