Implementation:Apache Paimon TorchDataset
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Data Integration |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
TorchDataset and TorchIterDataset integrate Paimon table data with PyTorch's training pipeline for ML workloads.
Description
TorchDataset and TorchIterDataset are two PyTorch Dataset implementations that enable reading Paimon table data directly in PyTorch training loops. TorchDataset implements the map-style Dataset interface, loading all data into memory as a Python list and providing random access by index. TorchIterDataset implements the iterable-style IterableDataset interface, streaming data on-demand without loading everything into memory.
TorchDataset converts all splits to an Arrow table, then to a Python list of dictionaries. This approach is simple and supports random access, making it suitable for algorithms that require shuffling or non-sequential access. However, it requires sufficient memory to hold all data and doesn't support truly large datasets.
TorchIterDataset provides a more scalable approach by iterating over splits and converting rows on-the-fly. It supports PyTorch's multi-worker DataLoader by automatically partitioning splits across workers. Each worker processes a subset of splits, enabling parallel data loading without data duplication. The iterator converts internal OffsetRow objects to dictionaries using field names from the read schema.
Usage
Use TorchDataset for small to medium datasets that fit in memory and require random access. Use TorchIterDataset for large datasets or when memory is constrained, especially with PyTorch's multi-worker DataLoader. Both integrate seamlessly with PyTorch's training ecosystem.
Code Reference
Source Location
- Repository: Apache_Paimon
- File: paimon-python/pypaimon/read/datasource/torch_dataset.py
Signature
class TorchDataset(Dataset):
"""
PyTorch Dataset implementation for reading Paimon table data.
"""
def __init__(self, table_read: TableRead, splits: List[Split]):
"""Initialize with table_read and splits."""
...
def __len__(self) -> int:
"""Return the total number of rows."""
...
def __getitem__(self, index: int):
"""Get a single item by index."""
...
class TorchIterDataset(IterableDataset):
"""
PyTorch IterableDataset implementation for streaming Paimon data.
"""
def __init__(self, table_read: TableRead, splits: List[Split]):
"""Initialize with table_read and splits."""
...
def __iter__(self):
"""
Iterate over dataset with automatic split partitioning for multi-worker.
"""
...
Import
from pypaimon.read.datasource.torch_dataset import TorchDataset, TorchIterDataset
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| table_read | TableRead | Yes | TableRead instance for reading data |
| splits | List[Split] | Yes | List of splits to read |
Outputs
| Name | Type | Description |
|---|---|---|
| item | dict | Dictionary mapping field names to values for each row |
Usage Examples
from pypaimon.read.datasource.torch_dataset import TorchDataset, TorchIterDataset
from torch.utils.data import DataLoader
import torch
# Open Paimon table
table = paimon.Table("path/to/table")
table_read = table.new_read()
splits = table.new_scan().plan_splits()
# Option 1: Map-style dataset (loads all data)
dataset = TorchDataset(table_read, splits)
print(f"Dataset size: {len(dataset)}")
# Use with DataLoader for shuffling
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
for batch in dataloader:
# batch is a dict of tensors
predictions = model(batch['features'])
loss = criterion(predictions, batch['labels'])
# Option 2: Iterable dataset (streams data)
iter_dataset = TorchIterDataset(table_read, splits)
# Use with multi-worker DataLoader
dataloader = DataLoader(
iter_dataset,
batch_size=32,
num_workers=4, # 4 workers automatically split the data
collate_fn=custom_collate
)
for epoch in range(num_epochs):
for batch in dataloader:
# Training loop
optimizer.zero_grad()
output = model(batch['features'])
loss = criterion(output, batch['labels'])
loss.backward()
optimizer.step()
# Custom collate function for batching
def custom_collate(batch):
# Convert list of dicts to dict of tensors
features = torch.tensor([item['feature_col'] for item in batch])
labels = torch.tensor([item['label_col'] for item in batch])
return {'features': features, 'labels': labels}