Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Danijar Dreamerv3 Replay Context Carry Init

From Leeroopedia
Knowledge Sources
Domains Reinforcement_Learning, Model_Based_RL, Optimization
Last Updated 2026-02-15 09:00 GMT

Overview

Replay buffer technique that stores and retrieves RSSM carry states alongside transitions, enabling accurate world model initialization when training on non-contiguous replay sequences.

Description

When sampling from the replay buffer, sequences are drawn at random positions — not necessarily from episode boundaries. The RSSM world model is recurrent and requires an initial carry state (deterministic + stochastic) to produce meaningful predictions. Without proper initialization, the first several timesteps of each training sequence are wasted as the model "warms up" from a zero state.

Replay context solves this by:

  1. During data collection, storing the RSSM carry states (deter, stoch) alongside each transition in the replay buffer via the `ext_space` mechanism
  2. During training, prefixing each sampled sequence with `replay_context` extra timesteps (default 1) that are used only to initialize the carry state
  3. Using stored carry states from the prefix to restore the model's recurrent state, then training only on the remaining timesteps

The `_apply_replay_context()` method conditionally uses stored carry states when the sequence is not at an episode boundary (indicated by `consec > 0`), and falls back to default initialization when it is the first chunk of an episode.

Usage

Enabled by default with `replay_context=1`. Increase for longer context windows. Set to `0` to disable (model always starts from zeros, wasting early timesteps).

The Insight (Rule of Thumb)

  • Action: Set `replay_context=1` (or higher) to prepend carry-state context to each training sequence. The stored carry states are written back to replay during training via `outs['replay']`.
  • Value: Default `replay_context=1`. The replay samples sequences of length `consec * batch_length + replay_context`, where the first `replay_context` steps are used only for state initialization.
  • Trade-off: Each additional context step increases memory and compute per batch proportionally, but eliminates "warmup waste" from the RSSM. For `batch_length=64`, a context of 1 wastes only 1/64 = 1.5% overhead vs potentially wasting 5-10 early timesteps without it.
  • Compatibility: Requires the replay buffer to store extra fields (deter, stoch for each module). The ext_space mechanism handles this automatically.

Reasoning

The RSSM has a deterministic state of 8192 dimensions and stochastic state of 32x64 = 2048 dimensions. Starting from zeros means the first ~5-10 timesteps produce unreliable features and gradients. With replay context, the stored carry state provides a "warm start" that makes every timestep in the training window useful.

The assertion at `dreamerv3/main.py:L188` ensures the replay is large enough: `assert config.batch_size * length <= capacity`, where `length = consec * batlen + replay_context`.

Implementation from `dreamerv3/agent.py:L312-340`:

def _apply_replay_context(self, carry, data):
    (enc_carry, dyn_carry, dec_carry, prevact) = carry
    carry = (enc_carry, dyn_carry, dec_carry)
    stepid = data['stepid']
    obs = {k: data[k] for k in self.obs_space}
    prepend = lambda x, y: jnp.concatenate([x[:, None], y[:, :-1]], 1)
    prevact = {k: prepend(prevact[k], data[k]) for k in self.act_space}
    if not self.config.replay_context:
      return carry, obs, prevact, stepid

    K = self.config.replay_context
    ...
    first_chunk = (data['consec'][:, 0] == 0)
    carry, obs, prevact, stepid = jax.tree.map(
        lambda normal, replay: nn.where(first_chunk, replay, normal),
        (carry, rhs(obs), rhs(prevact), rhs(stepid)),
        (rep_carry, rep_obs, rep_prevact, rep_stepid))
    return carry, obs, prevact, stepid

The ext_space that stores carry states from `dreamerv3/agent.py:L91-99`:

@property
def ext_space(self):
    spaces = {}
    spaces['consec'] = elements.Space(np.int32)
    spaces['stepid'] = elements.Space(np.uint8, 20)
    if self.config.replay_context:
      spaces.update(elements.tree.flatdict(dict(
          enc=self.enc.entry_space,
          dyn=self.dyn.entry_space,
          dec=self.dec.entry_space)))
    return spaces

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment