Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Environment:Huggingface Datasets JAX Integration

From Leeroopedia
Knowledge Sources
Domains Deep_Learning, Data_Processing
Last Updated 2026-02-14 19:00 GMT

Overview

JAX integration in HuggingFace Datasets enables the library to produce native jax.Array outputs with device placement and dtype selection that respects JAX's 64-bit precision configuration. Unlike the TensorFlow and PyTorch integrations, JAX requires two separate packages (jax and jaxlib) to both be importable, and the integration is not supported on Windows.

Description

At startup the library reads the USE_JAX environment variable (default: "AUTO"). When the value is in the auto-or-true set, the detection routine checks that both importlib.util.find_spec("jax") and importlib.util.find_spec("jaxlib") return non-None values. Both packages must be present for JAX_AVAILABLE to be set to True. The JAX version is then resolved from the jax package metadata.

When JAX is available, the JaxFormatter class is registered under the format type "jax" with no aliases. When JAX is unavailable, a placeholder is registered that raises ValueError("JAX needs to be installed to be able to return JAX arrays.") on use.

The JaxFormatter handles several JAX-specific concerns:

  • Device validation: The constructor checks whether the provided device argument is a jaxlib.xla_client.Device object. If so, it raises a ValueError because Device objects are not serializable with either pickle or dill. Users must pass the device as a string identifier instead. A global DEVICE_MAPPING dictionary maps string identifiers to actual device objects.
  • Precision-aware dtype selection: When tensorizing integer values, the formatter checks jax.config.jax_enable_x64. If 64-bit mode is enabled, integers default to jnp.int64; otherwise they default to jnp.int32. Floating-point values default to jnp.float32 regardless of the x64 setting.
  • Device placement: All array creation is wrapped in a jax.default_device(DEVICE_MAPPING[self.device]) context manager, ensuring tensors land on the intended device.

Usage

Set the dataset format to JAX arrays:

from datasets import load_dataset

dataset = load_dataset("glue", "mrpc", split="train")
dataset.set_format(type="jax", columns=["input_ids", "attention_mask", "label"])

Specify a target device by string identifier:

dataset.set_format(type="jax", columns=["input_ids"], device="cpu:0")

Access formatted samples (returns jax.Array objects):

sample = dataset[0]
print(type(sample["input_ids"]))  # <class 'jaxlib.xla_extension.ArrayImpl'>

System Requirements

  • Python: 3.9+
  • Operating System: Linux or macOS only. JAX dependencies carry the marker sys_platform != 'win32' in both the extras and test requirements. JAX is not supported on Windows within this project.

Dependencies

Dependency Version Constraint Notes
jax >=0.3.14 Core JAX library; must be importable for integration to activate
jaxlib >=0.3.14 XLA compilation backend; must also be importable alongside jax

Install via the extras:

pip install datasets[jax]

Credentials

No credentials are required for JAX integration itself. Standard HuggingFace Hub authentication (token-based) is used when downloading datasets from the Hub but is independent of the JAX environment.

Quick Install

pip install datasets[jax]

To verify the integration is active:

from datasets import config
print(f"JAX available: {config.JAX_AVAILABLE}")
print(f"JAX version:   {config.JAX_VERSION}")

Code Evidence

Environment variable and detection logic (from src/datasets/config.py lines 44, 117-129):

USE_JAX = os.environ.get("USE_JAX", "AUTO").upper()

JAX_VERSION = "N/A"
JAX_AVAILABLE = False

if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
    JAX_AVAILABLE = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("jaxlib") is not None
    if JAX_AVAILABLE:
        try:
            JAX_VERSION = version.parse(importlib.metadata.version("jax"))
            logger.info(f"JAX version {JAX_VERSION} available.")
        except importlib.metadata.PackageNotFoundError:
            pass
else:
    logger.info("Disabling JAX because USE_JAX is set to False")

Formatter registration (from src/datasets/formatting/__init__.py lines 106-112):

if config.JAX_AVAILABLE:
    from .jax_formatter import JaxFormatter
    _register_formatter(JaxFormatter, "jax", aliases=[])
