Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Datasets JaxFormatter

From Leeroopedia
Knowledge Sources
Domains Data_Engineering, NLP
Last Updated 2026-02-14 18:00 GMT

Overview

Concrete tool for converting Arrow table data to JAX arrays provided by the HuggingFace Datasets library.

Description

JaxFormatter is a formatting class that extends TensorFormatter and converts Arrow data to jax.Array objects. It provides three main formatting methods: format_row (single example), format_column (single column), and format_batch (batch of examples). The conversion process first extracts NumPy arrays from Arrow, then recursively tensorizes the data structure using jnp.array() within a jax.default_device context manager for device placement. Default dtypes depend on JAX's x64 mode: int32 or int64 for integers, float32 for floats. A device string parameter (resolved to a jaxlib.xla_extension.Device via a global mapping) controls which device arrays are placed on. Lists of same-shaped arrays are consolidated via jnp.stack.

Usage

JaxFormatter is typically not instantiated directly by users. It is automatically selected when Dataset.with_format("jax") or Dataset.set_format("jax") is called. It powers the tensor conversion layer for all JAX-formatted dataset access.

Code Reference

Source Location

  • Repository: datasets
  • File: src/datasets/formatting/jax_formatter.py
  • Lines: L38-L171

Signature

class JaxFormatter(TensorFormatter[Mapping, "jax.Array", Mapping]):
    def __init__(self, features=None, device=None, token_per_repo_id=None, **jnp_array_kwargs):

    @staticmethod
    def _map_devices_to_str() -> dict[str, "jaxlib.xla_extension.Device"]:
    def _consolidate(self, column):
    def _tensorize(self, value):
    def _recursive_tensorize(self, data_struct):
    def recursive_tensorize(self, data_struct: dict):
    def format_row(self, pa_table: pa.Table) -> Mapping:
    def format_column(self, pa_table: pa.Table) -> "jax.Array":
    def format_batch(self, pa_table: pa.Table) -> Mapping:

Import

from datasets.formatting.jax_formatter import JaxFormatter

I/O Contract

Inputs

Name Type Required Description
features Optional[Features] No Dataset features for decoding special types (e.g., Image, Audio).
device Optional[str] No String identifier of the JAX device to place arrays on (e.g., "cpu:0", "gpu:0"). Defaults to the first available device.
token_per_repo_id Optional[dict] No Authentication tokens for accessing private repositories.
**jnp_array_kwargs No Additional keyword arguments forwarded to jnp.array() (e.g., dtype).
pa_table pa.Table Yes (for format methods) The Arrow table to convert. Passed to format_row, format_column, or format_batch.

Outputs

Name Type Description
row Mapping A dict mapping column names to jax.Array values (from format_row).
column jax.Array A single JAX array for the column (from format_column).
batch Mapping A dict mapping column names to batched jax.Array values (from format_batch).

Usage Examples

Basic Usage

from datasets import load_dataset

# JaxFormatter is used automatically when format is "jax"
ds = load_dataset("cornell-movie-review-data/rotten_tomatoes", split="train")
ds = ds.with_format("jax")

# Accessing a row returns JAX arrays
row = ds[0]
print(type(row["label"]))  # <class 'jaxlib.xla_extension.ArrayImpl'>

# Accessing a batch returns stacked JAX arrays
batch = ds[:8]
print(batch["label"].shape)  # (8,)

# Specify a device
ds = ds.with_format("jax", device="cpu:0")

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment