Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Gretelai Gretel synthetics TensorFlowConfig

From Leeroopedia
Knowledge Sources
Domains Synthetic_Data, Deep_Learning, Hyperparameter_Management
Last Updated 2026-02-14 19:00 GMT

Overview

Concrete tool for centralizing all LSTM training and generation hyperparameters provided by the gretel-synthetics library.

Description

TensorFlowConfig is a Python dataclass that inherits from BaseConfig and collects every parameter needed to train an LSTM text generation model and subsequently generate synthetic text. It provides sensible defaults (e.g., 100 epochs, batch size 64, embedding dimension 256, dropout 0.2), validates parameter consistency on construction (such as checking TensorFlow version compatibility when differential privacy is enabled), and automatically derives file paths for training data and checkpoints from the provided checkpoint_dir.

The class also serves as a strategy selector: its get_training_callable() returns the train_rnn function and its get_generator_class() returns TensorFlowGenerator, enabling the facade layer to dispatch to the correct engine.

Usage

Use TensorFlowConfig when you want to train or generate with the TensorFlow/LSTM backend of gretel-synthetics. Pass it to the train() facade or to generate_text() directly.

Code Reference

Source Location

Signature

@dataclass
class TensorFlowConfig(BaseConfig):
    # Training configurations
    epochs: int = 100
    early_stopping: bool = True
    early_stopping_patience: int = 5
    best_model_metric: str = None
    early_stopping_min_delta: float = 0.001
    batch_size: int = 64
    buffer_size: int = 10000
    seq_length: int = 100
    embedding_dim: int = 256
    rnn_units: int = 256
    learning_rate: float = 0.01
    dropout_rate: float = 0.2
    rnn_initializer: str = "glorot_uniform"

    # Diff privacy configs
    dp: bool = False
    dp_noise_multiplier: float = 0.1
    dp_l2_norm_clip: float = 3.0
    dp_microbatches: int = 1

    # Generation settings
    gen_temp: float = 1.0
    gen_chars: int = 0
    gen_lines: int = 1000
    predict_batch_size: int = 64
    reset_states: bool = True

    # Checkpoint storage
    save_all_checkpoints: bool = False
    save_best_model: bool = True

Import

from gretel_synthetics.config import TensorFlowConfig

I/O Contract

Inputs

Name Type Required Description
input_data_path str Yes Path to the raw training data file (inherited from BaseConfig)
checkpoint_dir str Yes Directory for storing model checkpoints and training artifacts (inherited from BaseConfig)
epochs int No Number of training epochs (default: 100)
early_stopping bool No Enable early stopping (default: True)
early_stopping_patience int No Epochs to wait with no improvement before stopping (default: 5)
best_model_metric str No Metric to track for best model: "val_loss", "loss", "val_accuracy", or "accuracy" (default: auto-selected based on validation_split)
early_stopping_min_delta float No Minimum improvement to qualify as progress (default: 0.001)
batch_size int No Samples per gradient update (default: 64)
buffer_size int No Shuffle buffer size (default: 10000)
seq_length int No Length of each training input sequence in tokens (default: 100)
embedding_dim int No Embedding vector dimension (default: 256)
rnn_units int No Dimensionality of LSTM output space (default: 256)
learning_rate float No Optimizer learning rate (default: 0.01)
dropout_rate float No Fraction of units to drop (default: 0.2)
rnn_initializer str No Kernel weight initializer for LSTM (default: "glorot_uniform")
dp bool No Enable differentially private training (default: False)
dp_noise_multiplier float No Noise added to gradients for DP (default: 0.1)
dp_l2_norm_clip float No Max L2 norm for gradient clipping in DP (default: 3.0)
dp_microbatches int No Number of microbatches for DP (default: 1)
gen_temp float No Softmax temperature for generation (default: 1.0)
gen_chars int No Max characters per generated line, 0 for no limit (default: 0)
gen_lines int No Max number of lines to generate (default: 1000)
predict_batch_size int No Parallel prediction batch size (default: 64)
reset_states bool No Reset RNN states between generated records (default: True)
save_all_checkpoints bool No Save all epoch checkpoints (default: False)
save_best_model bool No Track and save best model checkpoint (default: True)

Outputs

Name Type Description
TensorFlowConfig instance TensorFlowConfig A validated configuration object ready for use with train() and generate_text()
training_data_path str Auto-generated path to the annotated training data file (set during __post_init__)

Usage Examples

Basic Example

from gretel_synthetics.config import TensorFlowConfig

config = TensorFlowConfig(
    input_data_path="/path/to/training_data.txt",
    checkpoint_dir="/path/to/checkpoints",
    epochs=50,
    batch_size=64,
    embedding_dim=256,
    rnn_units=256,
    dropout_rate=0.2,
    gen_temp=1.0,
)

Differential Privacy Example

from gretel_synthetics.config import TensorFlowConfig

config = TensorFlowConfig(
    input_data_path="/path/to/sensitive_data.txt",
    checkpoint_dir="/path/to/dp_checkpoints",
    dp=True,
    dp_noise_multiplier=0.1,
    dp_l2_norm_clip=3.0,
    dp_microbatches=1,
    learning_rate=0.01,
    epochs=30,
)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment