Implementation:Danijar Dreamerv3 Train Loop
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Model_Based_RL, JAX |
| Last Updated | 2026-02-15 09:00 GMT |
Overview
Concrete tool for running the DreamerV3 single-process training loop that interleaves environment data collection with world model and policy training provided by the embodied run library.
Description
The train() function in embodied/run/train.py orchestrates the complete training pipeline:
- Creates agent, replay, and logger from factory functions
- Sets up a Driver with parallel environments and step callbacks (replay insertion, logging, training trigger)
- Creates data streams (train and report) that sample and batch from replay
- Restores or initializes from checkpoint
- Runs the main loop: driver collects 10 steps, then training callbacks fire based on train_ratio
The Agent.train() method in dreamerv3/agent.py executes the world model and actor-critic losses through Agent.loss(), which calls RSSM.observe() for posterior inference, RSSM.imagine() for imagined rollouts, and imag_loss() for actor-critic optimization.
Usage
This is called by main() when config.script == 'train'. It runs until args.steps environment steps are completed.
Code Reference
Source Location
- Repository: dreamerv3
- File: embodied/run/train.py (loop), dreamerv3/agent.py (Agent.train, Agent.loss, imag_loss), dreamerv3/rssm.py (RSSM.observe, RSSM.imagine)
- Lines: embodied/run/train.py L9-119, dreamerv3/agent.py L137-245 (train+loss), dreamerv3/agent.py L382-446 (imag_loss), dreamerv3/rssm.py L61-118 (observe+imagine)
Signature
def train(make_agent, make_replay, make_env, make_stream, make_logger, args):
"""
Single-process training loop.
Args:
make_agent: Callable -> Agent
make_replay: Callable -> Replay
make_env: Callable(index) -> Env
make_stream: Callable(replay, mode) -> Stream
make_logger: Callable -> Logger
args: elements.Config with envs, train_ratio, steps, batch_size,
batch_length, log_every, report_every, save_every, etc.
"""
# Agent methods called within:
def Agent.train(self, carry, data):
"""(carry, data) -> (carry, outs, metrics)"""
def Agent.loss(self, carry, obs, prevact, training):
"""(carry, obs, prevact, training) -> (loss, (carry, entries, outs, metrics))"""
def Agent.policy(self, carry, obs, mode='train'):
"""(carry, obs, mode) -> (carry, acts, outs)"""
Import
import embodied
embodied.run.train(make_agent, make_replay, make_env, make_stream, make_logger, args)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| make_agent | Callable | Yes | Factory returning initialized Agent |
| make_replay | Callable | Yes | Factory returning configured Replay buffer |
| make_env | Callable(int) | Yes | Factory returning wrapped environment for given index |
| make_stream | Callable(replay, mode) | Yes | Factory returning data stream iterator |
| make_logger | Callable | Yes | Factory returning Logger with configured outputs |
| args | elements.Config | Yes | Run arguments: envs, train_ratio, steps, batch_size, batch_length, log_every, report_every, save_every, from_checkpoint |
Outputs
| Name | Type | Description |
|---|---|---|
| Trained agent | Side effect | Agent parameters updated in-place via optimizer |
| Checkpoints | Files | Periodic checkpoint files saved to logdir/ckpt |
| Metrics | Files | Training metrics written to JSONL, TensorBoard, WandB |
| Reports | Side effect | Open-loop video predictions logged periodically |
Usage Examples
Standard Training
from functools import partial as bind
from dreamerv3.main import make_agent, make_replay, make_env, make_stream, make_logger
import embodied
# config loaded from YAML + CLI
args = elements.Config(**config.run, batch_size=config.batch_size, ...)
embodied.run.train(
bind(make_agent, config),
bind(make_replay, config, 'replay'),
bind(make_env, config),
bind(make_stream, config),
bind(make_logger, config),
args)
Related Pages
Implements Principle
Requires Environment
Uses Heuristics
- Heuristic:Danijar_Dreamerv3_Symlog_TwoHot_Prediction
- Heuristic:Danijar_Dreamerv3_Adaptive_Gradient_Clipping
- Heuristic:Danijar_Dreamerv3_Free_Nats_KL_Thresholding
- Heuristic:Danijar_Dreamerv3_Percentile_Return_Normalization
- Heuristic:Danijar_Dreamerv3_Replay_Context_Carry_Init
- Heuristic:Danijar_Dreamerv3_XLA_GPU_Optimization_Flags