Implementation:Huggingface Datasets Dataset To Tf Dataset
| Knowledge Sources | |
|---|---|
| Domains | Data_Engineering, NLP |
| Last Updated | 2026-02-14 18:00 GMT |
Overview
Concrete tool for creating a TensorFlow tf.data.Dataset pipeline from a HuggingFace Dataset provided by the HuggingFace Datasets library.
Description
Dataset.to_tf_dataset is a method that constructs a fully functional tf.data.Dataset suitable for model.fit(), model.evaluate(), or model.predict(). The method infers the output tensor signatures (shapes and dtypes) by running a configurable number of test batches through the collate function, then creates a tf.data.Dataset.from_generator backed by the HuggingFace Dataset. It supports batching, shuffling, drop-remainder, custom collate functions with arguments, feature/label column separation, prefetching with AUTOTUNE, and multiprocessing via num_workers. The dataset is automatically formatted to NumPy before being passed to the TF pipeline.
Usage
Use Dataset.to_tf_dataset when you need a tf.data.Dataset for training or evaluating TensorFlow/Keras models. This is the primary integration point and handles all the plumbing needed to bridge HuggingFace datasets and TensorFlow's data pipeline API.
Code Reference
Source Location
- Repository: datasets
- File:
src/datasets/arrow_dataset.py - Lines: L330-L535
Signature
def to_tf_dataset(
self,
batch_size: Optional[int] = None,
columns: Optional[Union[str, list[str]]] = None,
shuffle: bool = False,
collate_fn: Optional[Callable] = None,
drop_remainder: bool = False,
collate_fn_args: Optional[dict[str, Any]] = None,
label_cols: Optional[Union[str, list[str]]] = None,
prefetch: bool = True,
num_workers: int = 0,
num_test_batches: int = 20,
):
Import
from datasets import Dataset
# to_tf_dataset is a method on Dataset instances
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| batch_size | Optional[int] |
No | Size of batches. None means unbatched (can be batched later with tf_dataset.batch()). |
| columns | Optional[Union[str, list[str]]] |
No | Column(s) to include in the output. Can include columns created by collate_fn. |
| shuffle | bool |
No | Whether to shuffle the dataset order. Recommended True for training. Defaults to False. |
| collate_fn | Optional[Callable] |
No | Function to collate samples into batches (e.g., a DataCollator). Defaults to simple stacking. |
| drop_remainder | bool |
No | Drop the last incomplete batch. Defaults to False. |
| collate_fn_args | Optional[dict[str, Any]] |
No | Keyword arguments passed to the collate_fn. |
| label_cols | Optional[Union[str, list[str]]] |
No | Column(s) to use as labels. Separated from features in the output for model.fit() compatibility. |
| prefetch | bool |
No | Whether to prefetch batches in the background. Defaults to True. |
| num_workers | int |
No | Number of worker processes for parallel data loading. Defaults to 0 (main process). |
| num_test_batches | int |
No | Number of batches used to infer the output signature. Defaults to 20. |
Outputs
| Name | Type | Description |
|---|---|---|
| tf_dataset | tf.data.Dataset |
A TensorFlow Dataset pipeline ready for training or evaluation. |
Usage Examples
Basic Usage
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding
ds = load_dataset("cornell-movie-review-data/rotten_tomatoes", split="train")
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
ds = ds.map(lambda x: tokenizer(x["text"], truncation=True), batched=True)
data_collator = DataCollatorWithPadding(tokenizer, return_tensors="np")
tf_dataset = ds.to_tf_dataset(
columns=["input_ids", "token_type_ids", "attention_mask", "label"],
shuffle=True,
batch_size=16,
collate_fn=data_collator,
)
# Use with model.fit()
# model.fit(tf_dataset, epochs=3)