Implementation:Farama Foundation Gymnasium JaxToNumpy
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, Wrappers |
| Last Updated | 2026-02-15 03:00 GMT |
Overview
A convenience wrapper that converts a JAX-based Gymnasium environment so that it can be interacted with using NumPy arrays.
Description
The JaxToNumpy wrapper is a thin subclass of ArrayConversion that pre-configures the conversion from JAX arrays to NumPy arrays. Actions provided as NumPy arrays are automatically converted to JAX arrays for the underlying environment, and observations are returned as NumPy arrays.
The module also provides two utility functions:
jax_to_numpy-- A partial application ofarray_conversiontargeting the NumPy namespace, for converting JAX arrays to NumPy.numpy_to_jax-- A partial application ofarray_conversiontargeting the JAX namespace, for converting NumPy arrays to JAX.
Note: The JAX to NumPy and NumPy to JAX conversion does not guarantee a roundtrip, as JAX does not support non-array scalar values.
Requires the jax package.
Usage
Use this wrapper when you have a JAX-based environment but your agent or training code operates on NumPy arrays. This is common when using JAX environments with non-JAX training frameworks.
Code Reference
Source Location
- Repository: Farama_Foundation_Gymnasium
- File:
gymnasium/wrappers/jax_to_numpy.py
Signature
class JaxToNumpy(ArrayConversion):
def __init__(self, env: gym.Env[ObsType, ActType]): ...
Import
from gymnasium.wrappers import JaxToNumpy
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| env | Env | Yes | The JAX-based environment to wrap |
Outputs
| Name | Type | Description |
|---|---|---|
| observation | numpy.ndarray | Observation converted to NumPy array |
| reward | float | Reward as a Python float |
| terminated | bool | Termination flag as a Python bool |
| truncated | bool | Truncation flag as a Python bool |
| info | dict | Info dict with values converted to NumPy arrays |
Usage Examples
import gymnasium as gym
from gymnasium.wrappers import JaxToNumpy
env = gym.make("JaxEnv-vx")
env = JaxToNumpy(env)
obs, _ = env.reset(seed=123)
type(obs) # <class 'numpy.ndarray'>
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
type(obs) # <class 'numpy.ndarray'>
type(reward) # <class 'float'>