Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Gretelai Gretel synthetics ACTGAN Save Load

From Leeroopedia
Knowledge Sources
Domains Synthetic_Data, GAN, Tabular_Data
Last Updated 2026-02-14 19:00 GMT

Overview

Concrete tool for saving a trained ACTGAN model to disk and loading it back with automatic device detection, provided by the gretel-synthetics library.

Description

The ACTGAN.save(path) method serializes a trained ACTGAN model to a file using Python's pickle protocol. Before pickling, it temporarily removes the epoch_callback from both the internal ACTGANSynthesizer and the _model_kwargs dictionary (since callbacks are typically non-serializable closures), delegates to the SDV BaseTabularModel.save() method, and then restores the callback on the live object for continued use.

The ACTGAN.load_v2(path) class method provides enhanced loading that handles cross-device portability. It uses torch_utils.determine_device() to detect whether CUDA is available, then opens the pickle file and deserializes it using torch_utils.patched_torch_unpickle(). This custom unpickler subclasses pickle.Unpickler and intercepts calls to torch.storage._load_from_bytes, replacing them with torch.load(BytesIO(b), map_location=device) to ensure all tensors are loaded to the correct device. After unpickling, set_device(device) is called on the synthesizer to move all generator parameters to the target device.

Usage

Use save() after training to persist the model, and load_v2() to restore it on any machine (GPU or CPU).

Code Reference

Source Location

  • Repository: gretel-synthetics
  • File: src/gretel_synthetics/actgan/actgan_wrapper.py (lines 169-197), src/gretel_synthetics/utils/torch_utils.py (lines 21-25)

Signature

# _ACTGANModel.save
def save(self, path: str) -> None:

# _ACTGANModel.load_v2
@classmethod
def load_v2(cls, path: str) -> ACTGAN:

# torch_utils.determine_device
def determine_device() -> str:

# torch_utils.patched_torch_unpickle
def patched_torch_unpickle(file_handle: BinaryIO, device: str) -> object:

Import

from gretel_synthetics.actgan.actgan_wrapper import ACTGAN

I/O Contract

Inputs (save)

Name Type Required Description
path str Yes File path where the pickled model will be written.

Outputs (save)

Name Type Description
(none) None Writes the serialized model to the specified file path. The epoch_callback is excluded from the serialized state.

Inputs (load_v2)

Name Type Required Description
path str Yes File path to a previously saved ACTGAN model.

Outputs (load_v2)

Name Type Description
(return value) ACTGAN A fully restored ACTGAN model with all network weights, DataTransformer state, and hyperparameters. The model is placed on the best available device (CUDA if available, otherwise CPU). The epoch_callback will be None.

Implementation Details

Save Method

def save(self, path: str) -> None:
    self._model: ACTGANSynthesizer

    # Temporarily remove any epoch callback so pickling can be done
    _tmp_callback = self._model._epoch_callback
    self._model._epoch_callback = None
    self._model_kwargs[EPOCH_CALLBACK] = None

    super().save(path)

    # Restore our callback for continued use of the model
    self._model._epoch_callback = _tmp_callback
    self._model_kwargs[EPOCH_CALLBACK] = _tmp_callback

Load V2 Method

@classmethod
def load_v2(cls, path: str) -> ACTGAN:
    device = torch_utils.determine_device()
    with open(path, "rb") as fin:
        loaded_model: ACTGAN = torch_utils.patched_torch_unpickle(fin, device)
        loaded_model._model.set_device(device)
    return loaded_model

Patched Unpickler

class _PyTorchPatchedUnpickler(pickle.Unpickler):
    def __init__(self, *args, map_location: str, **kwargs):
        self._map_location = map_location
        super().__init__(*args, **kwargs)

    def find_class(self, module, name):
        if module == "torch.storage" and name == "_load_from_bytes":
            return _load_with_map_location(self._map_location)
        else:
            return super().find_class(module, name)

def _load_with_map_location(map_location: str) -> callable:
    return lambda b: torch.load(BytesIO(b), map_location=map_location)

Device Detection

def determine_device() -> str:
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return device

Usage Examples

Basic Example

import pandas as pd
from gretel_synthetics.actgan.actgan_wrapper import ACTGAN

# Train and save
data = pd.read_csv("training_data.csv")
model = ACTGAN(epochs=100)
model.fit(data)
model.save("actgan_model.pkl")

# Load on any device (CPU or GPU)
loaded_model = ACTGAN.load_v2("actgan_model.pkl")
synthetic_data = loaded_model.sample(1000)

Example with Callback Restoration

from gretel_synthetics.actgan.actgan_wrapper import ACTGAN
from gretel_synthetics.actgan.structures import EpochInfo

# Load a previously saved model
model = ACTGAN.load_v2("actgan_model.pkl")

# Note: epoch_callback is None after loading
# Re-attach a callback if needed for continued training
# model._model._epoch_callback = my_callback_function

# Generate synthetic data
synthetic = model.sample(500)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment