Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Diffusers Save Pretrained Quantized

From Leeroopedia

Metadata

Property Value
API DiffusionPipeline.save_pretrained(save_directory, safe_serialization=True) / model.save_pretrained(save_directory)
Module src/diffusers/pipelines/pipeline_utils.py (pipeline), src/diffusers/models/modeling_utils.py (model)
Lines Pipeline: L240-L371, Model: L667-L820
Import pipeline.save_pretrained() or model.save_pretrained()
Type API Doc
Principle Huggingface_Diffusers_Quantized_Model_Saving
Implements Principle:Huggingface_Diffusers_Quantized_Model_Saving

Purpose

The save_pretrained methods on both DiffusionPipeline and ModelMixin handle serialization of quantized models. At the model level, the method validates that the quantizer supports serialization, saves quantized weights in safetensors format, and writes the config.json with embedded quantization_config. At the pipeline level, each component is saved to its own subdirectory, with quantized components handled by their individual save_pretrained methods.

Model-Level: ModelMixin.save_pretrained

I/O Contract

Parameter Type Default Description
save_directory os.PathLike (required) Directory to save the model and config
safe_serialization bool True Use safetensors format (recommended)
variant None None Weight filename variant (e.g., "fp16")
max_shard_size str "10GB" Maximum size per checkpoint shard
push_to_hub bool False Push to HuggingFace Hub after saving
is_main_process bool True Only save on main process (for distributed training)

Quantization-Specific Control Flow

def save_pretrained(self, save_directory, safe_serialization=True, variant=None,
                    max_shard_size="10GB", push_to_hub=False, **kwargs):
    # Step 1: Check quantizer serializability
    hf_quantizer = getattr(self, "hf_quantizer", None)
    if hf_quantizer is not None:
        quantization_serializable = (
            hf_quantizer is not None
            and isinstance(hf_quantizer, DiffusersQuantizer)
            and hf_quantizer.is_serializable
        )
        if not quantization_serializable:
            raise ValueError(
                f"The model is quantized with {hf_quantizer.quantization_config.quant_method} "
                f"and is not serializable..."
            )

    # Step 2: Determine weights filename
    weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
    weights_name = _add_variant(weights_name, variant)

    # Step 3: Create directory
    os.makedirs(save_directory, exist_ok=True)

    # Step 4: Save config (includes quantization_config in config.json)
    if is_main_process:
        model_to_save.save_config(save_directory)

    # Step 5: Get state dict (quantized weights)
    state_dict = model_to_save.state_dict()

    # Step 6: Shard and save weights
    state_dict_split = split_torch_state_dict_into_shards(
        state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
    )

    for filename, tensors in state_dict_split.filename_to_tensors.items():
        shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
        filepath = os.path.join(save_directory, filename)
        if safe_serialization:
            safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
        else:
            torch.save(shard, filepath)

    # Step 7: Save index file for sharded checkpoints
    if state_dict_split.is_sharded:
        index = {
            "metadata": state_dict_split.metadata,
            "weight_map": state_dict_split.tensor_to_filename,
        }
        save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
        save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
        with open(save_index_file, "w", encoding="utf-8") as f:
            content = json.dumps(index, indent=2, sort_keys=True) + "\n"
            f.write(content)

Key observations:

  • The quantization serialization check (Step 1) is a gate that prevents saving non-serializable quantized models
  • save_config() (Step 4) writes config.json which includes the quantization_config dict
  • state_dict() (Step 5) returns the quantized weights as they exist in memory -- the backend-specific layers handle converting their internal representation to serializable tensors
  • The actual weight serialization (Step 6) is backend-agnostic -- it just saves whatever tensors the state dict contains

Pipeline-Level: DiffusionPipeline.save_pretrained

I/O Contract

Parameter Type Default Description
save_directory os.PathLike (required) Root directory for the pipeline
safe_serialization bool True Use safetensors format
variant None None Weight filename variant
max_shard_size str | None None Maximum size per shard
push_to_hub bool False Push to HuggingFace Hub

Pipeline Save Flow

