Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Gretelai Gretel synthetics Generate Text

From Leeroopedia
Knowledge Sources
Domains Synthetic_Data, Deep_Learning, Text_Generation
Last Updated 2026-02-14 19:00 GMT

Overview

Concrete tool for generating synthetic text records from a trained LSTM model provided by the gretel-synthetics library.

Description

The generation implementation spans two modules:

1. The generate_text function (in src/gretel_synthetics/generate.py) is the primary user-facing entry point. It loads the trained tokenizer from the checkpoint directory, creates a Settings object that validates and processes the start string (including converting field delimiters to special tokens), determines the number of lines to generate, resolves parallelism settings, and delegates to generate_parallel which coordinates one or more TensorFlowGenerator workers.

2. The TensorFlowGenerator class (in src/gretel_synthetics/tensorflow/generator.py) implements the actual text generation loop. It loads the trained model weights, creates an infinite prediction generator (_predict_forever), and exposes a generate_next method that yields GenText records. Each record is validated against the optional line_validator callback, and the generator tracks the total number of invalid lines to enforce the max_invalid budget.

The low-level prediction loop (_predict_chars) operates in batches: it vectorizes the start string, feeds it to the model, and iteratively samples the next token using temperature-scaled categorical sampling via a @tf.function-compiled helper. Each batch produces predict_batch_size lines in parallel. Lines terminate when a newline token is found in the decoded output or when the gen_chars limit is reached.

Usage

Call generate_text(config) after training to produce synthetic records. Use start_string to seed generation, line_validator to filter invalid records, and parallelism for multi-worker throughput.

Code Reference

Source Location

  • Repository: gretel-synthetics
  • Files:
    • src/gretel_synthetics/generate.py (L144--250): generate_text() entry point
    • src/gretel_synthetics/tensorflow/generator.py (L19--137): TensorFlowGenerator class

Signature

generate_text:

def generate_text(
    config: BaseConfig,
    start_string: Optional[Union[str, List[str]]] = None,
    line_validator: Optional[Callable] = None,
    max_invalid: int = 1000,
    num_lines: Optional[int] = None,
    parallelism: int = 0,
) -> Iterator[GenText]:
    """A generator that will load a model and start creating records.

    Args:
        config: A configuration object, which you must have created previously
        start_string: A prefix string used to seed record generation.
        line_validator: An optional callback validator function.
        max_invalid: Maximum number of invalid lines to generate.
        num_lines: Overrides config.gen_lines if set.
        parallelism: Number of concurrent workers (0 = number of CPUs).

    Yields:
        A GenText object for each generated record.
    """

TensorFlowGenerator:

class TensorFlowGenerator(BaseGenerator):
    settings: Settings
    model: tf.keras.Sequential
    delim: str
    total_invalid: int = 0
    total_generated: int = 0

    def __init__(self, settings: Settings):
        ...

    def generate_next(
        self, num_lines: Optional[int], hard_limit: Optional[int] = None
    ) -> Iterator[GenText]:
        ...

GenText dataclass:

@dataclass
class GenText(gen_text):
    valid: bool = None
    text: str = None
    explain: str = None
    delimiter: str = None

    def as_dict(self) -> dict: ...
    def values_as_list(self) -> Optional[List[str]]: ...

Import

from gretel_synthetics.generate import generate_text, GenText

I/O Contract

Inputs

Name Type Required Description
config BaseConfig Yes Configuration object with checkpoint_dir (for loading the model and tokenizer), gen_lines, gen_temp, gen_chars, predict_batch_size, reset_states, and field_delimiter
start_string str or List[str] No Seed prefix for generation. Defaults to the tokenizer's newline_str. If a list, generates one record per seed (disables parallelism)
line_validator Callable No A function that takes a raw string and validates it. Returning False or raising an exception marks the line invalid
max_invalid int No Maximum number of invalid lines before raising TooManyInvalidError (default: 1000)
num_lines int No Number of valid lines to generate. Overrides config.gen_lines. If start_string is a list, this is set to len(start_string)
parallelism int No Number of parallel workers. 0 means use all CPUs; 1 (default effective) disables parallelism. Float values are interpreted as fraction of CPUs

Outputs

Name Type Description
Iterator[GenText] Iterator[GenText] A generator yielding GenText objects, each containing: text (the generated string), valid (True/False/None), explain (error message if validation failed), and delimiter (field delimiter if applicable)

Usage Examples

Basic Example

from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.generate import generate_text

config = TensorFlowConfig(
    input_data_path="/path/to/data.txt",
    checkpoint_dir="/path/to/trained_model",
    gen_lines=100,
    gen_temp=1.0,
)

for record in generate_text(config, num_lines=100):
    print(record.text)

With Validation

from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.generate import generate_text

config = TensorFlowConfig(
    input_data_path="/path/to/data.csv",
    checkpoint_dir="/path/to/trained_model",
    field_delimiter=",",
)

def validate_record(raw_line: str):
    parts = raw_line.split(",")
    if len(parts) != 5:
        raise ValueError(f"Expected 5 fields, got {len(parts)}")

for record in generate_text(
    config,
    line_validator=validate_record,
    max_invalid=500,
    num_lines=1000,
):
    if record.valid:
        print(record.text)

With Seed Prefix

from gretel_synthetics.config import TensorFlowConfig
from gretel_synthetics.generate import generate_text

config = TensorFlowConfig(
    input_data_path="/path/to/data.csv",
    checkpoint_dir="/path/to/trained_model",
    field_delimiter=",",
)

# Seed with specific first-column values
seeds = ["Alice,", "Bob,", "Carol,"]
for record in generate_text(config, start_string=seeds):
    print(record.text)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment