Heuristic:CarperAI Trlx Delta Rewards
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Optimization, LLMs |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Reward normalization technique that computes the improvement (delta) between the trained model's reward and the reference model's reward, reducing variance and centering the learning signal.
Description
Instead of using absolute reward scores from a reward model, delta rewards measure the improvement over a baseline. The baseline is the reward score of the original (reference) model's completion for the same prompt. This is combined with running or reference statistics normalization (`scale_reward`) and optional reward clipping (`cliprange_reward`) to produce a stable training signal. trlx implements three reward scaling modes: `"running"` (adaptive normalization by running statistics), `"ref"` (normalization by initial reference statistics), and `None` (no normalization).
Usage
Apply this heuristic when reward scores have high variance across prompts, or when absolute reward magnitudes are inconsistent. Delta rewards are particularly important for summarization and dialogue tasks where prompt difficulty varies significantly. The `scale_reward="running"` mode is the recommended default for most tasks.
The Insight (Rule of Thumb)
- Action: Compute reward as `new_reward - baseline_reward` and use `scale_reward="running"` for normalization.
- Value:
- Delta mode: `reward = get_reward(new_completion) - get_reward(reference_completion)`
- Scale mode: `"running"` (adaptive, recommended), `"ref"` (fixed baseline), or `None` (raw)
- Clip range: `cliprange_reward=10` (prevent extreme outliers)
- Reward padding: `-np.inf` sentinel for masked positions
- Trade-off: Delta rewards add one extra reward model forward pass per batch (for baseline scores) but significantly reduce reward variance and improve training stability.
Reasoning
Absolute reward scores conflate prompt difficulty with model quality: an easy prompt may yield high scores for any model, while a hard prompt yields low scores even for good completions. By subtracting the reference model's score, we isolate the improvement attributable to training. This centers the reward distribution around zero and reduces variance across prompts.
The running moments normalization (Welford's online algorithm) further stabilizes rewards by dividing by the running standard deviation, ensuring the PPO advantage estimates remain in a consistent numerical range throughout training. The `1e-8` epsilon in the whitening function (`torch.rsqrt(var + 1e-8)`) prevents division by zero.
Code Evidence
Delta reward computation from `examples/hh/ppo_hh.py:191-200`:
def reward_fn(samples, prompts, original_output, **kwargs):
samples = [s + reward_tokenizer.eos_token for s in samples]
rewards = get_reward(samples)
if not delta_reward:
return rewards
original_samples = [p + o + reward_tokenizer.eos_token
for p, o in zip(prompts, original_output)]
original_rewards = get_reward(original_samples)
return rewards - original_rewards
Summarization delta normalization from `examples/summarize_rlhf/trlx_gptj_text_summarization.py:147-153`:
def reward_fn(samples: List[str], **kwargs):
original_samples = [text.split("TL;DR:")[0] + "TL;DR: " for text in samples]
original_samples = [text + post_summary_dict[text.strip()] for text in original_samples]
original_scores = get_scores(original_samples)
scores = get_scores(samples)
norms_scores = scores - original_scores
return norms_scores
Reward scaling in PPO trainer from `trlx/trainer/accelerate_ppo_trainer.py:364-381`:
if self.ref_mean is None:
self.ref_mean, self.ref_std = (scores * scores_mask).sum(dim=1).mean(), \
(scores * scores_mask).sum(dim=1).std()
all_scores_mean, all_scores_std = self.running_moments.update(
torch.sum(scores * scores_mask, dim=1))
if self.config.method.scale_reward == "running":
scores /= self.running_moments.std
elif self.config.method.scale_reward == "ref":
scores /= self.ref_std
Running moments with Welford's algorithm from `trlx/utils/modeling.py:275-307`:
class RunningMoments:
def __init__(self):
self.mean = 0
self.std = 1
self.var = 1
self.count = 1e-24 # Small init to avoid division by zero