Overview
Concrete tool for saving a trained ACTGAN model to disk and loading it back with automatic device detection, provided by the gretel-synthetics library.
Description
The ACTGAN.save(path) method serializes a trained ACTGAN model to a file using Python's pickle protocol. Before pickling, it temporarily removes the epoch_callback from both the internal ACTGANSynthesizer and the _model_kwargs dictionary (since callbacks are typically non-serializable closures), delegates to the SDV BaseTabularModel.save() method, and then restores the callback on the live object for continued use.
The ACTGAN.load_v2(path) class method provides enhanced loading that handles cross-device portability. It uses torch_utils.determine_device() to detect whether CUDA is available, then opens the pickle file and deserializes it using torch_utils.patched_torch_unpickle(). This custom unpickler subclasses pickle.Unpickler and intercepts calls to torch.storage._load_from_bytes, replacing them with torch.load(BytesIO(b), map_location=device) to ensure all tensors are loaded to the correct device. After unpickling, set_device(device) is called on the synthesizer to move all generator parameters to the target device.
Usage
Use save() after training to persist the model, and load_v2() to restore it on any machine (GPU or CPU).
Code Reference
Source Location
- Repository: gretel-synthetics
- File:
src/gretel_synthetics/actgan/actgan_wrapper.py (lines 169-197), src/gretel_synthetics/utils/torch_utils.py (lines 21-25)
Signature
# _ACTGANModel.save
def save(self, path: str) -> None:
# _ACTGANModel.load_v2
@classmethod
def load_v2(cls, path: str) -> ACTGAN:
# torch_utils.determine_device
def determine_device() -> str:
# torch_utils.patched_torch_unpickle
def patched_torch_unpickle(file_handle: BinaryIO, device: str) -> object:
Import
from gretel_synthetics.actgan.actgan_wrapper import ACTGAN
I/O Contract
Inputs (save)
| Name |
Type |
Required |
Description
|
| path |
str |
Yes |
File path where the pickled model will be written.
|
Outputs (save)
| Name |
Type |
Description
|
| (none) |
None |
Writes the serialized model to the specified file path. The epoch_callback is excluded from the serialized state.
|
Inputs (load_v2)
| Name |
Type |
Required |
Description
|
| path |
str |
Yes |
File path to a previously saved ACTGAN model.
|
Outputs (load_v2)
| Name |
Type |
Description
|
| (return value) |
ACTGAN |
A fully restored ACTGAN model with all network weights, DataTransformer state, and hyperparameters. The model is placed on the best available device (CUDA if available, otherwise CPU). The epoch_callback will be None.
|
Implementation Details
Save Method
def save(self, path: str) -> None:
self._model: ACTGANSynthesizer
# Temporarily remove any epoch callback so pickling can be done
_tmp_callback = self._model._epoch_callback
self._model._epoch_callback = None
self._model_kwargs[EPOCH_CALLBACK] = None
super().save(path)
# Restore our callback for continued use of the model
self._model._epoch_callback = _tmp_callback
self._model_kwargs[EPOCH_CALLBACK] = _tmp_callback
Load V2 Method
@classmethod
def load_v2(cls, path: str) -> ACTGAN:
device = torch_utils.determine_device()
with open(path, "rb") as fin:
loaded_model: ACTGAN = torch_utils.patched_torch_unpickle(fin, device)
loaded_model._model.set_device(device)
return loaded_model
Patched Unpickler
class _PyTorchPatchedUnpickler(pickle.Unpickler):
def __init__(self, *args, map_location: str, **kwargs):
self._map_location = map_location
super().__init__(*args, **kwargs)
def find_class(self, module, name):
if module == "torch.storage" and name == "_load_from_bytes":
return _load_with_map_location(self._map_location)
else:
return super().find_class(module, name)
def _load_with_map_location(map_location: str) -> callable:
return lambda b: torch.load(BytesIO(b), map_location=map_location)
Device Detection
def determine_device() -> str:
device = "cuda" if torch.cuda.is_available() else "cpu"
return device
Usage Examples
Basic Example
import pandas as pd
from gretel_synthetics.actgan.actgan_wrapper import ACTGAN
# Train and save
data = pd.read_csv("training_data.csv")
model = ACTGAN(epochs=100)
model.fit(data)
model.save("actgan_model.pkl")
# Load on any device (CPU or GPU)
loaded_model = ACTGAN.load_v2("actgan_model.pkl")
synthetic_data = loaded_model.sample(1000)
Example with Callback Restoration
from gretel_synthetics.actgan.actgan_wrapper import ACTGAN
from gretel_synthetics.actgan.structures import EpochInfo
# Load a previously saved model
model = ACTGAN.load_v2("actgan_model.pkl")
# Note: epoch_callback is None after loading
# Re-attach a callback if needed for continued training
# model._model._epoch_callback = my_callback_function
# Generate synthetic data
synthetic = model.sample(500)
Related Pages
Implements Principle
Requires Environment