Implementation:Danijar Dreamerv3 Checkpoint Operations
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Training_Infrastructure |
| Last Updated | 2026-02-15 09:00 GMT |
Overview
Wrapper for the elements.Checkpoint class that serializes and restores agent parameters, replay buffer state, and training step counter for DreamerV3 training runs.
Description
DreamerV3 uses elements.Checkpoint (an external library) for state persistence. The checkpoint object supports attribute-based registration of components, load_or_save() for automatic resume-or-initialize behavior, and load() with key and regex filtering for selective parameter restoration.
In the single-process training loop (embodied/run/train.py), the checkpoint registers step, agent, and replay. In train_eval mode, it additionally registers replay_train and replay_eval. In distributed mode, separate checkpoints are maintained for the agent (learner process), replay (replay process), and logger (logger process).
Usage
This is a Wrapper Doc for the external elements.Checkpoint API. Use it after agent and replay initialization, before entering the main training loop.
Code Reference
Source Location
- Repository: dreamerv3
- File: embodied/run/train.py (single-process), embodied/run/train_eval.py (train+eval), embodied/run/eval_only.py (eval)
- Lines: embodied/run/train.py L83-90, embodied/run/train_eval.py L114-123, embodied/run/eval_only.py L59-61
Signature
# elements.Checkpoint API (external library)
cp = elements.Checkpoint(path)
cp.step = step_counter # Register step counter
cp.agent = agent # Register agent
cp.replay = replay # Register replay buffer
cp.load_or_save() # Load if exists, else save
# Selective loading from pretrained checkpoint
elements.checkpoint.load(
path, # Path to pretrained checkpoint
dict(agent=bind(agent.load, regex=regex_pattern))
)
# Evaluation-only: load specific keys
cp = elements.Checkpoint()
cp.agent = agent
cp.load(path, keys=['agent']) # Only load agent params
Import
import elements
cp = elements.Checkpoint(logdir / 'ckpt')
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| path | str or Path | No | Disk path for checkpoint storage. If None, creates a transient checkpoint (for eval_only). |
| cp.step | elements.Counter | No | Step counter to persist |
| cp.agent | Agent | No | Agent whose parameters to save/restore |
| cp.replay | Replay | No | Replay buffer state to persist |
| args.from_checkpoint | str | No | Path to a pretrained checkpoint for transfer learning |
| args.from_checkpoint_regex | str | No | Regex pattern to filter which parameters to load from pretrained checkpoint |
Outputs
| Name | Type | Description |
|---|---|---|
| Restored state | in-place | All registered components are restored in-place from disk |
| Checkpoint file | File | Serialized state written to disk at the specified path |
Usage Examples
Single Process Training
import elements
# Setup checkpoint with all components
cp = elements.Checkpoint(logdir / 'ckpt')
cp.step = step
cp.agent = agent
cp.replay = replay
# Optionally load pretrained weights
if args.from_checkpoint:
elements.checkpoint.load(args.from_checkpoint, dict(
agent=bind(agent.load, regex=args.from_checkpoint_regex)))
# Resume or initialize
cp.load_or_save()
# Periodic saving during training
if should_save(step):
cp.save()
Evaluation Only
# Create transient checkpoint, load agent only
cp = elements.Checkpoint()
cp.agent = agent
cp.load(args.from_checkpoint, keys=['agent'])