Heuristic:Danijar Dreamerv3 Replay Context Carry Init
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Model_Based_RL, Optimization |
| Last Updated | 2026-02-15 09:00 GMT |
Overview
Replay buffer technique that stores and retrieves RSSM carry states alongside transitions, enabling accurate world model initialization when training on non-contiguous replay sequences.
Description
When sampling from the replay buffer, sequences are drawn at random positions — not necessarily from episode boundaries. The RSSM world model is recurrent and requires an initial carry state (deterministic + stochastic) to produce meaningful predictions. Without proper initialization, the first several timesteps of each training sequence are wasted as the model "warms up" from a zero state.
Replay context solves this by:
- During data collection, storing the RSSM carry states (deter, stoch) alongside each transition in the replay buffer via the `ext_space` mechanism
- During training, prefixing each sampled sequence with `replay_context` extra timesteps (default 1) that are used only to initialize the carry state
- Using stored carry states from the prefix to restore the model's recurrent state, then training only on the remaining timesteps
The `_apply_replay_context()` method conditionally uses stored carry states when the sequence is not at an episode boundary (indicated by `consec > 0`), and falls back to default initialization when it is the first chunk of an episode.
Usage
Enabled by default with `replay_context=1`. Increase for longer context windows. Set to `0` to disable (model always starts from zeros, wasting early timesteps).
The Insight (Rule of Thumb)
- Action: Set `replay_context=1` (or higher) to prepend carry-state context to each training sequence. The stored carry states are written back to replay during training via `outs['replay']`.
- Value: Default `replay_context=1`. The replay samples sequences of length `consec * batch_length + replay_context`, where the first `replay_context` steps are used only for state initialization.
- Trade-off: Each additional context step increases memory and compute per batch proportionally, but eliminates "warmup waste" from the RSSM. For `batch_length=64`, a context of 1 wastes only 1/64 = 1.5% overhead vs potentially wasting 5-10 early timesteps without it.
- Compatibility: Requires the replay buffer to store extra fields (deter, stoch for each module). The ext_space mechanism handles this automatically.
Reasoning
The RSSM has a deterministic state of 8192 dimensions and stochastic state of 32x64 = 2048 dimensions. Starting from zeros means the first ~5-10 timesteps produce unreliable features and gradients. With replay context, the stored carry state provides a "warm start" that makes every timestep in the training window useful.
The assertion at `dreamerv3/main.py:L188` ensures the replay is large enough: `assert config.batch_size * length <= capacity`, where `length = consec * batlen + replay_context`.
Implementation from `dreamerv3/agent.py:L312-340`:
def _apply_replay_context(self, carry, data):
(enc_carry, dyn_carry, dec_carry, prevact) = carry
carry = (enc_carry, dyn_carry, dec_carry)
stepid = data['stepid']
obs = {k: data[k] for k in self.obs_space}
prepend = lambda x, y: jnp.concatenate([x[:, None], y[:, :-1]], 1)
prevact = {k: prepend(prevact[k], data[k]) for k in self.act_space}
if not self.config.replay_context:
return carry, obs, prevact, stepid
K = self.config.replay_context
...
first_chunk = (data['consec'][:, 0] == 0)
carry, obs, prevact, stepid = jax.tree.map(
lambda normal, replay: nn.where(first_chunk, replay, normal),
(carry, rhs(obs), rhs(prevact), rhs(stepid)),
(rep_carry, rep_obs, rep_prevact, rep_stepid))
return carry, obs, prevact, stepid
The ext_space that stores carry states from `dreamerv3/agent.py:L91-99`:
@property
def ext_space(self):
spaces = {}
spaces['consec'] = elements.Space(np.int32)
spaces['stepid'] = elements.Space(np.uint8, 20)
if self.config.replay_context:
spaces.update(elements.tree.flatdict(dict(
enc=self.enc.entry_space,
dyn=self.dyn.entry_space,
dec=self.dec.entry_space)))
return spaces