Implementation:Huggingface Datasets JaxFormatter
| Knowledge Sources | |
|---|---|
| Domains | Data_Engineering, NLP |
| Last Updated | 2026-02-14 18:00 GMT |
Overview
Concrete tool for converting Arrow table data to JAX arrays provided by the HuggingFace Datasets library.
Description
JaxFormatter is a formatting class that extends TensorFormatter and converts Arrow data to jax.Array objects. It provides three main formatting methods: format_row (single example), format_column (single column), and format_batch (batch of examples). The conversion process first extracts NumPy arrays from Arrow, then recursively tensorizes the data structure using jnp.array() within a jax.default_device context manager for device placement. Default dtypes depend on JAX's x64 mode: int32 or int64 for integers, float32 for floats. A device string parameter (resolved to a jaxlib.xla_extension.Device via a global mapping) controls which device arrays are placed on. Lists of same-shaped arrays are consolidated via jnp.stack.
Usage
JaxFormatter is typically not instantiated directly by users. It is automatically selected when Dataset.with_format("jax") or Dataset.set_format("jax") is called. It powers the tensor conversion layer for all JAX-formatted dataset access.
Code Reference
Source Location
- Repository: datasets
- File:
src/datasets/formatting/jax_formatter.py - Lines: L38-L171
Signature
class JaxFormatter(TensorFormatter[Mapping, "jax.Array", Mapping]):
def __init__(self, features=None, device=None, token_per_repo_id=None, **jnp_array_kwargs):
@staticmethod
def _map_devices_to_str() -> dict[str, "jaxlib.xla_extension.Device"]:
def _consolidate(self, column):
def _tensorize(self, value):
def _recursive_tensorize(self, data_struct):
def recursive_tensorize(self, data_struct: dict):
def format_row(self, pa_table: pa.Table) -> Mapping:
def format_column(self, pa_table: pa.Table) -> "jax.Array":
def format_batch(self, pa_table: pa.Table) -> Mapping:
Import
from datasets.formatting.jax_formatter import JaxFormatter
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| features | Optional[Features] |
No | Dataset features for decoding special types (e.g., Image, Audio). |
| device | Optional[str] |
No | String identifier of the JAX device to place arrays on (e.g., "cpu:0", "gpu:0"). Defaults to the first available device. |
| token_per_repo_id | Optional[dict] |
No | Authentication tokens for accessing private repositories. |
| **jnp_array_kwargs | No | Additional keyword arguments forwarded to jnp.array() (e.g., dtype). | |
| pa_table | pa.Table |
Yes (for format methods) | The Arrow table to convert. Passed to format_row, format_column, or format_batch. |
Outputs
| Name | Type | Description |
|---|---|---|
| row | Mapping |
A dict mapping column names to jax.Array values (from format_row). |
| column | jax.Array |
A single JAX array for the column (from format_column). |
| batch | Mapping |
A dict mapping column names to batched jax.Array values (from format_batch). |
Usage Examples
Basic Usage
from datasets import load_dataset
# JaxFormatter is used automatically when format is "jax"
ds = load_dataset("cornell-movie-review-data/rotten_tomatoes", split="train")
ds = ds.with_format("jax")
# Accessing a row returns JAX arrays
row = ds[0]
print(type(row["label"])) # <class 'jaxlib.xla_extension.ArrayImpl'>
# Accessing a batch returns stacked JAX arrays
batch = ds[:8]
print(batch["label"].shape) # (8,)
# Specify a device
ds = ds.with_format("jax", device="cpu:0")