def save_pretrained(self, save_directory, safe_serialization=True,
                    variant=None, max_shard_size=None, push_to_hub=False, **kwargs):
    model_index_dict = dict(self.config)
    # Filter to saveable modules
    expected_modules, optional_kwargs = self._get_signature_keys(self)

    for pipeline_component_name in model_index_dict.keys():
        sub_model = getattr(self, pipeline_component_name)
        model_cls = sub_model.__class__

        # Handle compiled models
        if is_compiled_module(sub_model):
            sub_model = _unwrap_model(sub_model)
            model_cls = sub_model.__class__

        # Find the correct save method from LOADABLE_CLASSES
        save_method_name = None
        for library_name, library_classes in LOADABLE_CLASSES.items():
            # ... resolve save_method_name
            pass

        save_method = getattr(sub_model, save_method_name)

        # Build save kwargs based on method signature
        save_kwargs = {}
        if "safe_serialization" in save_method_signature.parameters:
            save_kwargs["safe_serialization"] = safe_serialization
        if "variant" in save_method_signature.parameters:
            save_kwargs["variant"] = variant
        if "max_shard_size" in save_method_signature.parameters and max_shard_size is not None:
            save_kwargs["max_shard_size"] = max_shard_size

        # Save component to its subdirectory
        save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)

    # Save pipeline config (model_index.json)
    self.save_config(save_directory)

Key observations:

  • Each component is saved independently to save_directory/<component_name>/
  • The pipeline delegates to each component's own save_pretrained method
  • Quantized model components (diffusers models) hit the ModelMixin.save_pretrained path with its serialization check
  • Transformers model components use their own save_pretrained with analogous quantization handling
  • The pipeline's model_index.json is quantization-agnostic

Saved Directory Structure

After saving a pipeline with a quantized transformer:

save_directory/
  model_index.json                    # Pipeline config (no quantization info)
  transformer/
    config.json                       # Model config WITH quantization_config section
    model.safetensors                 # Quantized weight tensors
  text_encoder/
    config.json                       # Standard config (no quantization if not quantized)
    model.safetensors                 # Full-precision weights
  vae/
    config.json
    diffusion_pytorch_model.safetensors
  scheduler/
    scheduler_config.json
  tokenizer/
    ...

Usage Examples

Save a Quantized Model

import torch
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig

# Load with quantization
config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
transformer = FluxTransformer2DModel.from_pretrained(
    "black-forest-labs/Flux.1-Dev",
    subfolder="transformer",
    quantization_config=config,
    torch_dtype=torch.bfloat16,
)

# Save quantized model
transformer.save_pretrained("./my_quantized_transformer")
# Creates: config.json (with quantization_config) + model.safetensors (quantized weights)

Save a Quantized Pipeline

import torch
from diffusers import DiffusionPipeline, BitsAndBytesConfig
from diffusers.quantizers import PipelineQuantizationConfig

pipe = DiffusionPipeline.from_pretrained(
    "black-forest-labs/Flux.1-Dev",
    quantization_config=PipelineQuantizationConfig(
        quant_mapping={
            "transformer": BitsAndBytesConfig(
                load_in_4bit=True, bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
            ),
        }
    ),
    torch_dtype=torch.bfloat16,
)

# Save entire pipeline -- quantized components preserve their configs
pipe.save_pretrained("./my_quantized_pipeline")

Save and Push to Hub

transformer.save_pretrained(
    "./my_quantized_transformer",
    push_to_hub=True,
    repo_id="my-username/flux-nf4-transformer",
)

Reload the Saved Model

# No quantization_config needed -- detected from config.json
transformer = FluxTransformer2DModel.from_pretrained("./my_quantized_transformer")
print(transformer.is_quantized)  # True

Implementation Notes

  • Serialization gate importance: If a backend's is_serializable returns False, save_pretrained raises immediately. This prevents saving corrupted or unusable checkpoints. Users see a clear error explaining the limitation.
  • state_dict() behavior for quantized models: Each backend's quantized layers override state_dict() to return their internal representation in a serializable format. For BitsAndBytes, this includes the packed weights, absmax values, and quant_state. For TorchAO, tensor subclasses are converted to standard tensors.
  • Safetensors metadata: The metadata={"format": "pt"} tag in the safetensors save indicates PyTorch format, enabling correct deserialization.
  • Shard cleanup: Before saving, the method removes stale shard files from previous saves to prevent orphaned shards from confusing the loader.
  • Compiled model unwrapping: Pipeline save handles torch.compile()-wrapped models by unwrapping them before determining the save method.

Related Pages

Requires Environment

Source References

  • src/diffusers/models/modeling_utils.py:L667-L677 - save_pretrained signature and docstring
  • src/diffusers/models/modeling_utils.py:L715-L726 - Quantization serialization check
  • src/diffusers/models/modeling_utils.py:L749-L806 - Config save, state dict extraction, and weight sharding
  • src/diffusers/pipelines/pipeline_utils.py:L240-L248 - Pipeline save_pretrained signature
  • src/diffusers/pipelines/pipeline_utils.py:L276-L357 - Per-component save iteration
  • src/diffusers/quantizers/base.py:L234-L236 - is_serializable abstract property

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment