Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Danijar Dreamerv3 XLA GPU Optimization Flags

From Leeroopedia
Knowledge Sources
Domains Infrastructure, Optimization, Deep_Learning
Last Updated 2026-02-15 09:00 GMT

Overview

Curated set of XLA compiler flags for GPU execution that disable Triton kernels, enable latency-hiding scheduling, and configure pipelined collective operations for optimal DreamerV3 throughput.

Description

DreamerV3 sets a comprehensive set of `XLA_FLAGS` environment variables when running on GPU to optimize JAX/XLA compilation and execution. These flags were empirically tuned and address several categories:

Disabled features:

  • Rematerialization is disabled (`--xla_disable_hlo_passes=rematerialization`) — the model fits in memory without recomputation
  • Triton GEMM and softmax fusion are disabled — the XLA cuBLAS/cuDNN kernels are faster for this workload
  • CUDA graphs are disabled (`--xla_gpu_graph_level=0`) — the dynamic control flow in RSSM scan doesn't benefit

Enabled optimizations:

  • Latency-hiding scheduler — overlaps computation with communication
  • Pipelined all-gather, all-reduce, and reduce-scatter — overlaps collective operations
  • Double-buffered while loops — improves scan/loop throughput
  • Large combine thresholds (128MB for all-gather/all-reduce) — reduces the number of collective operations

Separate TPU-specific flags are also provided for TPU platform execution.

Usage

Applied automatically when `platform == 'gpu'` and `gpuflags == True` (the default). Set `--jax.platform cpu` to bypass these flags for debugging.

The Insight (Rule of Thumb)

  • Action: When running DreamerV3 on GPU, let the default XLA flags be applied. Do not set `XLA_FLAGS` manually as the code overwrites them. If you need custom flags, modify `embodied/jax/internal.py`.
  • Value: The flags disable Triton (which underperforms for this workload), disable CUDA graphs (incompatible with dynamic scan), and enable pipelined collectives.
  • Trade-off: Disabling rematerialization increases peak memory usage but improves throughput. Disabling Triton kernels may reduce performance on future hardware where Triton outperforms cuBLAS.
  • Compatibility: These flags are JAX/XLA-version-specific. Some flags may become no-ops or cause warnings on newer XLA versions.

Reasoning

The DreamerV3 training loop has a complex computation graph with recurrent scans (RSSM observe/imagine), which doesn't map well to CUDA graphs. Triton kernels in XLA were found to be slower than the default cuBLAS/cuDNN kernels for the matrix sizes in DreamerV3 (hidden=1024-1536, batch not very large). Pipelined collectives are essential for multi-GPU training where data-parallel gradient synchronization would otherwise be on the critical path.

The rematerialization disable is important: XLA's automatic rematerialization can choose to recompute activations that are needed for the backward pass, but for DreamerV3's model sizes the activations fit in memory, making rematerialization pure overhead.

Implementation from `embodied/jax/internal.py:L54-79`:

  if gpuflags and platform == 'gpu':
    xlaflags += [
        '--xla_disable_hlo_passes=rematerialization',
        '--xla_gpu_all_gather_combine_threshold_bytes=134217728',
        '--xla_gpu_all_reduce_combine_threshold_bytes=134217728',
        '--xla_gpu_enable_all_gather_combine_by_dim=false',
        '--xla_gpu_enable_highest_priority_async_stream=true',
        '--xla_gpu_enable_latency_hiding_scheduler=true',
        '--xla_gpu_enable_pipelined_all_gather=true',
        '--xla_gpu_enable_pipelined_all_reduce=true',
        '--xla_gpu_enable_pipelined_reduce_scatter=true',
        '--xla_gpu_enable_reduce_scatter_combine_by_dim=false',
        '--xla_gpu_enable_triton_gemm=false',
        '--xla_gpu_enable_triton_softmax_fusion=false',
        '--xla_gpu_enable_while_loop_double_buffering=true',
        '--xla_gpu_graph_level=0',
        '--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864',
    ]

Additional JAX configuration from the same function:

  jax.config.update('jax_disable_most_optimizations', debug)
  jax.config.update('jax_disable_jit', not jit)
  if transfer_guard and jit and not debug_nans:
    jax.config.update('jax_transfer_guard', 'disallow')
  os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = str(bool(prealloc)).lower()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment