Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Danijar Dreamerv3 Train Loop

From Leeroopedia
Knowledge Sources
Domains Reinforcement_Learning, Model_Based_RL, JAX
Last Updated 2026-02-15 09:00 GMT

Overview

Concrete tool for running the DreamerV3 single-process training loop that interleaves environment data collection with world model and policy training provided by the embodied run library.

Description

The train() function in embodied/run/train.py orchestrates the complete training pipeline:

  1. Creates agent, replay, and logger from factory functions
  2. Sets up a Driver with parallel environments and step callbacks (replay insertion, logging, training trigger)
  3. Creates data streams (train and report) that sample and batch from replay
  4. Restores or initializes from checkpoint
  5. Runs the main loop: driver collects 10 steps, then training callbacks fire based on train_ratio

The Agent.train() method in dreamerv3/agent.py executes the world model and actor-critic losses through Agent.loss(), which calls RSSM.observe() for posterior inference, RSSM.imagine() for imagined rollouts, and imag_loss() for actor-critic optimization.

Usage

This is called by main() when config.script == 'train'. It runs until args.steps environment steps are completed.

Code Reference

Source Location

  • Repository: dreamerv3
  • File: embodied/run/train.py (loop), dreamerv3/agent.py (Agent.train, Agent.loss, imag_loss), dreamerv3/rssm.py (RSSM.observe, RSSM.imagine)
  • Lines: embodied/run/train.py L9-119, dreamerv3/agent.py L137-245 (train+loss), dreamerv3/agent.py L382-446 (imag_loss), dreamerv3/rssm.py L61-118 (observe+imagine)

Signature

def train(make_agent, make_replay, make_env, make_stream, make_logger, args):
    """
    Single-process training loop.

    Args:
        make_agent: Callable -> Agent
        make_replay: Callable -> Replay
        make_env: Callable(index) -> Env
        make_stream: Callable(replay, mode) -> Stream
        make_logger: Callable -> Logger
        args: elements.Config with envs, train_ratio, steps, batch_size,
              batch_length, log_every, report_every, save_every, etc.
    """

# Agent methods called within:
def Agent.train(self, carry, data):
    """(carry, data) -> (carry, outs, metrics)"""

def Agent.loss(self, carry, obs, prevact, training):
    """(carry, obs, prevact, training) -> (loss, (carry, entries, outs, metrics))"""

def Agent.policy(self, carry, obs, mode='train'):
    """(carry, obs, mode) -> (carry, acts, outs)"""

Import

import embodied
embodied.run.train(make_agent, make_replay, make_env, make_stream, make_logger, args)

I/O Contract

Inputs

Name Type Required Description
make_agent Callable Yes Factory returning initialized Agent
make_replay Callable Yes Factory returning configured Replay buffer
make_env Callable(int) Yes Factory returning wrapped environment for given index
make_stream Callable(replay, mode) Yes Factory returning data stream iterator
make_logger Callable Yes Factory returning Logger with configured outputs
args elements.Config Yes Run arguments: envs, train_ratio, steps, batch_size, batch_length, log_every, report_every, save_every, from_checkpoint

Outputs

Name Type Description
Trained agent Side effect Agent parameters updated in-place via optimizer
Checkpoints Files Periodic checkpoint files saved to logdir/ckpt
Metrics Files Training metrics written to JSONL, TensorBoard, WandB
Reports Side effect Open-loop video predictions logged periodically

Usage Examples

Standard Training

from functools import partial as bind
from dreamerv3.main import make_agent, make_replay, make_env, make_stream, make_logger
import embodied

# config loaded from YAML + CLI
args = elements.Config(**config.run, batch_size=config.batch_size, ...)

embodied.run.train(
    bind(make_agent, config),
    bind(make_replay, config, 'replay'),
    bind(make_env, config),
    bind(make_stream, config),
    bind(make_logger, config),
    args)

Related Pages

Implements Principle

Requires Environment

Uses Heuristics

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment