Heuristic:Allenai Open instruct Logprob Clamping
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Reinforcement_Learning |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Clamp log probability differences to [-40, 40] in the KL divergence computation to prevent numerical overflow.
Description
When computing the KL penalty between the current policy and reference policy in GRPO, the difference `new_logprobs - ref_logprobs` can become very large (positive or negative) as the policy diverges from the reference. Since this difference is exponentiated in KL estimators, values beyond +/-40 would cause floating-point overflow (exp(50) is approximately 5e21). Clamping to [-40, 40] provides 8 orders of magnitude of dynamic range while preventing NaN/Inf propagation.
Usage
Apply this heuristic whenever computing KL divergence between policy and reference log probabilities. This is critical in GRPO training but applies to any on-policy RL algorithm with a KL penalty.
The Insight (Rule of Thumb)
- Action: Clamp `(new_logprobs - ref_logprobs)` to the range [-40.0, 40.0] before computing KL.
- Value: Bounds of -40.0 and 40.0.
- Trade-off: Extremely divergent policies will have their KL slightly underestimated, but this is desirable as it prevents training collapse from NaN gradients.
Reasoning
The KL estimator uses exponentials of logprob differences. Without clamping, a policy that drifts far from the reference (common during early training or after a large learning rate step) can produce logprob differences of 100+, causing `exp(100) = inf` and NaN loss. The [-40, 40] range was chosen because `exp(40)` is approximately 2.35e17, well within float32 range, while still allowing detection of significant policy divergence.
Code Evidence
From `open_instruct/grpo_utils.py:260-266`:
if ref_logprobs is not None:
# We want the KL loss to backpropagate through the model.
# We also clamp the KL loss to avoid numerical instability.
# https://chatgpt.com/share/679d0ed9-8f48-8011-926e-e274b15ae8ae
ref_logprobs_diff = (new_logprobs - ref_logprobs).clamp(-40.0, 40.0)
kl_all = model_utils.estimate_kl(ref_logprobs_diff, ratio)
kl = kl_all[config.kl_estimator]