Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Datasets TorchFormatter

From Leeroopedia
Knowledge Sources
Domains Data_Engineering, NLP
Last Updated 2026-02-14 18:00 GMT

Overview

Concrete tool for converting Arrow table data to PyTorch tensors provided by the HuggingFace Datasets library.

Description

TorchFormatter is a formatting class that extends TensorFormatter and converts Arrow data to torch.Tensor objects. It provides three main formatting methods: format_row (single example), format_column (single column), and format_batch (batch of examples). The conversion process first extracts NumPy arrays from Arrow, then recursively tensorizes the data structure, applying default dtypes (int64 for integers, float32 for floats), handling PIL images (converting to CHW-ordered tensors), unsigned integer type coercion, and consolidating lists of same-shaped tensors via torch.stack. String, bytes, and None values are passed through unchanged.

Usage

TorchFormatter is typically not instantiated directly by users. It is automatically selected when Dataset.with_format("torch") or Dataset.set_format("torch") is called. It powers the tensor conversion layer for all PyTorch-formatted dataset access.

Code Reference

Source Location

  • Repository: datasets
  • File: src/datasets/formatting/torch_formatter.py
  • Lines: L32-L127

Signature

class TorchFormatter(TensorFormatter[Mapping, "torch.Tensor", Mapping]):
    def __init__(self, features=None, token_per_repo_id=None, **torch_tensor_kwargs):

    def _consolidate(self, column):
    def _tensorize(self, value):
    def _recursive_tensorize(self, data_struct):
    def recursive_tensorize(self, data_struct: dict):
    def format_row(self, pa_table: pa.Table) -> Mapping:
    def format_column(self, pa_table: pa.Table) -> "torch.Tensor":
    def format_batch(self, pa_table: pa.Table) -> Mapping:

Import

from datasets.formatting.torch_formatter import TorchFormatter

I/O Contract

Inputs

Name Type Required Description
features Optional[Features] No Dataset features for decoding special types (e.g., Image, Audio).
token_per_repo_id Optional[dict] No Authentication tokens for accessing private repositories.
**torch_tensor_kwargs No Additional keyword arguments forwarded to torch.tensor() (e.g., dtype, device).
pa_table pa.Table Yes (for format methods) The Arrow table to convert. Passed to format_row, format_column, or format_batch.

Outputs

Name Type Description
row Mapping A dict mapping column names to torch.Tensor values (from format_row).
column torch.Tensor A single tensor for the column (from format_column).
batch Mapping A dict mapping column names to batched torch.Tensor values (from format_batch).

Usage Examples

Basic Usage

from datasets import load_dataset

# TorchFormatter is used automatically when format is "torch"
ds = load_dataset("cornell-movie-review-data/rotten_tomatoes", split="train")
ds = ds.with_format("torch")

# Accessing a row returns torch tensors
row = ds[0]
print(type(row["label"]))  # <class 'torch.Tensor'>

# Accessing a batch returns stacked torch tensors
batch = ds[:8]
print(batch["label"].shape)  # torch.Size([8])

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment