Implementation:Gretelai Gretel synthetics DataFrameBatch Train All Batches
| Knowledge Sources | |
|---|---|
| Domains | Synthetic_Data, Tabular_Data |
| Last Updated | 2026-02-14 19:00 GMT |
Overview
Concrete tool for training one generative model per column-cluster batch provided by the gretel-synthetics library.
Description
DataFrameBatch.train_all_batches() logs the batch sizes and then iterates over every key in the batches dictionary, calling train_batch(idx) for each one.
DataFrameBatch.train_batch(batch_idx) trains a single batch's model. If a custom tokenizer was provided to the DataFrameBatch, it is deep-copied and its config attribute is set to the target batch's config. The method then delegates to the library's core train() function, passing the batch's config and the optional tokeniser. A RuntimeError is raised if this method is called in read mode, and a ValueError is raised for an invalid batch index.
Before training begins, _log_batches() emits an INFO-level log message listing the number of batches and the column count in each batch (e.g., [15, 12, 8]).
Usage
Call train_all_batches() after create_training_data() to train models for all batches sequentially. Use train_batch(idx) when you need to retrain a single batch (for example, after adjusting its config or validator).
Code Reference
Source Location
- Repository: gretel-synthetics
- File:
src/gretel_synthetics/batch.py - Lines: 1178-1185 (train_all_batches), 1158-1176 (train_batch)
Signature
def train_all_batches(self):
"""Train a model for each batch."""
def train_batch(self, batch_idx: int):
"""Train a model for a single batch. All model information will
be written into that batch's directory.
Args:
batch_idx: The index of the batch, from the ``batches`` dictionary
"""
Import
from gretel_synthetics.batch import DataFrameBatch
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| self | DataFrameBatch | Yes | Must be in write mode with training data already created. |
| batch_idx (train_batch only) | int | Yes | Index of the batch to train, must be a valid key in self.batches. |
Outputs
| Name | Type | Description |
|---|---|---|
| (side effect) model checkpoints | files on disk | Trained model weights, tokeniser artifacts, and config written into each batch's checkpoint_dir. |
| (return value) | None | Both methods return None. |
Usage Examples
Basic Example: Train All Batches
from gretel_synthetics.batch import DataFrameBatch
config = {
"checkpoint_dir": "/tmp/my_model",
"field_delimiter": ",",
"overwrite": True,
"epochs": 30,
}
batcher = DataFrameBatch(df=my_dataframe, batch_size=10, config=config)
batcher.create_training_data()
# Train all batches sequentially
batcher.train_all_batches()
Selective Retraining
# Retrain only batch 2
batcher.train_batch(2)