Heuristic:Danijar Dreamerv3 XLA GPU Optimization Flags
| Knowledge Sources | |
|---|---|
| Domains | Infrastructure, Optimization, Deep_Learning |
| Last Updated | 2026-02-15 09:00 GMT |
Overview
Curated set of XLA compiler flags for GPU execution that disable Triton kernels, enable latency-hiding scheduling, and configure pipelined collective operations for optimal DreamerV3 throughput.
Description
DreamerV3 sets a comprehensive set of `XLA_FLAGS` environment variables when running on GPU to optimize JAX/XLA compilation and execution. These flags were empirically tuned and address several categories:
Disabled features:
- Rematerialization is disabled (`--xla_disable_hlo_passes=rematerialization`) — the model fits in memory without recomputation
- Triton GEMM and softmax fusion are disabled — the XLA cuBLAS/cuDNN kernels are faster for this workload
- CUDA graphs are disabled (`--xla_gpu_graph_level=0`) — the dynamic control flow in RSSM scan doesn't benefit
Enabled optimizations:
- Latency-hiding scheduler — overlaps computation with communication
- Pipelined all-gather, all-reduce, and reduce-scatter — overlaps collective operations
- Double-buffered while loops — improves scan/loop throughput
- Large combine thresholds (128MB for all-gather/all-reduce) — reduces the number of collective operations
Separate TPU-specific flags are also provided for TPU platform execution.
Usage
Applied automatically when `platform == 'gpu'` and `gpuflags == True` (the default). Set `--jax.platform cpu` to bypass these flags for debugging.
The Insight (Rule of Thumb)
- Action: When running DreamerV3 on GPU, let the default XLA flags be applied. Do not set `XLA_FLAGS` manually as the code overwrites them. If you need custom flags, modify `embodied/jax/internal.py`.
- Value: The flags disable Triton (which underperforms for this workload), disable CUDA graphs (incompatible with dynamic scan), and enable pipelined collectives.
- Trade-off: Disabling rematerialization increases peak memory usage but improves throughput. Disabling Triton kernels may reduce performance on future hardware where Triton outperforms cuBLAS.
- Compatibility: These flags are JAX/XLA-version-specific. Some flags may become no-ops or cause warnings on newer XLA versions.
Reasoning
The DreamerV3 training loop has a complex computation graph with recurrent scans (RSSM observe/imagine), which doesn't map well to CUDA graphs. Triton kernels in XLA were found to be slower than the default cuBLAS/cuDNN kernels for the matrix sizes in DreamerV3 (hidden=1024-1536, batch not very large). Pipelined collectives are essential for multi-GPU training where data-parallel gradient synchronization would otherwise be on the critical path.
The rematerialization disable is important: XLA's automatic rematerialization can choose to recompute activations that are needed for the backward pass, but for DreamerV3's model sizes the activations fit in memory, making rematerialization pure overhead.
Implementation from `embodied/jax/internal.py:L54-79`:
if gpuflags and platform == 'gpu':
xlaflags += [
'--xla_disable_hlo_passes=rematerialization',
'--xla_gpu_all_gather_combine_threshold_bytes=134217728',
'--xla_gpu_all_reduce_combine_threshold_bytes=134217728',
'--xla_gpu_enable_all_gather_combine_by_dim=false',
'--xla_gpu_enable_highest_priority_async_stream=true',
'--xla_gpu_enable_latency_hiding_scheduler=true',
'--xla_gpu_enable_pipelined_all_gather=true',
'--xla_gpu_enable_pipelined_all_reduce=true',
'--xla_gpu_enable_pipelined_reduce_scatter=true',
'--xla_gpu_enable_reduce_scatter_combine_by_dim=false',
'--xla_gpu_enable_triton_gemm=false',
'--xla_gpu_enable_triton_softmax_fusion=false',
'--xla_gpu_enable_while_loop_double_buffering=true',
'--xla_gpu_graph_level=0',
'--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864',
]
Additional JAX configuration from the same function:
jax.config.update('jax_disable_most_optimizations', debug)
jax.config.update('jax_disable_jit', not jit)
if transfer_guard and jit and not debug_nans:
jax.config.update('jax_transfer_guard', 'disallow')
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = str(bool(prealloc)).lower()