Environment:Danijar Dreamerv3 JAX CUDA
| Knowledge Sources | |
|---|---|
| Domains | Infrastructure, Deep_Learning, Reinforcement_Learning |
| Last Updated | 2026-02-15 09:00 GMT |
Overview
Linux environment with CUDA 12, JAX 0.4.33+, Python 3.11+, and bfloat16 compute for training DreamerV3 world models.
Description
This environment provides a GPU-accelerated JAX runtime for training the DreamerV3 reinforcement learning agent. The default compute dtype is bfloat16 (configurable to float32 or float16). The JAX platform defaults to cuda and requires NVIDIA GPU hardware with CUDA 12 support. The codebase uses XLA compiler flags optimized for GPU execution including latency-hiding schedulers, pipelined all-reduce/all-gather, and disabled rematerialization. Multi-host distributed training is supported via jax.distributed.initialize() with a coordinator address.
Usage
Use this environment for all DreamerV3 training workflows: single-process training, train-and-evaluate, evaluation-only, and distributed parallel training. The JAX platform can be switched to cpu or tpu via the --jax.platform flag, but GPU (CUDA) is the default and recommended platform.
System Requirements
| Category | Requirement | Notes |
|---|---|---|
| OS | Linux (tested) or macOS | README states tested on Linux and Mac |
| Python | 3.11+ | README requirement |
| Hardware | NVIDIA GPU with CUDA 12 support | Default platform is cuda; bfloat16 requires Ampere (A100) or newer for native support |
| Disk | Sufficient for replay buffer | Default replay size is 5M transitions; chunks stored on disk |
Dependencies
System Packages
- `nvidia-cuda-nvcc-cu12` <= 12.2 (CUDA compiler)
- CUDA 12 toolkit (for JAX GPU backend)
Python Packages
- `jax[cuda12]` == 0.4.33
- `ninjax` >= 3.5.1
- `optax` (JAX optimizer library)
- `chex` (JAX testing utilities)
- `jaxtyping` (JAX type annotations)
- `einops` (tensor rearrangement)
- `elements` >= 3.19.1 (training infrastructure)
- `portal` >= 3.5.0 (RPC/process management)
- `granular` >= 0.20.3
- `scope` >= 0.4.4
- `numpy` < 2 (DMLab/MineRL compatibility constraint)
- `colored_traceback`
- `tqdm`
- `ipdb`
- `av` (video encoding)
- `google-resumable-media` >= 2.7.2
Environment-Specific Packages
- `ale_py` == 0.9.0 (Atari)
- `autorom[accept-rom-license]` == 0.6.1 (Atari ROM management)
Credentials
The following environment variables may be needed depending on configuration:
- `JOB_COMPLETION_INDEX`: Optional Kubernetes job index for multi-replica training (dreamerv3/main.py:L33-34).
- `XLA_FLAGS`: Automatically set by the setup function with GPU/TPU optimization flags (embodied/jax/internal.py:L43-94).
- `XLA_PYTHON_CLIENT_PREALLOCATE`: Set automatically to control JAX GPU memory preallocation (embodied/jax/internal.py:L39).
- `TF_CUDNN_DETERMINISTIC`: Set to '1' when deterministic mode is enabled (embodied/jax/internal.py:L46).
Quick Install
# Install JAX with CUDA 12 support
pip install jax[cuda12]==0.4.33
# Install all other requirements
pip install -U -r requirements.txt
# Or install directly
pip install ninjax>=3.5.1 optax chex jaxtyping einops elements>=3.19.1 portal>=3.5.0 granular>=0.20.3 scope>=0.4.4 "numpy<2" colored_traceback tqdm ipdb av "google-resumable-media>=2.7.2" "nvidia-cuda-nvcc-cu12<=12.2" ale_py==0.9.0 "autorom[accept-rom-license]==0.6.1"
Code Evidence
JAX platform and compute dtype configuration from `embodied/jax/internal.py:L15-109`:
def setup(
platform=None,
compute_dtype=jnp.bfloat16,
debug=False,
jit=True,
prealloc=False,
...
):
platform and jax.config.update('jax_platforms', platform)
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = str(bool(prealloc)).lower()
Default platform set to cuda in `dreamerv3/configs.yaml:L73`:
jax:
platform: cuda
compute_dtype: bfloat16
prealloc: True
GPU XLA flags from `embodied/jax/internal.py:L54-79`:
if gpuflags and platform == 'gpu':
xlaflags += [
'--xla_disable_hlo_passes=rematerialization',
'--xla_gpu_enable_latency_hiding_scheduler=true',
'--xla_gpu_enable_pipelined_all_gather=true',
'--xla_gpu_enable_pipelined_all_reduce=true',
'--xla_gpu_enable_triton_gemm=false',
'--xla_gpu_enable_triton_softmax_fusion=false',
'--xla_gpu_graph_level=0',
...
]
Device validation from `embodied/jax/agent.py:L72-79`:
available = jax.devices()
elements.print(f'JAX devices ({jax.device_count()}):', available)
if self.jaxcfg.expect_devices > 0:
if len(available) != self.jaxcfg.expect_devices:
print('ALERT: Wrong number of devices')
while True:
time.sleep(1)
assert len(available) == jax.process_count() * jax.local_device_count()
Python version and requirements from `requirements.txt`:
jax[cuda12]==0.4.33
ninjax>=3.5.1
numpy<2 # DMLab: <2, MineRLv1.0: <1.24
nvidia-cuda-nvcc-cu12<=12.2
elements>=3.19.1
portal>=3.5.0
Common Errors
| Error Message | Cause | Solution |
|---|---|---|
| `Too many leaves for PyTreeDef` | Checkpoint incompatible with current config | Ensure logdir matches the config used to create it; do not reuse old logdirs with different architecture settings |
| CUDA errors (various) | Often caused by OOM or JAX/CUDA version mismatch | Scroll up in logs for root cause; try `--batch_size 1` to rule out OOM; verify JAX and CUDA versions match |
| `ALERT: Wrong number of devices` (hangs) | `jax.expect_devices` set but available devices don't match | Check GPU availability; set `--jax.expect_devices 0` to disable check |
| `Inter-node TP is not supported!` | Tensor parallelism dimension exceeds local devices | Reduce the `t` dimension in mesh shape to fit within a single node |
Compatibility Notes
- CPU mode: Use `--jax.platform cpu` for debugging. The `debug` config preset sets platform to cpu, disables preallocation, and reduces model size.
- TPU mode: Use `--jax.platform tpu`. Separate XLA flags are applied for TPU including async collective fusion and megacore settings.
- Float16: When using float16 compute dtype, the optimizer automatically enables gradient scaling via `optax.apply_if_finite` with dynamic loss scaling (embodied/jax/opt.py:L25-29).
- Bfloat16 (default): Requires Ampere or newer GPU for native support. Prioritized replay is incompatible with float16 — an assertion enforces bfloat16 or float32 when priority fracs are non-zero (dreamerv3/main.py:L198-200).
- Multi-host: Distributed training requires `--jax.coordinator_address` and uses `jax.distributed.initialize()`. Process ID is set via config, not via `jax.process_index()`.
- numpy < 2: Required for DMLab and MineRL environment compatibility.