Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Farama Foundation Gymnasium JaxToNumpy

From Leeroopedia
Revision as of 12:37, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Farama_Foundation_Gymnasium_JaxToNumpy.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Knowledge Sources
Domains Reinforcement_Learning, Wrappers
Last Updated 2026-02-15 03:00 GMT

Overview

A convenience wrapper that converts a JAX-based Gymnasium environment so that it can be interacted with using NumPy arrays.

Description

The JaxToNumpy wrapper is a thin subclass of ArrayConversion that pre-configures the conversion from JAX arrays to NumPy arrays. Actions provided as NumPy arrays are automatically converted to JAX arrays for the underlying environment, and observations are returned as NumPy arrays.

The module also provides two utility functions:

  • jax_to_numpy -- A partial application of array_conversion targeting the NumPy namespace, for converting JAX arrays to NumPy.
  • numpy_to_jax -- A partial application of array_conversion targeting the JAX namespace, for converting NumPy arrays to JAX.

Note: The JAX to NumPy and NumPy to JAX conversion does not guarantee a roundtrip, as JAX does not support non-array scalar values.

Requires the jax package.

Usage

Use this wrapper when you have a JAX-based environment but your agent or training code operates on NumPy arrays. This is common when using JAX environments with non-JAX training frameworks.

Code Reference

Source Location

Signature

class JaxToNumpy(ArrayConversion):
    def __init__(self, env: gym.Env[ObsType, ActType]): ...

Import

from gymnasium.wrappers import JaxToNumpy
from gymnasium.wrappers.jax_to_numpy import jax_to_numpy, numpy_to_jax

I/O Contract

Inputs

Name Type Required Description
env Env Yes The JAX-based environment to wrap

Outputs

Name Type Description
observation numpy.ndarray Observation converted to NumPy array
reward float Reward as a Python float
terminated bool Termination flag as a Python bool
truncated bool Truncation flag as a Python bool
info dict Info dict with values converted to NumPy arrays

Usage Examples

import gymnasium as gym
from gymnasium.wrappers import JaxToNumpy

env = gym.make("JaxEnv-vx")
env = JaxToNumpy(env)
obs, _ = env.reset(seed=123)
type(obs)  # <class 'numpy.ndarray'>

action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)
type(obs)     # <class 'numpy.ndarray'>
type(reward)  # <class 'float'>

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment