Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Gretelai Gretel synthetics ACTGANSynthesizer Actual Fit

From Leeroopedia
Knowledge Sources
Domains Synthetic_Data, GAN, Tabular_Data
Last Updated 2026-02-14 19:00 GMT

Overview

Concrete tool for executing the ACTGAN adversarial training loop over transformed tabular data, provided by the gretel-synthetics library.

Description

The ACTGANSynthesizer._actual_fit(train_data) method implements the core WGAN-GP training loop for the ACTGAN model. It constructs the Generator and Discriminator neural networks, creates Adam optimizers for both, initializes a DataSampler for batch preparation, and runs the alternating discriminator/generator update loop for the configured number of epochs. Each epoch iterates over all training data in batches, performing discriminator_steps discriminator updates per generator update.

The Generator is constructed with residual layers (each containing Linear, BatchNorm1d, and ReLU, with skip connections via concatenation), taking input of dimension embedding_dim + cond_vec_dim and producing output of dimension data_dim (the encoded data dimensionality).

The Discriminator is constructed with sequential Linear layers, each followed by LeakyReLU(0.2) and Dropout(0.5), taking input of dimension (data_dim + cond_vec_dim) * pac and producing a scalar output. The PAC mechanism reshapes the input so that pac samples are evaluated simultaneously.

After each epoch, the method logs generator loss, discriminator loss, and reconstruction loss (if verbose), and invokes the optional epoch_callback with an EpochInfo dataclass.

Usage

This method is called internally by ACTGANSynthesizer.fit() after _pre_fit_transform() has prepared the TrainData. It is not typically called directly by users.

Code Reference

Source Location

  • Repository: gretel-synthetics
  • File: src/gretel_synthetics/actgan/actgan.py
  • Lines: 685-837 (_actual_fit), 127-140 (Generator), 63-104 (Discriminator)

Signature

# ACTGANSynthesizer._actual_fit
def _actual_fit(self, train_data: TrainData) -> None:

# Generator.__init__
class Generator(Module):
    def __init__(self, embedding_dim: int, generator_dim: Sequence[int], data_dim: int):

# Discriminator.__init__
class Discriminator(Module):
    def __init__(self, input_dim, discriminator_dim, pac=10):

Import

from gretel_synthetics.actgan.actgan import ACTGANSynthesizer, Generator, Discriminator

I/O Contract

Inputs

Name Type Required Description
train_data TrainData Yes Transformed training data as a TrainData instance. Contains decoded column data and column metadata. Produced by _pre_fit_transform().

Outputs

Name Type Description
(none) None Modifies the ACTGANSynthesizer instance in place. After training, self._generator holds the trained Generator network, and self._condvec_sampler holds the conditional vector sampler for use during generation.

Internal Training Flow

The training loop performs the following operations per epoch:

steps_per_epoch = max(len(train_data) // self._batch_size, 1)
for i in range(epochs):
    for _ in range(steps_per_epoch):
        # --- Discriminator updates ---
        for _ in range(self._discriminator_steps):
            fakez = self._make_noise()  # [batch_size, embedding_dim]
            fake_cond_vec, real_cond_vec, fake_column_mask, real_encoded = \
                self._prepare_batch(data_sampler)
            fake, fakeact = self._apply_generator(fakez, fake_cond_vec)
            fake_cat, y_fake = self._apply_discriminator(fakeact, fake_cond_vec, discriminator)
            real_cat, y_real = self._apply_discriminator(real_encoded, real_cond_vec, discriminator)
            pen = discriminator.calc_gradient_penalty(real_cat, fake_cat, self._device, self.pac)
            loss_d = -(torch.mean(y_real) - torch.mean(y_fake))
            # Backprop discriminator

        # --- Generator update ---
        fakez = self._make_noise()
        fake_cond_vec, _, fake_column_mask, _ = self._prepare_batch(data_sampler)
        fake, fakeact = self._apply_generator(fakez, fake_cond_vec)
        _, y_fake = self._apply_discriminator(fakeact, fake_cond_vec, discriminator)
        # Compute reconstruction loss based on conditional_vector_type
        loss_g = -torch.mean(y_fake) + reconstruction_loss_coef * loss_reconstruction
        # Backprop generator

    # Epoch callback
    if self._epoch_callback is not None:
        self._epoch_callback(EpochInfo(epoch, loss_g, loss_d, loss_r))

Key Components

Generator Architecture

class Generator(Module):
    def __init__(self, embedding_dim: int, generator_dim: Sequence[int], data_dim: int):
        dim = embedding_dim
        seq = []
        for item in list(generator_dim):
            seq += [Residual(dim, item)]
            dim += item
        seq.append(Linear(dim, data_dim))
        self.seq = Sequential(*seq)

    def forward(self, input_):
        data = self.seq(input_)
        return data

Discriminator Architecture

class Discriminator(Module):
    def __init__(self, input_dim, discriminator_dim, pac=10):
        dim = input_dim * pac
        self.pac = pac
        self.pacdim = dim
        seq = []
        for item in list(discriminator_dim):
            seq += [Linear(dim, item), LeakyReLU(0.2), Dropout(0.5)]
            dim = item
        seq += [Linear(dim, 1)]
        self.seq = Sequential(*seq)

    def forward(self, input_):
        return self.seq(input_.view(-1, self.pacdim))

Usage Examples

Basic Example

import pandas as pd
from gretel_synthetics.actgan.actgan_wrapper import ACTGAN

# The _actual_fit method is called internally by fit()
data = pd.read_csv("sample_data.csv")
model = ACTGAN(
    epochs=100,
    batch_size=500,
    discriminator_steps=1,
    generator_dim=(256, 256),
    discriminator_dim=(256, 256),
    verbose=True,
)
model.fit(data)  # internally calls _actual_fit after data transformation

Example with Epoch Callback

from gretel_synthetics.actgan.actgan_wrapper import ACTGAN
from gretel_synthetics.actgan.structures import EpochInfo

def on_epoch(info: EpochInfo):
    print(f"Epoch {info.epoch}: G={info.loss_g:.4f}, D={info.loss_d:.4f}, R={info.loss_r:.4f}")

model = ACTGAN(
    epochs=50,
    epoch_callback=on_epoch,
    verbose=False,
)
model.fit(data)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment