Implementation:Gretelai Gretel synthetics DGAN Train Loop
| Knowledge Sources | |
|---|---|
| Domains | Synthetic_Data, Time_Series, GAN |
| Last Updated | 2026-02-14 19:00 GMT |
Overview
Concrete tool for executing the WGAN-GP adversarial training loop on prepared data provided by the gretel-synthetics library.
Description
The DGAN._train() method implements the core adversarial training loop for the DoppelGANger model. It operates on a PyTorch Dataset of 3-element tuples (attributes, additional_attributes, features) already in the internal encoded representation.
The method sets up a DataLoader with shuffling and optional drop_last (to avoid batches of size 1). Three Adam optimizers are created for the generator, feature discriminator, and attribute discriminator (if present). Training iterates over epochs and batches, executing discriminator rounds (updating both feature and attribute discriminators) followed by generator rounds in each batch step. A torch.cuda.amp.GradScaler enables optional mixed precision training.
The _discriminate() helper filters out NaN placeholder tensors, flattens features, concatenates all inputs, and passes through the feature discriminator. The _get_gradient_penalty() helper interpolates between real and generated batches using a random alpha, computes the discriminator output on the interpolation, obtains gradients via torch.autograd.grad, and returns the squared deviation of the gradient L2 norm from 1.
Usage
_train() is called internally by train_numpy() after data transformation. It should not be called directly. Training behavior is controlled through DGANConfig parameters.
Code Reference
Source Location
- Repository: gretel-synthetics
- File:
src/gretel_synthetics/timeseries_dgan/dgan.py - Lines: 779-956 (_train), 983-1005 (_discriminate), 1028-1074 (_get_gradient_penalty)
Signature
def _train(
self,
dataset: Dataset,
progress_callback: Optional[Callable[[ProgressInfo], None]] = None,
):
def _discriminate(
self,
batch,
) -> torch.Tensor:
def _get_gradient_penalty(
self, generated_batch, real_batch, discriminator_func
) -> torch.Tensor:
Import
from gretel_synthetics.timeseries_dgan.dgan import DGAN
I/O Contract
Inputs (_train)
| Name | Type | Required | Description |
|---|---|---|---|
| dataset | torch.utils.data.Dataset | Yes | TensorDataset of 3-element tuples: (attributes_tensor, additional_attributes_tensor, features_tensor). NaN-filled tensors serve as placeholders when attributes/additional_attributes are absent. |
| progress_callback | Optional[Callable[[ProgressInfo], None]] | No | Called after each batch with ProgressInfo containing epoch, total_epochs, batch, and total_batches |
Inputs (_discriminate)
| Name | Type | Required | Description |
|---|---|---|---|
| batch | tuple of torch.Tensor | Yes | Tuple of (attributes, additional_attributes, features) tensors; NaN-filled tensors are automatically filtered out |
Inputs (_get_gradient_penalty)
| Name | Type | Required | Description |
|---|---|---|---|
| generated_batch | tuple of torch.Tensor | Yes | Generator output tensors |
| real_batch | tuple of torch.Tensor | Yes | Real training data tensors |
| discriminator_func | Callable | Yes | Either self._discriminate or self._discriminate_attributes |
Outputs
| Name | Type | Description |
|---|---|---|
| _train returns | None | Model weights are updated in-place through the training loop |
| _discriminate returns | torch.Tensor | Discriminator scores with shape (batch_size, 1) |
| _get_gradient_penalty returns | torch.Tensor | Scalar tensor containing the gradient penalty value |
Usage Examples
Basic Example
import numpy as np
from gretel_synthetics.timeseries_dgan.dgan import DGAN
from gretel_synthetics.timeseries_dgan.config import DGANConfig
config = DGANConfig(
max_sequence_len=20,
sample_len=5,
batch_size=500,
epochs=100,
discriminator_rounds=1,
generator_rounds=1,
gradient_penalty_coef=10.0,
generator_learning_rate=0.001,
discriminator_learning_rate=0.001,
)
model = DGAN(config)
features = np.random.rand(5000, 20, 3)
attributes = np.random.rand(5000, 2)
# Training loop runs automatically inside train_numpy
model.train_numpy(features=features, attributes=attributes)
Progress Callback Example
from gretel_synthetics.timeseries_dgan.structures import ProgressInfo
def on_progress(info: ProgressInfo):
if info.batch == 0:
print(f"Epoch {info.epoch + 1}/{info.total_epochs}")
model.train_numpy(
features=features,
attributes=attributes,
progress_callback=on_progress,
)