Implementation:Gretelai Gretel synthetics TensorFlowConfig
| Knowledge Sources | |
|---|---|
| Domains | Synthetic_Data, Deep_Learning, Hyperparameter_Management |
| Last Updated | 2026-02-14 19:00 GMT |
Overview
Concrete tool for centralizing all LSTM training and generation hyperparameters provided by the gretel-synthetics library.
Description
TensorFlowConfig is a Python dataclass that inherits from BaseConfig and collects every parameter needed to train an LSTM text generation model and subsequently generate synthetic text. It provides sensible defaults (e.g., 100 epochs, batch size 64, embedding dimension 256, dropout 0.2), validates parameter consistency on construction (such as checking TensorFlow version compatibility when differential privacy is enabled), and automatically derives file paths for training data and checkpoints from the provided checkpoint_dir.
The class also serves as a strategy selector: its get_training_callable() returns the train_rnn function and its get_generator_class() returns TensorFlowGenerator, enabling the facade layer to dispatch to the correct engine.
Usage
Use TensorFlowConfig when you want to train or generate with the TensorFlow/LSTM backend of gretel-synthetics. Pass it to the train() facade or to generate_text() directly.
Code Reference
Source Location
- Repository: gretel-synthetics
- File:
src/gretel_synthetics/config.py - Lines: 172--331
Signature
@dataclass
class TensorFlowConfig(BaseConfig):
# Training configurations
epochs: int = 100
early_stopping: bool = True
early_stopping_patience: int = 5
best_model_metric: str = None
early_stopping_min_delta: float = 0.001
batch_size: int = 64
buffer_size: int = 10000
seq_length: int = 100
embedding_dim: int = 256
rnn_units: int = 256
learning_rate: float = 0.01
dropout_rate: float = 0.2
rnn_initializer: str = "glorot_uniform"
# Diff privacy configs
dp: bool = False
dp_noise_multiplier: float = 0.1
dp_l2_norm_clip: float = 3.0
dp_microbatches: int = 1
# Generation settings
gen_temp: float = 1.0
gen_chars: int = 0
gen_lines: int = 1000
predict_batch_size: int = 64
reset_states: bool = True
# Checkpoint storage
save_all_checkpoints: bool = False
save_best_model: bool = True
Import
from gretel_synthetics.config import TensorFlowConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_data_path | str | Yes | Path to the raw training data file (inherited from BaseConfig) |
| checkpoint_dir | str | Yes | Directory for storing model checkpoints and training artifacts (inherited from BaseConfig) |
| epochs | int | No | Number of training epochs (default: 100) |
| early_stopping | bool | No | Enable early stopping (default: True) |
| early_stopping_patience | int | No | Epochs to wait with no improvement before stopping (default: 5) |
| best_model_metric | str | No | Metric to track for best model: "val_loss", "loss", "val_accuracy", or "accuracy" (default: auto-selected based on validation_split) |
| early_stopping_min_delta | float | No | Minimum improvement to qualify as progress (default: 0.001) |
| batch_size | int | No | Samples per gradient update (default: 64) |
| buffer_size | int | No | Shuffle buffer size (default: 10000) |
| seq_length | int | No | Length of each training input sequence in tokens (default: 100) |
| embedding_dim | int | No | Embedding vector dimension (default: 256) |
| rnn_units | int | No | Dimensionality of LSTM output space (default: 256) |
| learning_rate | float | No | Optimizer learning rate (default: 0.01) |
| dropout_rate | float | No | Fraction of units to drop (default: 0.2) |
| rnn_initializer | str | No | Kernel weight initializer for LSTM (default: "glorot_uniform") |
| dp | bool | No | Enable differentially private training (default: False) |
| dp_noise_multiplier | float | No | Noise added to gradients for DP (default: 0.1) |
| dp_l2_norm_clip | float | No | Max L2 norm for gradient clipping in DP (default: 3.0) |
| dp_microbatches | int | No | Number of microbatches for DP (default: 1) |
| gen_temp | float | No | Softmax temperature for generation (default: 1.0) |
| gen_chars | int | No | Max characters per generated line, 0 for no limit (default: 0) |
| gen_lines | int | No | Max number of lines to generate (default: 1000) |
| predict_batch_size | int | No | Parallel prediction batch size (default: 64) |
| reset_states | bool | No | Reset RNN states between generated records (default: True) |
| save_all_checkpoints | bool | No | Save all epoch checkpoints (default: False) |
| save_best_model | bool | No | Track and save best model checkpoint (default: True) |
Outputs
| Name | Type | Description |
|---|---|---|
| TensorFlowConfig instance | TensorFlowConfig | A validated configuration object ready for use with train() and generate_text() |
| training_data_path | str | Auto-generated path to the annotated training data file (set during __post_init__) |
Usage Examples
Basic Example
from gretel_synthetics.config import TensorFlowConfig
config = TensorFlowConfig(
input_data_path="/path/to/training_data.txt",
checkpoint_dir="/path/to/checkpoints",
epochs=50,
batch_size=64,
embedding_dim=256,
rnn_units=256,
dropout_rate=0.2,
gen_temp=1.0,
)
Differential Privacy Example
from gretel_synthetics.config import TensorFlowConfig
config = TensorFlowConfig(
input_data_path="/path/to/sensitive_data.txt",
checkpoint_dir="/path/to/dp_checkpoints",
dp=True,
dp_noise_multiplier=0.1,
dp_l2_norm_clip=3.0,
dp_microbatches=1,
learning_rate=0.01,
epochs=30,
)