Environment:Huggingface Datasets JAX Integration
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Data_Processing |
| Last Updated | 2026-02-14 19:00 GMT |
Overview
JAX integration in HuggingFace Datasets enables the library to produce native jax.Array outputs with device placement and dtype selection that respects JAX's 64-bit precision configuration. Unlike the TensorFlow and PyTorch integrations, JAX requires two separate packages (jax and jaxlib) to both be importable, and the integration is not supported on Windows.
Description
At startup the library reads the USE_JAX environment variable (default: "AUTO"). When the value is in the auto-or-true set, the detection routine checks that both importlib.util.find_spec("jax") and importlib.util.find_spec("jaxlib") return non-None values. Both packages must be present for JAX_AVAILABLE to be set to True. The JAX version is then resolved from the jax package metadata.
When JAX is available, the JaxFormatter class is registered under the format type "jax" with no aliases. When JAX is unavailable, a placeholder is registered that raises ValueError("JAX needs to be installed to be able to return JAX arrays.") on use.
The JaxFormatter handles several JAX-specific concerns:
- Device validation: The constructor checks whether the provided
deviceargument is ajaxlib.xla_client.Deviceobject. If so, it raises aValueErrorbecauseDeviceobjects are not serializable with eitherpickleordill. Users must pass the device as a string identifier instead. A globalDEVICE_MAPPINGdictionary maps string identifiers to actual device objects.
- Precision-aware dtype selection: When tensorizing integer values, the formatter checks
jax.config.jax_enable_x64. If 64-bit mode is enabled, integers default tojnp.int64; otherwise they default tojnp.int32. Floating-point values default tojnp.float32regardless of the x64 setting.
- Device placement: All array creation is wrapped in a
jax.default_device(DEVICE_MAPPING[self.device])context manager, ensuring tensors land on the intended device.
Usage
Set the dataset format to JAX arrays:
from datasets import load_dataset
dataset = load_dataset("glue", "mrpc", split="train")
dataset.set_format(type="jax", columns=["input_ids", "attention_mask", "label"])
Specify a target device by string identifier:
dataset.set_format(type="jax", columns=["input_ids"], device="cpu:0")
Access formatted samples (returns jax.Array objects):
sample = dataset[0] print(type(sample["input_ids"])) # <class 'jaxlib.xla_extension.ArrayImpl'>
System Requirements
- Python: 3.9+
- Operating System: Linux or macOS only. JAX dependencies carry the marker
sys_platform != 'win32'in both the extras and test requirements. JAX is not supported on Windows within this project.
Dependencies
| Dependency | Version Constraint | Notes |
|---|---|---|
jax |
>=0.3.14 |
Core JAX library; must be importable for integration to activate |
jaxlib |
>=0.3.14 |
XLA compilation backend; must also be importable alongside jax
|
Install via the extras:
pip install datasets[jax]
Credentials
No credentials are required for JAX integration itself. Standard HuggingFace Hub authentication (token-based) is used when downloading datasets from the Hub but is independent of the JAX environment.
Quick Install
pip install datasets[jax]
To verify the integration is active:
from datasets import config
print(f"JAX available: {config.JAX_AVAILABLE}")
print(f"JAX version: {config.JAX_VERSION}")
Code Evidence
Environment variable and detection logic (from src/datasets/config.py lines 44, 117-129):
USE_JAX = os.environ.get("USE_JAX", "AUTO").upper()
JAX_VERSION = "N/A"
JAX_AVAILABLE = False
if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES:
JAX_AVAILABLE = importlib.util.find_spec("jax") is not None and importlib.util.find_spec("jaxlib") is not None
if JAX_AVAILABLE:
try:
JAX_VERSION = version.parse(importlib.metadata.version("jax"))
logger.info(f"JAX version {JAX_VERSION} available.")
except importlib.metadata.PackageNotFoundError:
pass
else:
logger.info("Disabling JAX because USE_JAX is set to False")
Formatter registration (from src/datasets/formatting/__init__.py lines 106-112):
if config.JAX_AVAILABLE:
from .jax_formatter import JaxFormatter
_register_formatter(JaxFormatter, "jax", aliases=[])
else:
_jax_error = ValueError("JAX needs to be installed to be able to return JAX arrays.")
_register_unavailable_formatter(_jax_error, "jax", aliases=[])
Device validation in JaxFormatter constructor (from src/datasets/formatting/jax_formatter.py lines 39-64):
class JaxFormatter(TensorFormatter[Mapping, "jax.Array", Mapping]):
def __init__(self, features=None, device=None, token_per_repo_id=None, **jnp_array_kwargs):
super().__init__(features=features, token_per_repo_id=token_per_repo_id)
import jax
from jaxlib.xla_client import Device
if isinstance(device, Device):
raise ValueError(
f"Expected {device} to be a `str` not {type(device)}, as `jaxlib.xla_extension.Device` "
"is not serializable neither with `pickle` nor with `dill`. Instead you can surround "
"the device with `str()` to get its string identifier that will be internally mapped "
"to the actual `jaxlib.xla_extension.Device`."
)
self.device = device if isinstance(device, str) else str(jax.devices()[0])
Precision-aware dtype selection (from src/datasets/formatting/jax_formatter.py lines 92-102):
default_dtype = {}
if isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.integer):
if jax.config.jax_enable_x64:
default_dtype = {"dtype": jnp.int64}
else:
default_dtype = {"dtype": jnp.int32}
elif isinstance(value, (np.number, np.ndarray)) and np.issubdtype(value.dtype, np.floating):
default_dtype = {"dtype": jnp.float32}
Device-aware tensor creation (from src/datasets/formatting/jax_formatter.py lines 126-129):
with jax.default_device(DEVICE_MAPPING[self.device]):
return jnp.array(value, **{**default_dtype, **self.jnp_array_kwargs})
Extras in setup.py (from setup.py line 222):
"jax": ["jax>=0.3.14", "jaxlib>=0.3.14"],
Windows exclusion in test dependencies (from setup.py lines 171-172):
"jax>=0.3.14; sys_platform != 'win32'", "jaxlib>=0.3.14; sys_platform != 'win32'",
Common Errors
| Error | Cause | Resolution |
|---|---|---|
ValueError: JAX needs to be installed to be able to return JAX arrays. |
set_format(type="jax") called when JAX or jaxlib is not installed |
Install both packages: pip install jax>=0.3.14 jaxlib>=0.3.14
|
ValueError: Expected {device} to be a `str` not {type}, as `jaxlib.xla_extension.Device` is not serializable... |
A jaxlib.xla_extension.Device object was passed directly as the device parameter |
Wrap the device with str() to pass its string identifier instead of the object
|
| Device not listed among available devices (warning, falls back to default) | The string device identifier does not match any device returned by jax.devices() |
Use a valid device string such as "cpu:0" or "gpu:0" matching the output of jax.devices()
|
Disabling JAX because USE_JAX is set to False (info log) |
USE_JAX environment variable is set to a false value (0, OFF, NO, FALSE) |
Unset USE_JAX or set it to AUTO or 1
|
| Integer precision mismatch (int32 vs int64) | JAX defaults to 32-bit integers unless jax_enable_x64 is enabled |
Set jax.config.update("jax_enable_x64", True) before loading data if 64-bit integers are needed
|
Compatibility Notes
- Windows: JAX is explicitly excluded on Windows. Both the extras and test dependencies carry the
sys_platform != 'win32'platform marker. - Dual-package requirement: Unlike TensorFlow or PyTorch, JAX requires both
jaxandjaxlibto be importable. Having only one installed will result inJAX_AVAILABLE = False. - No mutual exclusion: Unlike the TensorFlow/PyTorch pair, JAX detection has no mutual exclusion logic with other frameworks. JAX can be active simultaneously with PyTorch or TensorFlow.
- Serialization constraint:
jaxlib.xla_extension.Deviceobjects cannot be serialized withpickleordill. The formatter uses a globalDEVICE_MAPPINGdictionary and string-based device references to work around this limitation. This design means device objects are re-resolved from strings after deserialization. - 64-bit precision: JAX's default 32-bit behavior affects dtype selection in the formatter. Code relying on 64-bit integer precision must explicitly enable
jax_enable_x64before data loading.
Related Pages
- Huggingface_Datasets_JaxFormatter -- The
JaxFormatterclass that converts Arrow tables to JAX arrays with device placement - Huggingface_Datasets_Dataset_Set_Format -- The
set_format()method used to select the JAX output format