Heuristic:Danijar Dreamerv3 Free Nats KL Thresholding
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Model_Based_RL, Optimization |
| Last Updated | 2026-02-15 09:00 GMT |
Overview
Training stability technique that clamps KL divergence losses to a minimum threshold (free nats), preventing the posterior from collapsing onto the prior early in training.
Description
In the RSSM world model, the KL divergence between the posterior (informed by observations) and the prior (predicted from dynamics alone) is used as both a dynamics loss (train prior toward posterior) and a representation loss (train posterior toward prior). Without a minimum threshold, the model can satisfy the KL objective trivially by making both distributions identical and uninformative, leading to posterior collapse.
The free nats technique sets a floor: any KL divergence below the threshold is treated as zero loss. This allows the latent space to maintain a minimum amount of information without penalty, ensuring the model encodes useful state information before the KL regularization takes effect.
Usage
Applied automatically during RSSM loss computation. Use when training world models with stochastic latent spaces to prevent early posterior collapse. The default threshold of 1.0 free nat works across all tested environments.
The Insight (Rule of Thumb)
- Action: Set `free_nats=1.0` in the RSSM configuration to threshold both the dynamics and representation KL losses.
- Value: Default `free_nats=1.0`. This means any KL divergence below 1.0 nat incurs zero loss.
- Trade-off: Higher free nats allows more information in the latent space but reduces regularization pressure. Lower free nats increases regularization but risks posterior collapse. The default of 1.0 balances these across all tested domains.
- Compatibility: Works with the categorical OneHot latent space (stoch=32, classes=64 by default). The maximum possible KL for this space is `32 * log(64) ≈ 133 nats`, so 1.0 is a modest floor.
Reasoning
Without free nats, the dual KL losses create a race condition: the dynamics loss pushes the prior to match the posterior, while the representation loss (weighted at 0.1x) pushes the posterior to match the prior. If both converge to a uniform distribution, KL is zero but the latent space encodes no information. The free nats threshold ensures at least 1.0 nat of information is maintained per timestep without penalty.
The implementation from `dreamerv3/rssm.py:L120-133`:
def loss(self, carry, tokens, acts, reset, training):
metrics = {}
carry, entries, feat = self.observe(carry, tokens, acts, reset, training)
prior = self._prior(feat['deter'])
post = feat['logit']
dyn = self._dist(sg(post)).kl(self._dist(prior))
rep = self._dist(post).kl(self._dist(sg(prior)))
if self.free_nats:
dyn = jnp.maximum(dyn, self.free_nats)
rep = jnp.maximum(rep, self.free_nats)
losses = {'dyn': dyn, 'rep': rep}
...
Default config from `dreamerv3/configs.yaml:L91`:
rssm: {deter: 8192, hidden: 1024, stoch: 32, classes: 64, free_nats: 1.0}
Loss scales showing the asymmetric weighting from `dreamerv3/configs.yaml:L86`:
loss_scales: {dyn: 1.0, rep: 0.1, ...}