Implementation:Danijar Dreamerv3 Agent Init
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Model_Based_RL, JAX |
| Last Updated | 2026-02-15 09:00 GMT |
Overview
Concrete tool for constructing a DreamerV3 world-model-based agent with RSSM, encoder/decoder, actor-critic, and optimizer provided by the dreamerv3 package.
Description
The Agent.__init__() method in dreamerv3/agent.py wires together all neural network modules: RSSM world model, CNN+MLP encoder/decoder, reward/continue prediction heads, policy and value heads with slow target network, percentile normalizers, and a custom optimizer. The Agent class inherits from embodied.jax.Agent which adds JAX-specific lifecycle management (JIT compilation, parameter sharding, streaming).
The factory function make_agent() in dreamerv3/main.py creates a temporary environment to extract obs_space and act_space, then instantiates the Agent with these spaces and the merged config.
Usage
Call make_agent(config) to create the agent. This is always done before checkpoint restoration or training loop entry.
Code Reference
Source Location
- Repository: dreamerv3
- File: dreamerv3/agent.py (Agent.__init__), dreamerv3/main.py (make_agent)
- Lines: dreamerv3/agent.py L24-83, dreamerv3/main.py L127-149
Signature
class Agent(embodied.jax.Agent):
def __init__(self, obs_space, act_space, config):
"""
Construct the DreamerV3 agent.
Args:
obs_space: dict[str, elements.Space] - Observation spaces
(images, vectors, reward, is_first, is_last, is_terminal).
act_space: dict[str, elements.Space] - Action spaces
(discrete or continuous per key).
config: elements.Config - Agent configuration including:
- config.enc: Encoder settings (typ, depth, mults, etc.)
- config.dyn: RSSM settings (deter=4096, stoch=32, classes=32)
- config.dec: Decoder settings
- config.rewhead, config.conhead: Reward/continue head settings
- config.policy: Policy head settings
- config.value: Value head settings
- config.slowvalue: Slow target EMA settings
- config.opt: Optimizer settings (lr, agc, eps, etc.)
- config.loss_scales: Per-loss scaling weights
"""
def make_agent(config):
"""
Factory function: creates a temporary env to get spaces, then builds Agent.
Args:
config: elements.Config - Full DreamerV3 configuration.
Returns:
Agent: Initialized DreamerV3 agent (or RandomAgent if config.random_agent).
"""
Import
from dreamerv3.agent import Agent
from dreamerv3.main import make_agent
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| obs_space | dict[str, elements.Space] | Yes | Observation spaces from environment (images, vectors, signals) |
| act_space | dict[str, elements.Space] | Yes | Action spaces from environment (discrete/continuous) |
| config | elements.Config | Yes | Agent hyperparameters, model architecture, optimizer settings |
Outputs
| Name | Type | Description |
|---|---|---|
| agent | Agent | Initialized agent with: self.dyn (RSSM), self.enc (Encoder), self.dec (Decoder), self.rew (reward head), self.con (continue head), self.pol (policy head), self.val (value head), self.slowval (slow target), self.opt (optimizer) |
| agent.policy() | method | (carry, obs, mode) -> (carry, acts, outs) |
| agent.train() | method | (carry, data) -> (carry, outs, metrics) |
| agent.report() | method | (carry, data) -> (carry, metrics) |
Usage Examples
Via Factory Function
from dreamerv3.main import make_agent
# config is an elements.Config object from configuration loading
agent = make_agent(config)
# Initialize carry states
carry_train = agent.init_train(batch_size=16)
carry_policy = agent.init_policy(batch_size=4)
Direct Construction
from dreamerv3.agent import Agent
# obs_space and act_space extracted from environment
agent = Agent(obs_space, act_space, config)
# The agent's modules are accessible:
# agent.dyn - RSSM world model
# agent.enc - Encoder (CNN + MLP)
# agent.dec - Decoder (CNN + MLP)
# agent.pol - Policy head
# agent.val - Value head
# agent.opt - Optimizer