Implementation:Gretelai Gretel synthetics Train RNN
| Knowledge Sources | |
|---|---|
| Domains | Synthetic_Data, Deep_Learning, Model_Training |
| Last Updated | 2026-02-14 19:00 GMT |
Overview
Concrete tool for training an LSTM text generation model on tokenized data provided by the gretel-synthetics library.
Description
The training implementation consists of two layers:
1. The facade function train(store, tokenizer_trainer) (in src/gretel_synthetics/train.py) orchestrates the full pipeline: it creates a default tokenizer if none is provided, calls annotate_data() and train() on the tokenizer trainer, loads the trained tokenizer, packages everything into a TrainingParams dataclass, saves model parameters, performs a GPU check, and dispatches to the engine-specific training callable returned by store.get_training_callable().
2. The engine-specific function train_rnn(params) (in src/gretel_synthetics/tensorflow/train.py) performs the actual TensorFlow model training. It extracts the tokenizer and configuration from params, constructs a TensorFlow dataset with input-target pairs, builds the LSTM model, configures Keras callbacks (checkpointing, history tracking, early stopping, epoch callback wrapper, max training time), calls model.fit(), and saves the training history to CSV.
The dataset construction (_create_dataset) encodes all training lines into token IDs, creates fixed-length sequences, splits each into input/target pairs shifted by one token, shuffles them, and optionally creates a validation set using an enumeration-based 80/20 filter.
Usage
Call train(config) from user code. The facade handles all orchestration. Call train_rnn only if you need to bypass the facade (e.g., for custom tokenizer flows).
Code Reference
Source Location
- Repository: gretel-synthetics
- Files:
src/gretel_synthetics/train.py(L75--98): train() facadesrc/gretel_synthetics/tensorflow/train.py(L215--329): train_rnn() engine-specific training
Signature
Facade:
def train(store: BaseConfig, tokenizer_trainer: Optional[BaseTokenizerTrainer] = None):
"""Train a Synthetic Model. This is a facade entrypoint that implements the engine
specific training operation based on the provided configuration.
Args:
store: A subclass instance of BaseConfig.
tokenizer_trainer: An optional subclass instance of a BaseTokenizerTrainer.
"""
if tokenizer_trainer is None:
tokenizer_trainer = _create_default_tokenizer(store)
tokenizer_trainer.annotate_data()
tokenizer_trainer.train()
tokenizer = tokenizer_from_model_dir(store.checkpoint_dir)
params = TrainingParams(
tokenizer_trainer=tokenizer_trainer, tokenizer=tokenizer, config=store
)
train_fn = store.get_training_callable()
store.save_model_params()
store.gpu_check()
train_fn(params)
Engine-specific:
def train_rnn(params: TrainingParams):
"""Fit synthetic data model on training data.
Args:
params: The parameters controlling model training.
Returns:
None
"""
store = params.config
tokenizer = params.tokenizer
num_lines = params.tokenizer_trainer.num_lines
text_iter = params.tokenizer_trainer.data_iterator()
# ... dataset creation, model building, callback setup, model.fit() ...
Import
from gretel_synthetics.train import train
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| store | BaseConfig (typically TensorFlowConfig) | Yes | Configuration object with all training hyperparameters, file paths, and callback settings |
| tokenizer_trainer | BaseTokenizerTrainer | No | A pre-configured tokenizer trainer. If None, a default is created based on config.vocab_size (CharTokenizerTrainer if 0, SentencePieceTokenizerTrainer otherwise) |
Outputs
| Name | Type | Description |
|---|---|---|
| Model checkpoints | files | Saved model weights in the checkpoint directory (best model or all checkpoints depending on config) |
| model_history.csv | file | CSV file with per-epoch loss, accuracy, val_loss, val_accuracy, and optionally epsilon/delta columns |
| model_params.json | file | Serialized configuration parameters saved to the checkpoint directory |
Usage Examples
Basic Example
from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.train import train
config = TensorFlowConfig(
input_data_path="/path/to/training_data.txt",
checkpoint_dir="/path/to/checkpoints",
epochs=50,
early_stopping=True,
early_stopping_patience=5,
batch_size=64,
)
train(config)
Training with Custom Tokenizer
from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.tokenizers import SentencePieceTokenizerTrainer
from gretel_synthetics.train import train
config = TensorFlowConfig(
input_data_path="/path/to/data.csv",
checkpoint_dir="/path/to/checkpoints",
field_delimiter=",",
epochs=30,
)
tokenizer_trainer = SentencePieceTokenizerTrainer(
vocab_size=10000,
character_coverage=1.0,
config=config,
)
train(config, tokenizer_trainer=tokenizer_trainer)
Training with Epoch Callback
from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.train import train, EpochState
def my_callback(epoch_state: EpochState):
print(f"Epoch {epoch_state.epoch}: loss={epoch_state.loss:.4f}")
config = TensorFlowConfig(
input_data_path="/path/to/data.txt",
checkpoint_dir="/path/to/checkpoints",
epoch_callback=my_callback,
)
train(config)