Implementation:Farama Foundation Gymnasium ArrayConversion
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Wrappers |
| Last Updated | 2026-02-15 03:00 GMT |
Overview
A wrapper and set of utility functions for converting between arbitrary Array API compatible frameworks (numpy, torch, jax.numpy, cupy, etc.) in Gymnasium environments.
Description
This module provides the ArrayConversion wrapper class and the array_conversion singledispatch function for converting data between Array API compatible frameworks.
The array_conversion function handles multiple types via singledispatch:
- numbers.Number -- Converts Python numbers to framework arrays via
xp.asarray. - Mapping -- Recursively converts all values in a dict/mapping.
- Iterable -- Recursively converts iterables; detects Array API objects and uses DLPack for zero-copy transfer when possible.
- NoneType -- Passes through None values unchanged.
- Array API objects -- Uses
xp.from_dlpackfor efficient cross-framework conversion, with fallback handling for PyTorch device transfer issues and read-only buffer limitations.
The ArrayConversion wrapper converts environment outputs (observations, info) from the environment's framework to a target framework, and converts actions from the target framework back to the environment's framework. Reward is cast to float, and terminated/truncated are cast to bool. Supports pickling via __getstate__/__setstate__.
Requires array-api-compat package and numpy >= 2.1.0.
Usage
Use this wrapper when your environment uses one Array API framework (e.g., JAX) but your agent code uses another (e.g., PyTorch or NumPy). This is the base wrapper that JaxToNumpy, JaxToTorch, and NumpyToTorch are built upon.
Code Reference
Source Location
- Repository: Farama_Foundation_Gymnasium
- File:
gymnasium/wrappers/array_conversion.py
Signature
class ArrayConversion(gym.Wrapper, gym.utils.RecordConstructorArgs):
def __init__(
self,
env: gym.Env,
env_xp: ModuleType,
target_xp: ModuleType,
env_device: Device | None = None,
target_device: Device | None = None,
): ...
Import
from gymnasium.wrappers import ArrayConversion
from gymnasium.wrappers.array_conversion import array_conversion
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| env | Env | Yes | The Array API compatible environment to wrap |
| env_xp | ModuleType | Yes | The Array API framework the environment uses (e.g., jax.numpy) |
| target_xp | ModuleType | Yes | The Array API framework to convert outputs to (e.g., torch) |
| env_device | Device or None | No | The device the environment is on (default None) |
| target_device | Device or None | No | The device on which output arrays should be placed (default None) |
Outputs
| Name | Type | Description |
|---|---|---|
| observation | target_xp array | Observation converted to target framework |
| reward | float | Reward as a Python float |
| terminated | bool | Termination flag as a Python bool |
| truncated | bool | Truncation flag as a Python bool |
| info | dict | Info dict with values converted to target framework |
Usage Examples
import torch
import jax.numpy as jnp
import gymnasium as gym
from gymnasium.wrappers import ArrayConversion
# Convert a JAX-based environment to return PyTorch tensors
env = gym.make("JaxEnv-vx")
env = ArrayConversion(env, env_xp=jnp, target_xp=torch)
obs, _ = env.reset(seed=123)
type(obs) # <class 'torch.Tensor'>
action = torch.tensor(env.action_space.sample())
obs, reward, terminated, truncated, info = env.step(action)
type(reward) # <class 'float'>