Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Environment:Danijar Dreamerv3 JAX CUDA

From Leeroopedia
Knowledge Sources
Domains Infrastructure, Deep_Learning, Reinforcement_Learning
Last Updated 2026-02-15 09:00 GMT

Overview

Linux environment with CUDA 12, JAX 0.4.33+, Python 3.11+, and bfloat16 compute for training DreamerV3 world models.

Description

This environment provides a GPU-accelerated JAX runtime for training the DreamerV3 reinforcement learning agent. The default compute dtype is bfloat16 (configurable to float32 or float16). The JAX platform defaults to cuda and requires NVIDIA GPU hardware with CUDA 12 support. The codebase uses XLA compiler flags optimized for GPU execution including latency-hiding schedulers, pipelined all-reduce/all-gather, and disabled rematerialization. Multi-host distributed training is supported via jax.distributed.initialize() with a coordinator address.

Usage

Use this environment for all DreamerV3 training workflows: single-process training, train-and-evaluate, evaluation-only, and distributed parallel training. The JAX platform can be switched to cpu or tpu via the --jax.platform flag, but GPU (CUDA) is the default and recommended platform.

System Requirements

Category Requirement Notes
OS Linux (tested) or macOS README states tested on Linux and Mac
Python 3.11+ README requirement
Hardware NVIDIA GPU with CUDA 12 support Default platform is cuda; bfloat16 requires Ampere (A100) or newer for native support
Disk Sufficient for replay buffer Default replay size is 5M transitions; chunks stored on disk

Dependencies

System Packages

  • `nvidia-cuda-nvcc-cu12` <= 12.2 (CUDA compiler)
  • CUDA 12 toolkit (for JAX GPU backend)

Python Packages

  • `jax[cuda12]` == 0.4.33
  • `ninjax` >= 3.5.1
  • `optax` (JAX optimizer library)
  • `chex` (JAX testing utilities)
  • `jaxtyping` (JAX type annotations)
  • `einops` (tensor rearrangement)
  • `elements` >= 3.19.1 (training infrastructure)
  • `portal` >= 3.5.0 (RPC/process management)
  • `granular` >= 0.20.3
  • `scope` >= 0.4.4
  • `numpy` < 2 (DMLab/MineRL compatibility constraint)
  • `colored_traceback`
  • `tqdm`
  • `ipdb`
  • `av` (video encoding)
  • `google-resumable-media` >= 2.7.2

Environment-Specific Packages

  • `ale_py` == 0.9.0 (Atari)
  • `autorom[accept-rom-license]` == 0.6.1 (Atari ROM management)

Credentials

The following environment variables may be needed depending on configuration:

  • `JOB_COMPLETION_INDEX`: Optional Kubernetes job index for multi-replica training (dreamerv3/main.py:L33-34).
  • `XLA_FLAGS`: Automatically set by the setup function with GPU/TPU optimization flags (embodied/jax/internal.py:L43-94).
  • `XLA_PYTHON_CLIENT_PREALLOCATE`: Set automatically to control JAX GPU memory preallocation (embodied/jax/internal.py:L39).
  • `TF_CUDNN_DETERMINISTIC`: Set to '1' when deterministic mode is enabled (embodied/jax/internal.py:L46).

Quick Install

# Install JAX with CUDA 12 support
pip install jax[cuda12]==0.4.33

# Install all other requirements
pip install -U -r requirements.txt

# Or install directly
pip install ninjax>=3.5.1 optax chex jaxtyping einops elements>=3.19.1 portal>=3.5.0 granular>=0.20.3 scope>=0.4.4 "numpy<2" colored_traceback tqdm ipdb av "google-resumable-media>=2.7.2" "nvidia-cuda-nvcc-cu12<=12.2" ale_py==0.9.0 "autorom[accept-rom-license]==0.6.1"

Code Evidence

JAX platform and compute dtype configuration from `embodied/jax/internal.py:L15-109`:

def setup(
    platform=None,
    compute_dtype=jnp.bfloat16,
    debug=False,
    jit=True,
    prealloc=False,
    ...
):
  platform and jax.config.update('jax_platforms', platform)
  os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = str(bool(prealloc)).lower()

Default platform set to cuda in `dreamerv3/configs.yaml:L73`:

jax:
    platform: cuda
    compute_dtype: bfloat16
    prealloc: True

GPU XLA flags from `embodied/jax/internal.py:L54-79`:

  if gpuflags and platform == 'gpu':
    xlaflags += [
        '--xla_disable_hlo_passes=rematerialization',
        '--xla_gpu_enable_latency_hiding_scheduler=true',
        '--xla_gpu_enable_pipelined_all_gather=true',
        '--xla_gpu_enable_pipelined_all_reduce=true',
        '--xla_gpu_enable_triton_gemm=false',
        '--xla_gpu_enable_triton_softmax_fusion=false',
        '--xla_gpu_graph_level=0',
        ...
    ]

Device validation from `embodied/jax/agent.py:L72-79`:

    available = jax.devices()
    elements.print(f'JAX devices ({jax.device_count()}):', available)
    if self.jaxcfg.expect_devices > 0:
      if len(available) != self.jaxcfg.expect_devices:
        print('ALERT: Wrong number of devices')
        while True:
          time.sleep(1)
    assert len(available) == jax.process_count() * jax.local_device_count()

Python version and requirements from `requirements.txt`:

jax[cuda12]==0.4.33
ninjax>=3.5.1
numpy<2  # DMLab: <2, MineRLv1.0: <1.24
nvidia-cuda-nvcc-cu12<=12.2
elements>=3.19.1
portal>=3.5.0

Common Errors

Error Message Cause Solution
`Too many leaves for PyTreeDef` Checkpoint incompatible with current config Ensure logdir matches the config used to create it; do not reuse old logdirs with different architecture settings
CUDA errors (various) Often caused by OOM or JAX/CUDA version mismatch Scroll up in logs for root cause; try `--batch_size 1` to rule out OOM; verify JAX and CUDA versions match
`ALERT: Wrong number of devices` (hangs) `jax.expect_devices` set but available devices don't match Check GPU availability; set `--jax.expect_devices 0` to disable check
`Inter-node TP is not supported!` Tensor parallelism dimension exceeds local devices Reduce the `t` dimension in mesh shape to fit within a single node

Compatibility Notes

  • CPU mode: Use `--jax.platform cpu` for debugging. The `debug` config preset sets platform to cpu, disables preallocation, and reduces model size.
  • TPU mode: Use `--jax.platform tpu`. Separate XLA flags are applied for TPU including async collective fusion and megacore settings.
  • Float16: When using float16 compute dtype, the optimizer automatically enables gradient scaling via `optax.apply_if_finite` with dynamic loss scaling (embodied/jax/opt.py:L25-29).
  • Bfloat16 (default): Requires Ampere or newer GPU for native support. Prioritized replay is incompatible with float16 — an assertion enforces bfloat16 or float32 when priority fracs are non-zero (dreamerv3/main.py:L198-200).
  • Multi-host: Distributed training requires `--jax.coordinator_address` and uses `jax.distributed.initialize()`. Process ID is set via config, not via `jax.process_index()`.
  • numpy < 2: Required for DMLab and MineRL environment compatibility.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment