Heuristic:Gretelai Gretel synthetics Mixed Precision Training Tradeoff
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Time_Series, Deep_Learning |
| Last Updated | 2026-02-14 19:00 GMT |
Overview
Optional mixed precision training (FP16/FP32) for DGAN that reduces GPU memory usage and training time by using 16-bit floating point where safe, disabled by default due to potential stability concerns.
Description
DGAN supports PyTorch Automatic Mixed Precision (AMP) via `torch.cuda.amp.GradScaler` and `torch.cuda.amp.autocast`. When enabled, the training loop automatically identifies computation steps that can safely use FP16 (16-bit floating point) and keeps FP32 (32-bit) only where needed for numerical stability. The GradScaler prevents gradient underflow by dynamically scaling loss values. This is disabled by default (`mixed_precision_training=False`) because GAN training can be sensitive to numerical precision, and the Wasserstein loss with gradient penalty requires careful handling of floating point arithmetic.
Usage
Enable this heuristic when training DGAN on large time series datasets and running into GPU memory limitations. Set `DGANConfig(mixed_precision_training=True)`. Monitor training loss for stability. If loss becomes NaN or diverges, disable mixed precision.
The Insight (Rule of Thumb)
- Action: Set `mixed_precision_training=True` in `DGANConfig` for memory-constrained scenarios.
- Value: Can reduce VRAM usage by 30-50% and increase training throughput by 20-40% on modern NVIDIA GPUs (Volta+).
- Trade-off: Disabled by default due to potential numerical stability issues with WGAN-GP gradient penalty computation. GAN training is more sensitive to precision than standard supervised learning.
Reasoning
The DGAN training loop involves three separate networks (generator, feature discriminator, attribute discriminator) with Wasserstein loss and gradient penalty. The gradient penalty computation requires second-order gradients through interpolated samples, which can amplify floating point errors. Mixed precision is safe for the forward passes and most gradient computations, but the gradient penalty step may benefit from FP32 precision. PyTorch AMP handles this automatically via `autocast`, but edge cases in GAN training may still cause instability.
Code Evidence
DGAN config for mixed precision from `timeseries_dgan/config.py:88-91,133`:
mixed_precision_training: enabling automatic mixed precision while training
in order to reduce memory costs, bandwith, and time by identifying the
steps that require full precision and using 32-bit floating point for
only those steps while using 16-bit floating point everywhere else.
# ...
mixed_precision_training: bool = False
Training loop with AMP from `timeseries_dgan/dgan.py:840-850`:
scaler = torch.cuda.amp.GradScaler(enabled=self.config.mixed_precision_training)
for epoch in range(self.config.epochs):
for batch_idx, real_batch in enumerate(loader):
with torch.cuda.amp.autocast(
enabled=self.config.mixed_precision_training
):
attribute_noise = self.attribute_noise_func(real_batch[0].shape[0])
feature_noise = self.feature_noise_func(real_batch[0].shape[0])