else:
    _jax_error = ValueError("JAX needs to be installed to be able to return JAX arrays.")
    _register_unavailable_formatter(_jax_error, "jax", aliases=[])

Device validation in JaxFormatter constructor (from src/datasets/formatting/jax_formatter.py lines 39-64):

class JaxFormatter(TensorFormatter[Mapping, "jax.Array", Mapping]):
    def __init__(self, features=None, device=None, token_per_repo_id=None, **jnp_array_kwargs):
        super().__init__(features=features, token_per_repo_id=token_per_repo_id)
        import jax
        from jaxlib.xla_client import Device

        if isinstance(device, Device):
            raise ValueError(
                f"Expected {device} to be a `str` not {type(device)}, as `jaxlib.xla_extension.Device` "
                "is not serializable neither with `pickle` nor with `dill`. Instead you can surround "
                "the device with `str()` to get its string identifier that will be internally mapped "
                "to the actual `jaxlib.xla_extension.Device`."
            )
        self.device = device if isinstance(device, str) else str(jax.devices()[0])

Precision-aware dtype selection (from src/datasets/formatting/jax_formatter.py lines 92-102):

default_dtype = {}

if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer):
    if jax.config.jax_enable_x64:
        default_dtype = {"dtype": jnp.int64}
    else:
        default_dtype = {"dtype": jnp.int32}
elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
    default_dtype = {"dtype": jnp.float32}

Device-aware tensor creation (from src/datasets/formatting/jax_formatter.py lines 126-129):

with jax.default_device(DEVICE_MAPPING[self.device]):
    return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs})

Extras in setup.py (from setup.py line 222):

"jax": ["jax>=0.3.14", "jaxlib>=0.3.14"],

Windows exclusion in test dependencies (from setup.py lines 171-172):

"jax>=0.3.14; sys_platform != 'win32'",
"jaxlib>=0.3.14; sys_platform != 'win32'",

Common Errors

Error Cause Resolution
ValueError: JAX needs to be installed to be able to return JAX arrays. set_format(type="jax") called when JAX or jaxlib is not installed Install both packages: pip install jax>=0.3.14 jaxlib>=0.3.14
ValueError: Expected {device} to be a `str` not {type}, as `jaxlib.xla_extension.Device` is not serializable... A jaxlib.xla_extension.Device object was passed directly as the device parameter Wrap the device with str() to pass its string identifier instead of the object
Device not listed among available devices (warning, falls back to default) The string device identifier does not match any device returned by jax.devices() Use a valid device string such as "cpu:0" or "gpu:0" matching the output of jax.devices()
Disabling JAX because USE_JAX is set to False (info log) USE_JAX environment variable is set to a false value (0, OFF, NO, FALSE) Unset USE_JAX or set it to AUTO or 1
Integer precision mismatch (int32 vs int64) JAX defaults to 32-bit integers unless jax_enable_x64 is enabled Set jax.config.update("jax_enable_x64", True) before loading data if 64-bit integers are needed

Compatibility Notes

  • Windows: JAX is explicitly excluded on Windows. Both the extras and test dependencies carry the sys_platform != 'win32' platform marker.
  • Dual-package requirement: Unlike TensorFlow or PyTorch, JAX requires both jax and jaxlib to be importable. Having only one installed will result in JAX_AVAILABLE = False.
  • No mutual exclusion: Unlike the TensorFlow/PyTorch pair, JAX detection has no mutual exclusion logic with other frameworks. JAX can be active simultaneously with PyTorch or TensorFlow.
  • Serialization constraint: jaxlib.xla_extension.Device objects cannot be serialized with pickle or dill. The formatter uses a global DEVICE_MAPPING dictionary and string-based device references to work around this limitation. This design means device objects are re-resolved from strings after deserialization.
  • 64-bit precision: JAX's default 32-bit behavior affects dtype selection in the formatter. Code relying on 64-bit integer precision must explicitly enable jax_enable_x64 before data loading.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment