Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Farama Foundation Gymnasium ArrayConversion

From Leeroopedia
Knowledge Sources
Domains Reinforcement_Learning, Wrappers
Last Updated 2026-02-15 03:00 GMT

Overview

A wrapper and set of utility functions for converting between arbitrary Array API compatible frameworks (numpy, torch, jax.numpy, cupy, etc.) in Gymnasium environments.

Description

This module provides the ArrayConversion wrapper class and the array_conversion singledispatch function for converting data between Array API compatible frameworks.

The array_conversion function handles multiple types via singledispatch:

  • numbers.Number -- Converts Python numbers to framework arrays via xp.asarray.
  • Mapping -- Recursively converts all values in a dict/mapping.
  • Iterable -- Recursively converts iterables; detects Array API objects and uses DLPack for zero-copy transfer when possible.
  • NoneType -- Passes through None values unchanged.
  • Array API objects -- Uses xp.from_dlpack for efficient cross-framework conversion, with fallback handling for PyTorch device transfer issues and read-only buffer limitations.

The ArrayConversion wrapper converts environment outputs (observations, info) from the environment's framework to a target framework, and converts actions from the target framework back to the environment's framework. Reward is cast to float, and terminated/truncated are cast to bool. Supports pickling via __getstate__/__setstate__.

Requires array-api-compat package and numpy >= 2.1.0.

Usage

Use this wrapper when your environment uses one Array API framework (e.g., JAX) but your agent code uses another (e.g., PyTorch or NumPy). This is the base wrapper that JaxToNumpy, JaxToTorch, and NumpyToTorch are built upon.

Code Reference

Source Location

Signature

class ArrayConversion(gym.Wrapper, gym.utils.RecordConstructorArgs):
    def __init__(
        self,
        env: gym.Env,
        env_xp: ModuleType,
        target_xp: ModuleType,
        env_device: Device | None = None,
        target_device: Device | None = None,
    ): ...

Import

from gymnasium.wrappers import ArrayConversion
from gymnasium.wrappers.array_conversion import array_conversion

I/O Contract

Inputs

Name Type Required Description
env Env Yes The Array API compatible environment to wrap
env_xp ModuleType Yes The Array API framework the environment uses (e.g., jax.numpy)
target_xp ModuleType Yes The Array API framework to convert outputs to (e.g., torch)
env_device Device or None No The device the environment is on (default None)
target_device Device or None No The device on which output arrays should be placed (default None)

Outputs

Name Type Description
observation target_xp array Observation converted to target framework
reward float Reward as a Python float
terminated bool Termination flag as a Python bool
truncated bool Truncation flag as a Python bool
info dict Info dict with values converted to target framework

Usage Examples

import torch
import jax.numpy as jnp
import gymnasium as gym
from gymnasium.wrappers import ArrayConversion

# Convert a JAX-based environment to return PyTorch tensors
env = gym.make("JaxEnv-vx")
env = ArrayConversion(env, env_xp=jnp, target_xp=torch)
obs, _ = env.reset(seed=123)
type(obs)  # <class 'torch.Tensor'>

action = torch.tensor(env.action_space.sample())
obs, reward, terminated, truncated, info = env.step(action)
type(reward)  # <class 'float'>

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment