Implementation:Gretelai Gretel synthetics Generate Text
| Knowledge Sources | |
|---|---|
| Domains | Synthetic_Data, Deep_Learning, Text_Generation |
| Last Updated | 2026-02-14 19:00 GMT |
Overview
Concrete tool for generating synthetic text records from a trained LSTM model provided by the gretel-synthetics library.
Description
The generation implementation spans two modules:
1. The generate_text function (in src/gretel_synthetics/generate.py) is the primary user-facing entry point. It loads the trained tokenizer from the checkpoint directory, creates a Settings object that validates and processes the start string (including converting field delimiters to special tokens), determines the number of lines to generate, resolves parallelism settings, and delegates to generate_parallel which coordinates one or more TensorFlowGenerator workers.
2. The TensorFlowGenerator class (in src/gretel_synthetics/tensorflow/generator.py) implements the actual text generation loop. It loads the trained model weights, creates an infinite prediction generator (_predict_forever), and exposes a generate_next method that yields GenText records. Each record is validated against the optional line_validator callback, and the generator tracks the total number of invalid lines to enforce the max_invalid budget.
The low-level prediction loop (_predict_chars) operates in batches: it vectorizes the start string, feeds it to the model, and iteratively samples the next token using temperature-scaled categorical sampling via a @tf.function-compiled helper. Each batch produces predict_batch_size lines in parallel. Lines terminate when a newline token is found in the decoded output or when the gen_chars limit is reached.
Usage
Call generate_text(config) after training to produce synthetic records. Use start_string to seed generation, line_validator to filter invalid records, and parallelism for multi-worker throughput.
Code Reference
Source Location
- Repository: gretel-synthetics
- Files:
src/gretel_synthetics/generate.py(L144--250): generate_text() entry pointsrc/gretel_synthetics/tensorflow/generator.py(L19--137): TensorFlowGenerator class
Signature
generate_text:
def generate_text(
config: BaseConfig,
start_string: Optional[Union[str, List[str]]] = None,
line_validator: Optional[Callable] = None,
max_invalid: int = 1000,
num_lines: Optional[int] = None,
parallelism: int = 0,
) -> Iterator[GenText]:
"""A generator that will load a model and start creating records.
Args:
config: A configuration object, which you must have created previously
start_string: A prefix string used to seed record generation.
line_validator: An optional callback validator function.
max_invalid: Maximum number of invalid lines to generate.
num_lines: Overrides config.gen_lines if set.
parallelism: Number of concurrent workers (0 = number of CPUs).
Yields:
A GenText object for each generated record.
"""
TensorFlowGenerator:
class TensorFlowGenerator(BaseGenerator):
settings: Settings
model: tf.keras.Sequential
delim: str
total_invalid: int = 0
total_generated: int = 0
def __init__(self, settings: Settings):
...
def generate_next(
self, num_lines: Optional[int], hard_limit: Optional[int] = None
) -> Iterator[GenText]:
...
GenText dataclass:
@dataclass
class GenText(gen_text):
valid: bool = None
text: str = None
explain: str = None
delimiter: str = None
def as_dict(self) -> dict: ...
def values_as_list(self) -> Optional[List[str]]: ...
Import
from gretel_synthetics.generate import generate_text, GenText
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | BaseConfig | Yes | Configuration object with checkpoint_dir (for loading the model and tokenizer), gen_lines, gen_temp, gen_chars, predict_batch_size, reset_states, and field_delimiter |
| start_string | str or List[str] | No | Seed prefix for generation. Defaults to the tokenizer's newline_str. If a list, generates one record per seed (disables parallelism) |
| line_validator | Callable | No | A function that takes a raw string and validates it. Returning False or raising an exception marks the line invalid |
| max_invalid | int | No | Maximum number of invalid lines before raising TooManyInvalidError (default: 1000) |
| num_lines | int | No | Number of valid lines to generate. Overrides config.gen_lines. If start_string is a list, this is set to len(start_string) |
| parallelism | int | No | Number of parallel workers. 0 means use all CPUs; 1 (default effective) disables parallelism. Float values are interpreted as fraction of CPUs |
Outputs
| Name | Type | Description |
|---|---|---|
| Iterator[GenText] | Iterator[GenText] | A generator yielding GenText objects, each containing: text (the generated string), valid (True/False/None), explain (error message if validation failed), and delimiter (field delimiter if applicable) |
Usage Examples
Basic Example
from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.generate import generate_text
config = TensorFlowConfig(
input_data_path="/path/to/data.txt",
checkpoint_dir="/path/to/trained_model",
gen_lines=100,
gen_temp=1.0,
)
for record in generate_text(config, num_lines=100):
print(record.text)
With Validation
from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.generate import generate_text
config = TensorFlowConfig(
input_data_path="/path/to/data.csv",
checkpoint_dir="/path/to/trained_model",
field_delimiter=",",
)
def validate_record(raw_line: str):
parts = raw_line.split(",")
if len(parts) != 5:
raise ValueError(f"Expected 5 fields, got {len(parts)}")
for record in generate_text(
config,
line_validator=validate_record,
max_invalid=500,
num_lines=1000,
):
if record.valid:
print(record.text)
With Seed Prefix
from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.generate import generate_text
config = TensorFlowConfig(
input_data_path="/path/to/data.csv",
checkpoint_dir="/path/to/trained_model",
field_delimiter=",",
)
# Seed with specific first-column values
seeds = ["Alice,", "Bob,", "Carol,"]
for record in generate_text(config, start_string=seeds):
print(record.text)