Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Gretelai Gretel synthetics Train RNN

From Leeroopedia
Knowledge Sources
Domains Synthetic_Data, Deep_Learning, Model_Training
Last Updated 2026-02-14 19:00 GMT

Overview

Concrete tool for training an LSTM text generation model on tokenized data provided by the gretel-synthetics library.

Description

The training implementation consists of two layers:

1. The facade function train(store, tokenizer_trainer) (in src/gretel_synthetics/train.py) orchestrates the full pipeline: it creates a default tokenizer if none is provided, calls annotate_data() and train() on the tokenizer trainer, loads the trained tokenizer, packages everything into a TrainingParams dataclass, saves model parameters, performs a GPU check, and dispatches to the engine-specific training callable returned by store.get_training_callable().

2. The engine-specific function train_rnn(params) (in src/gretel_synthetics/tensorflow/train.py) performs the actual TensorFlow model training. It extracts the tokenizer and configuration from params, constructs a TensorFlow dataset with input-target pairs, builds the LSTM model, configures Keras callbacks (checkpointing, history tracking, early stopping, epoch callback wrapper, max training time), calls model.fit(), and saves the training history to CSV.

The dataset construction (_create_dataset) encodes all training lines into token IDs, creates fixed-length sequences, splits each into input/target pairs shifted by one token, shuffles them, and optionally creates a validation set using an enumeration-based 80/20 filter.

Usage

Call train(config) from user code. The facade handles all orchestration. Call train_rnn only if you need to bypass the facade (e.g., for custom tokenizer flows).

Code Reference

Source Location

  • Repository: gretel-synthetics
  • Files:
    • src/gretel_synthetics/train.py (L75--98): train() facade
    • src/gretel_synthetics/tensorflow/train.py (L215--329): train_rnn() engine-specific training

Signature

Facade:

def train(store: BaseConfig, tokenizer_trainer: Optional[BaseTokenizerTrainer] = None):
    """Train a Synthetic Model. This is a facade entrypoint that implements the engine
    specific training operation based on the provided configuration.

    Args:
        store: A subclass instance of BaseConfig.
        tokenizer_trainer: An optional subclass instance of a BaseTokenizerTrainer.
    """
    if tokenizer_trainer is None:
        tokenizer_trainer = _create_default_tokenizer(store)
    tokenizer_trainer.annotate_data()
    tokenizer_trainer.train()
    tokenizer = tokenizer_from_model_dir(store.checkpoint_dir)
    params = TrainingParams(
        tokenizer_trainer=tokenizer_trainer, tokenizer=tokenizer, config=store
    )
    train_fn = store.get_training_callable()
    store.save_model_params()
    store.gpu_check()
    train_fn(params)

Engine-specific:

def train_rnn(params: TrainingParams):
    """Fit synthetic data model on training data.

    Args:
        params: The parameters controlling model training.

    Returns:
        None
    """
    store = params.config
    tokenizer = params.tokenizer
    num_lines = params.tokenizer_trainer.num_lines
    text_iter = params.tokenizer_trainer.data_iterator()
    # ... dataset creation, model building, callback setup, model.fit() ...

Import

from gretel_synthetics.train import train

I/O Contract

Inputs

Name Type Required Description
store BaseConfig (typically TensorFlowConfig) Yes Configuration object with all training hyperparameters, file paths, and callback settings
tokenizer_trainer BaseTokenizerTrainer No A pre-configured tokenizer trainer. If None, a default is created based on config.vocab_size (CharTokenizerTrainer if 0, SentencePieceTokenizerTrainer otherwise)

Outputs

Name Type Description
Model checkpoints files Saved model weights in the checkpoint directory (best model or all checkpoints depending on config)
model_history.csv file CSV file with per-epoch loss, accuracy, val_loss, val_accuracy, and optionally epsilon/delta columns
model_params.json file Serialized configuration parameters saved to the checkpoint directory

Usage Examples

Basic Example

from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.train import train

config = TensorFlowConfig(
    input_data_path="/path/to/training_data.txt",
    checkpoint_dir="/path/to/checkpoints",
    epochs=50,
    early_stopping=True,
    early_stopping_patience=5,
    batch_size=64,
)

train(config)

Training with Custom Tokenizer

from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.tokenizers import SentencePieceTokenizerTrainer
from gretel_synthetics.train import train

config = TensorFlowConfig(
    input_data_path="/path/to/data.csv",
    checkpoint_dir="/path/to/checkpoints",
    field_delimiter=",",
    epochs=30,
)

tokenizer_trainer = SentencePieceTokenizerTrainer(
    vocab_size=10000,
    character_coverage=1.0,
    config=config,
)

train(config, tokenizer_trainer=tokenizer_trainer)

Training with Epoch Callback

from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.train import train, EpochState

def my_callback(epoch_state: EpochState):
    print(f"Epoch {epoch_state.epoch}: loss={epoch_state.loss:.4f}")

config = TensorFlowConfig(
    input_data_path="/path/to/data.txt",
    checkpoint_dir="/path/to/checkpoints",
    epoch_callback=my_callback,
)

train(config)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment