Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Gretelai Gretel synthetics DGAN Train Loop

From Leeroopedia
Knowledge Sources
Domains Synthetic_Data, Time_Series, GAN
Last Updated 2026-02-14 19:00 GMT

Overview

Concrete tool for executing the WGAN-GP adversarial training loop on prepared data provided by the gretel-synthetics library.

Description

The DGAN._train() method implements the core adversarial training loop for the DoppelGANger model. It operates on a PyTorch Dataset of 3-element tuples (attributes, additional_attributes, features) already in the internal encoded representation.

The method sets up a DataLoader with shuffling and optional drop_last (to avoid batches of size 1). Three Adam optimizers are created for the generator, feature discriminator, and attribute discriminator (if present). Training iterates over epochs and batches, executing discriminator rounds (updating both feature and attribute discriminators) followed by generator rounds in each batch step. A torch.cuda.amp.GradScaler enables optional mixed precision training.

The _discriminate() helper filters out NaN placeholder tensors, flattens features, concatenates all inputs, and passes through the feature discriminator. The _get_gradient_penalty() helper interpolates between real and generated batches using a random alpha, computes the discriminator output on the interpolation, obtains gradients via torch.autograd.grad, and returns the squared deviation of the gradient L2 norm from 1.

Usage

_train() is called internally by train_numpy() after data transformation. It should not be called directly. Training behavior is controlled through DGANConfig parameters.

Code Reference

Source Location

  • Repository: gretel-synthetics
  • File: src/gretel_synthetics/timeseries_dgan/dgan.py
  • Lines: 779-956 (_train), 983-1005 (_discriminate), 1028-1074 (_get_gradient_penalty)

Signature

def _train(
    self,
    dataset: Dataset,
    progress_callback: Optional[Callable[[ProgressInfo], None]] = None,
):
def _discriminate(
    self,
    batch,
) -> torch.Tensor:
def _get_gradient_penalty(
    self, generated_batch, real_batch, discriminator_func
) -> torch.Tensor:

Import

from gretel_synthetics.timeseries_dgan.dgan import DGAN

I/O Contract

Inputs (_train)

Name Type Required Description
dataset torch.utils.data.Dataset Yes TensorDataset of 3-element tuples: (attributes_tensor, additional_attributes_tensor, features_tensor). NaN-filled tensors serve as placeholders when attributes/additional_attributes are absent.
progress_callback Optional[Callable[[ProgressInfo], None]] No Called after each batch with ProgressInfo containing epoch, total_epochs, batch, and total_batches

Inputs (_discriminate)

Name Type Required Description
batch tuple of torch.Tensor Yes Tuple of (attributes, additional_attributes, features) tensors; NaN-filled tensors are automatically filtered out

Inputs (_get_gradient_penalty)

Name Type Required Description
generated_batch tuple of torch.Tensor Yes Generator output tensors
real_batch tuple of torch.Tensor Yes Real training data tensors
discriminator_func Callable Yes Either self._discriminate or self._discriminate_attributes

Outputs

Name Type Description
_train returns None Model weights are updated in-place through the training loop
_discriminate returns torch.Tensor Discriminator scores with shape (batch_size, 1)
_get_gradient_penalty returns torch.Tensor Scalar tensor containing the gradient penalty value

Usage Examples

Basic Example

import numpy as np
from gretel_synthetics.timeseries_dgan.dgan import DGAN
from gretel_synthetics.timeseries_dgan.config import DGANConfig

config = DGANConfig(
    max_sequence_len=20,
    sample_len=5,
    batch_size=500,
    epochs=100,
    discriminator_rounds=1,
    generator_rounds=1,
    gradient_penalty_coef=10.0,
    generator_learning_rate=0.001,
    discriminator_learning_rate=0.001,
)

model = DGAN(config)
features = np.random.rand(5000, 20, 3)
attributes = np.random.rand(5000, 2)

# Training loop runs automatically inside train_numpy
model.train_numpy(features=features, attributes=attributes)

Progress Callback Example

from gretel_synthetics.timeseries_dgan.structures import ProgressInfo

def on_progress(info: ProgressInfo):
    if info.batch == 0:
        print(f"Epoch {info.epoch + 1}/{info.total_epochs}")

model.train_numpy(
    features=features,
    attributes=attributes,
    progress_callback=on_progress,
)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment