Implementation:Huggingface Diffusers Infer Model Type
Appearance
| Field | Value |
|---|---|
| Type | Pattern Doc |
| Overview | Key-based architecture detection function that identifies checkpoint model type from weight key names and tensor shapes |
| Domains | Model Conversion, Checkpoint Analysis |
| Workflow | Checkpoint_Conversion |
| Related Principle | Huggingface_Diffusers_Checkpoint_Format_Identification |
| Source | src/diffusers/loaders/single_file_utils.py:L583-L817
|
| Last Updated | 2026-02-13 00:00 GMT |
Code Reference
infer_diffusers_model_type
Source: src/diffusers/loaders/single_file_utils.py:L583-L817
def infer_diffusers_model_type(checkpoint):
if (
CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
):
if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
model_type = "inpainting_v2"
elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
model_type = "xl_inpaint"
else:
model_type = "inpainting"
elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
model_type = "v2"
elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
model_type = "xl_base"
# ... (cascade of elif checks for each architecture) ...
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
if "model.diffusion_model.patch_embedding.weight" in checkpoint:
target_key = "model.diffusion_model.patch_embedding.weight"
else:
target_key = "patch_embedding.weight"
if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
if checkpoint[target_key].shape[0] == 1536:
model_type = "wan-vace-1.3B"
elif checkpoint[target_key].shape[0] == 5120:
model_type = "wan-vace-14B"
if CHECKPOINT_KEY_NAMES["wan_animate"] in checkpoint:
model_type = "wan-animate-14B"
elif checkpoint[target_key].shape[0] == 1536:
model_type = "wan-t2v-1.3B"
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
model_type = "wan-t2v-14B"
else:
model_type = "wan-i2v-14B"
# ... (more elif checks) ...
else:
model_type = "v1"
return model_type
Import
from diffusers.loaders.single_file_utils import infer_diffusers_model_type
Key Parameters
| Parameter | Type | Description |
|---|---|---|
checkpoint |
dict[str, torch.Tensor] |
Full checkpoint state dictionary loaded from a .safetensors or .ckpt file |
I/O Contract
Inputs
checkpoint: A dictionary mapping string key names totorch.Tensorvalues. This is the raw state dict from the checkpoint file.
Outputs
str: A model type identifier string. Possible values include:
| Model Type String | Architecture |
|---|---|
"v1" |
Stable Diffusion v1.x (default fallback) |
"v2" |
Stable Diffusion v2.x |
"xl_base" |
Stable Diffusion XL Base |
"xl_refiner" |
Stable Diffusion XL Refiner |
"flux-dev" |
Flux Dev |
"flux-schnell" |
Flux Schnell |
"wan-t2v-1.3B" |
Wan 1.3B Text-to-Video |
"wan-t2v-14B" |
Wan 14B Text-to-Video |
"wan-i2v-14B" |
Wan 14B Image-to-Video |
"wan-vace-1.3B" |
Wan VACE 1.3B |
"wan-vace-14B" |
Wan VACE 14B |
"hunyuan-video" |
HunyuanVideo |
"sd3" / "sd35_medium" / "sd35_large" |
Stable Diffusion 3.x |
Detection Logic
The function uses a priority-ordered cascade of checks. The order matters because some models share key patterns:
- Inpainting models (checked first because they have 9-channel input)
- SD v2 (1024-dim projection)
- Playground v2.5
- SDXL Base/Refiner
- Upscale
- ControlNet (with XL variants)
- Stable Cascade
- SD3 (with shape-based size detection)
- AnimateDiff (multiple version variants)
- Flux (dev vs schnell via guidance keys; fill/depth via img_in shape)
- LTX-Video (multiple version variants)
- AutoEncoder-DC
- Mochi
- HunyuanVideo
- AuraFlow
- Wan (T2V/I2V/VACE via patch_embedding shape and in_channels)
- HiDream
- Cosmos
- Fallback to v1
Usage Examples
Identifying a Wan Checkpoint
from safetensors.torch import load_file
from diffusers.loaders.single_file_utils import infer_diffusers_model_type
checkpoint = load_file("wan-14b-t2v.safetensors")
model_type = infer_diffusers_model_type(checkpoint)
# Returns: "wan-t2v-14B"
Identifying a Flux Checkpoint
checkpoint = load_file("flux1-dev.safetensors")
model_type = infer_diffusers_model_type(checkpoint)
# Returns: "flux-dev" (has guidance_in keys)
checkpoint_schnell = load_file("flux1-schnell.safetensors")
model_type = infer_diffusers_model_type(checkpoint_schnell)
# Returns: "flux-schnell" (no guidance_in keys)
Related Pages
- Huggingface_Diffusers_Checkpoint_Format_Identification (principle for this implementation) - Theory of key-based detection
- Huggingface_Diffusers_Single_File_Loadable_Classes (next step) - Uses model type to fetch config and select conversion
- Huggingface_Diffusers_From_Single_File (caller) - Invoked by from_single_file via fetch_diffusers_config
Principle:Huggingface_Diffusers_Checkpoint_Format_Identification
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment