Heuristic:Danijar Dreamerv3 Percentile Return Normalization
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Optimization, Model_Based_RL |
| Last Updated | 2026-02-15 09:00 GMT |
Overview
Scale-free return normalization using running percentile estimates (5th and 95th) instead of mean/std, providing robust normalization that handles skewed and multi-modal return distributions.
Description
DreamerV3 normalizes imagined returns using the percentile method: instead of dividing by standard deviation, it computes the 5th and 95th percentiles of returns via exponential moving averages and normalizes by their difference. This makes the policy gradient updates invariant to the absolute scale and distribution shape of returns.
Three separate normalizers are used:
- retnorm (percentile): Normalizes the advantage (return - baseline) for the policy gradient. Uses 5th/95th percentiles with EMA rate 0.01.
- valnorm (none by default): Normalizes value targets before the value loss. Disabled by default.
- advnorm (none by default): Additional advantage normalization. Disabled by default.
The `Normalize` class supports three implementations: `none` (passthrough), `meanstd` (standard normalization), and `perc` (percentile-based).
Usage
Applied automatically during imagination-based policy optimization. The percentile normalizer is the key innovation enabling fixed hyperparameters: it prevents the policy entropy coefficient (`actent=3e-4`) from being dominated by large returns in high-reward environments or irrelevant in low-reward ones.
The Insight (Rule of Thumb)
- Action: Use percentile-based return normalization (`impl: perc`) with the 5th and 95th percentiles for the return normalizer. Keep value and advantage normalizers at `none`.
- Value: `perclo=5.0, perchi=95.0, rate=0.01, limit=1.0`. The `limit=1.0` sets the minimum scale to 1.0 (prevents division by near-zero when 5th≈95th percentile).
- Trade-off: Percentile normalization is robust to outliers and skewed distributions (unlike mean/std), but introduces a small lag due to the EMA tracking. The `debias=False` setting avoids bias correction, prioritizing responsiveness over accuracy in early training.
- Compatibility: Requires imagined returns to have sufficient variance for meaningful percentile estimation. Works across environments with reward scales from 0-1 (Crafter) to 0-100,000+ (Atari).
Reasoning
Standard mean/std normalization is sensitive to outliers and assumes roughly Gaussian returns. In RL, returns are often heavily skewed (most trajectories fail, few succeed spectacularly). Percentile normalization captures the "spread" of the return distribution without being pulled by extreme values.
The normalization ensures that the policy entropy coefficient `actent=3e-4` provides consistent exploration pressure regardless of return scale: normalized returns are in the approximate range [0, 1], so the entropy bonus is always a meaningful fraction of the gradient signal.
Implementation from `embodied/jax/utils.py:L16-91`:
class Normalize(nj.Module):
rate: float = 0.01
limit: float = 1e-8
perclo: float = 5.0
perchi: float = 95.0
def stats(self):
if self.impl == 'perc':
lo, hi = self.lo.read() * corr, self.hi.read() * corr
return sg(lo), sg(jnp.maximum(self.limit, hi - lo))
def _perc(self, x, q):
axes = internal.get_data_axes()
if axes:
x = jax.lax.all_gather(x, axes)
x = jnp.percentile(x, q)
return x
Usage in the actor-critic loss from `dreamerv3/agent.py:L407-414`:
roffset, rscale = retnorm(ret, update)
adv = (ret - tarval[:, :-1]) / rscale
aoffset, ascale = advnorm(adv, update)
adv_normed = (adv - aoffset) / ascale
logpi = sum([v.logp(sg(act[k]))[:, :-1] for k, v in policy.items()])
ents = {k: v.entropy()[:, :-1] for k, v in policy.items()}
policy_loss = sg(weight[:, :-1]) * -(
logpi * sg(adv_normed) + actent * sum(ents.values()))
Default config from `dreamerv3/configs.yaml:L111-113`:
retnorm: {impl: perc, rate: 0.01, limit: 1.0, perclo: 5.0, perchi: 95.0, debias: False}
valnorm: {impl: none, rate: 0.01, limit: 1e-8}
advnorm: {impl: none, rate: 0.01, limit: 1e-8}