Implementation:Huggingface Diffusers Save Pretrained Quantized
Metadata
| Property | Value |
|---|---|
| API | DiffusionPipeline.save_pretrained(save_directory, safe_serialization=True) / model.save_pretrained(save_directory)
|
| Module | src/diffusers/pipelines/pipeline_utils.py (pipeline), src/diffusers/models/modeling_utils.py (model)
|
| Lines | Pipeline: L240-L371, Model: L667-L820 |
| Import | pipeline.save_pretrained() or model.save_pretrained()
|
| Type | API Doc |
| Principle | Huggingface_Diffusers_Quantized_Model_Saving |
| Implements | Principle:Huggingface_Diffusers_Quantized_Model_Saving |
Purpose
The save_pretrained methods on both DiffusionPipeline and ModelMixin handle serialization of quantized models. At the model level, the method validates that the quantizer supports serialization, saves quantized weights in safetensors format, and writes the config.json with embedded quantization_config. At the pipeline level, each component is saved to its own subdirectory, with quantized components handled by their individual save_pretrained methods.
Model-Level: ModelMixin.save_pretrained
I/O Contract
| Parameter | Type | Default | Description |
|---|---|---|---|
save_directory |
os.PathLike | (required) | Directory to save the model and config |
safe_serialization |
bool |
True |
Use safetensors format (recommended) |
variant |
None | None |
Weight filename variant (e.g., "fp16")
|
max_shard_size |
str | "10GB" |
Maximum size per checkpoint shard |
push_to_hub |
bool |
False |
Push to HuggingFace Hub after saving |
is_main_process |
bool |
True |
Only save on main process (for distributed training) |
Quantization-Specific Control Flow
def save_pretrained(self, save_directory, safe_serialization=True, variant=None,
max_shard_size="10GB", push_to_hub=False, **kwargs):
# Step 1: Check quantizer serializability
hf_quantizer = getattr(self, "hf_quantizer", None)
if hf_quantizer is not None:
quantization_serializable = (
hf_quantizer is not None
and isinstance(hf_quantizer, DiffusersQuantizer)
and hf_quantizer.is_serializable
)
if not quantization_serializable:
raise ValueError(
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} "
f"and is not serializable..."
)
# Step 2: Determine weights filename
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
# Step 3: Create directory
os.makedirs(save_directory, exist_ok=True)
# Step 4: Save config (includes quantization_config in config.json)
if is_main_process:
model_to_save.save_config(save_directory)
# Step 5: Get state dict (quantized weights)
state_dict = model_to_save.state_dict()
# Step 6: Shard and save weights
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
# Step 7: Save index file for sharded checkpoints
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
Key observations:
- The quantization serialization check (Step 1) is a gate that prevents saving non-serializable quantized models
save_config()(Step 4) writesconfig.jsonwhich includes thequantization_configdictstate_dict()(Step 5) returns the quantized weights as they exist in memory -- the backend-specific layers handle converting their internal representation to serializable tensors- The actual weight serialization (Step 6) is backend-agnostic -- it just saves whatever tensors the state dict contains
Pipeline-Level: DiffusionPipeline.save_pretrained
I/O Contract
| Parameter | Type | Default | Description |
|---|---|---|---|
save_directory |
os.PathLike | (required) | Root directory for the pipeline |
safe_serialization |
bool |
True |
Use safetensors format |
variant |
None | None |
Weight filename variant |
max_shard_size |
str | None | None |
Maximum size per shard |
push_to_hub |
bool |
False |
Push to HuggingFace Hub |
Pipeline Save Flow
def save_pretrained(self, save_directory, safe_serialization=True,
variant=None, max_shard_size=None, push_to_hub=False, **kwargs):
model_index_dict = dict(self.config)
# Filter to saveable modules
expected_modules, optional_kwargs = self._get_signature_keys(self)
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__
# Handle compiled models
if is_compiled_module(sub_model):
sub_model = _unwrap_model(sub_model)
model_cls = sub_model.__class__
# Find the correct save method from LOADABLE_CLASSES
save_method_name = None
for library_name, library_classes in LOADABLE_CLASSES.items():
# ... resolve save_method_name
pass
save_method = getattr(sub_model, save_method_name)
# Build save kwargs based on method signature
save_kwargs = {}
if "safe_serialization" in save_method_signature.parameters:
save_kwargs["safe_serialization"] = safe_serialization
if "variant" in save_method_signature.parameters:
save_kwargs["variant"] = variant
if "max_shard_size" in save_method_signature.parameters and max_shard_size is not None:
save_kwargs["max_shard_size"] = max_shard_size
# Save component to its subdirectory
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
# Save pipeline config (model_index.json)
self.save_config(save_directory)
Key observations:
- Each component is saved independently to
save_directory/<component_name>/ - The pipeline delegates to each component's own
save_pretrainedmethod - Quantized model components (diffusers models) hit the
ModelMixin.save_pretrainedpath with its serialization check - Transformers model components use their own
save_pretrainedwith analogous quantization handling - The pipeline's
model_index.jsonis quantization-agnostic
Saved Directory Structure
After saving a pipeline with a quantized transformer:
save_directory/
model_index.json # Pipeline config (no quantization info)
transformer/
config.json # Model config WITH quantization_config section
model.safetensors # Quantized weight tensors
text_encoder/
config.json # Standard config (no quantization if not quantized)
model.safetensors # Full-precision weights
vae/
config.json
diffusion_pytorch_model.safetensors
scheduler/
scheduler_config.json
tokenizer/
...
Usage Examples
Save a Quantized Model
import torch
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
# Load with quantization
config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
subfolder="transformer",
quantization_config=config,
torch_dtype=torch.bfloat16,
)
# Save quantized model
transformer.save_pretrained("./my_quantized_transformer")
# Creates: config.json (with quantization_config) + model.safetensors (quantized weights)
Save a Quantized Pipeline
import torch
from diffusers import DiffusionPipeline, BitsAndBytesConfig
from diffusers.quantizers import PipelineQuantizationConfig
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/Flux.1-Dev",
quantization_config=PipelineQuantizationConfig(
quant_mapping={
"transformer": BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
),
}
),
torch_dtype=torch.bfloat16,
)
# Save entire pipeline -- quantized components preserve their configs
pipe.save_pretrained("./my_quantized_pipeline")
Save and Push to Hub
transformer.save_pretrained(
"./my_quantized_transformer",
push_to_hub=True,
repo_id="my-username/flux-nf4-transformer",
)
Reload the Saved Model
# No quantization_config needed -- detected from config.json
transformer = FluxTransformer2DModel.from_pretrained("./my_quantized_transformer")
print(transformer.is_quantized) # True
Implementation Notes
- Serialization gate importance: If a backend's
is_serializablereturnsFalse,save_pretrainedraises immediately. This prevents saving corrupted or unusable checkpoints. Users see a clear error explaining the limitation. - state_dict() behavior for quantized models: Each backend's quantized layers override
state_dict()to return their internal representation in a serializable format. For BitsAndBytes, this includes the packed weights, absmax values, and quant_state. For TorchAO, tensor subclasses are converted to standard tensors. - Safetensors metadata: The
metadata={"format": "pt"}tag in the safetensors save indicates PyTorch format, enabling correct deserialization. - Shard cleanup: Before saving, the method removes stale shard files from previous saves to prevent orphaned shards from confusing the loader.
- Compiled model unwrapping: Pipeline save handles
torch.compile()-wrapped models by unwrapping them before determining the save method.
Related Pages
- Huggingface_Diffusers_Quantized_Model_Saving - Principle of quantization-aware serialization
- Huggingface_Diffusers_Quantized_Model_Loading - The loading counterpart for saved quantized models
- Huggingface_Diffusers_Quantization_Config_Classes - Config objects serialized into config.json
- Huggingface_Diffusers_Quantized_Pipeline_Call - Running inference before saving
Requires Environment
Source References
src/diffusers/models/modeling_utils.py:L667-L677- save_pretrained signature and docstringsrc/diffusers/models/modeling_utils.py:L715-L726- Quantization serialization checksrc/diffusers/models/modeling_utils.py:L749-L806- Config save, state dict extraction, and weight shardingsrc/diffusers/pipelines/pipeline_utils.py:L240-L248- Pipeline save_pretrained signaturesrc/diffusers/pipelines/pipeline_utils.py:L276-L357- Per-component save iterationsrc/diffusers/quantizers/base.py:L234-L236- is_serializable abstract property