Principle:Gretelai Gretel synthetics GAN Model Persistence
| Knowledge Sources | |
|---|---|
| Domains | Synthetic_Data, GAN, Tabular_Data |
| Last Updated | 2026-02-14 19:00 GMT |
Overview
GAN model persistence is the process of serializing a trained generative model to disk and deserializing it back for later use, handling the complexities of PyTorch state, device mapping, and non-serializable components.
Description
After training a GAN model (which can be computationally expensive, taking hours or days), it is essential to save the model to disk so that it can be loaded later for synthetic data generation without retraining. Model persistence for tabular GANs involves several challenges:
- Non-serializable components: Certain objects attached to the model, such as callback functions, cannot be serialized with Python's pickle protocol. These must be temporarily removed before saving and restored afterwards.
- Device portability: A model trained on a GPU (CUDA device) may need to be loaded on a machine with only a CPU, or on a different GPU. The loading process must handle device remapping transparently.
- PyTorch tensor storage: PyTorch tensors contain device-specific storage objects. When unpickling a model trained on CUDA on a CPU-only machine, the default unpickler fails. A patched unpickler must intercept PyTorch's storage loading and apply the correct
map_location. - Complete state preservation: The saved model must include not just the neural network weights, but also the DataTransformer (with fitted encoders and normalizers), the conditional vector sampler, all hyperparameters, and the SDV metadata, so that sampling can proceed identically after loading.
The persistence approach used in gretel-synthetics relies on Python's pickle protocol (inherited from SDV's BaseTabularModel.save()) for saving, and provides an enhanced load_v2() method that uses a custom unpickler to handle cross-device loading.
Usage
Use model persistence when:
- Training is expensive and you want to reuse a trained model across sessions
- Deploying a trained model to a different environment (e.g., from GPU training server to CPU inference server)
- Sharing a trained model between team members
- Creating checkpoints during experimentation
Theoretical Basis
Serialization Strategy
The save/load cycle follows this pattern:
Save:
1. Temporarily set epoch_callback to None (non-picklable)
2. Pickle the entire ACTGAN object (including SDV metadata, DataTransformer,
ACTGANSynthesizer with Generator weights, all hyperparameters)
3. Restore epoch_callback on the live object for continued use
Load (v2):
1. Determine current device (CUDA if available, else CPU)
2. Open the pickle file
3. Use a patched unpickler that intercepts torch.storage._load_from_bytes
and applies map_location to redirect tensor storage to the current device
4. After unpickling, call set_device() on the synthesizer to move all
network parameters to the determined device
Cross-Device Loading
The core challenge is that PyTorch's default pickle handling stores tensors with their original device information. The patched unpickler solves this:
Standard pickle: torch.storage._load_from_bytes(b) -> loads to original device
Patched pickle: torch.load(BytesIO(b), map_location=target_device) -> loads to target device
This allows a model trained on cuda:0 to be loaded on cpu without errors.
What Is Not Saved
The following components are deliberately excluded from the serialized model:
- epoch_callback: Callback functions are typically closures or lambda functions that reference external state and cannot be reliably pickled. They are set to
Nonebefore saving and must be re-attached after loading if needed.