Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Diffusers Infer Model Type

From Leeroopedia
Field Value
Type Pattern Doc
Overview Key-based architecture detection function that identifies checkpoint model type from weight key names and tensor shapes
Domains Model Conversion, Checkpoint Analysis
Workflow Checkpoint_Conversion
Related Principle Huggingface_Diffusers_Checkpoint_Format_Identification
Source src/diffusers/loaders/single_file_utils.py:L583-L817
Last Updated 2026-02-13 00:00 GMT

Code Reference

infer_diffusers_model_type

Source: src/diffusers/loaders/single_file_utils.py:L583-L817

def infer_diffusers_model_type(checkpoint):
    if (
        CHECKPOINT_KEY_NAMES["inpainting"] in checkpoint
        and checkpoint[CHECKPOINT_KEY_NAMES["inpainting"]].shape[1] == 9
    ):
        if CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
            model_type = "inpainting_v2"
        elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
            model_type = "xl_inpaint"
        else:
            model_type = "inpainting"

    elif CHECKPOINT_KEY_NAMES["v2"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["v2"]].shape[-1] == 1024:
        model_type = "v2"

    elif CHECKPOINT_KEY_NAMES["xl_base"] in checkpoint:
        model_type = "xl_base"

    # ... (cascade of elif checks for each architecture) ...

    elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["wan"]):
        if "model.diffusion_model.patch_embedding.weight" in checkpoint:
            target_key = "model.diffusion_model.patch_embedding.weight"
        else:
            target_key = "patch_embedding.weight"

        if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
            if checkpoint[target_key].shape[0] == 1536:
                model_type = "wan-vace-1.3B"
            elif checkpoint[target_key].shape[0] == 5120:
                model_type = "wan-vace-14B"

        if CHECKPOINT_KEY_NAMES["wan_animate"] in checkpoint:
            model_type = "wan-animate-14B"

        elif checkpoint[target_key].shape[0] == 1536:
            model_type = "wan-t2v-1.3B"
        elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
            model_type = "wan-t2v-14B"
        else:
            model_type = "wan-i2v-14B"

    # ... (more elif checks) ...

    else:
        model_type = "v1"

    return model_type

Import

from diffusers.loaders.single_file_utils import infer_diffusers_model_type

Key Parameters

Parameter Type Description
checkpoint dict[str, torch.Tensor] Full checkpoint state dictionary loaded from a .safetensors or .ckpt file

I/O Contract

Inputs

  • checkpoint: A dictionary mapping string key names to torch.Tensor values. This is the raw state dict from the checkpoint file.

Outputs

  • str: A model type identifier string. Possible values include:
Model Type String Architecture
"v1" Stable Diffusion v1.x (default fallback)
"v2" Stable Diffusion v2.x
"xl_base" Stable Diffusion XL Base
"xl_refiner" Stable Diffusion XL Refiner
"flux-dev" Flux Dev
"flux-schnell" Flux Schnell
"wan-t2v-1.3B" Wan 1.3B Text-to-Video
"wan-t2v-14B" Wan 14B Text-to-Video
"wan-i2v-14B" Wan 14B Image-to-Video
"wan-vace-1.3B" Wan VACE 1.3B
"wan-vace-14B" Wan VACE 14B
"hunyuan-video" HunyuanVideo
"sd3" / "sd35_medium" / "sd35_large" Stable Diffusion 3.x

Detection Logic

The function uses a priority-ordered cascade of checks. The order matters because some models share key patterns:

  1. Inpainting models (checked first because they have 9-channel input)
  2. SD v2 (1024-dim projection)
  3. Playground v2.5
  4. SDXL Base/Refiner
  5. Upscale
  6. ControlNet (with XL variants)
  7. Stable Cascade
  8. SD3 (with shape-based size detection)
  9. AnimateDiff (multiple version variants)
  10. Flux (dev vs schnell via guidance keys; fill/depth via img_in shape)
  11. LTX-Video (multiple version variants)
  12. AutoEncoder-DC
  13. Mochi
  14. HunyuanVideo
  15. AuraFlow
  16. Wan (T2V/I2V/VACE via patch_embedding shape and in_channels)
  17. HiDream
  18. Cosmos
  19. Fallback to v1

Usage Examples

Identifying a Wan Checkpoint

from safetensors.torch import load_file
from diffusers.loaders.single_file_utils import infer_diffusers_model_type

checkpoint = load_file("wan-14b-t2v.safetensors")
model_type = infer_diffusers_model_type(checkpoint)
# Returns: "wan-t2v-14B"

Identifying a Flux Checkpoint

checkpoint = load_file("flux1-dev.safetensors")
model_type = infer_diffusers_model_type(checkpoint)
# Returns: "flux-dev" (has guidance_in keys)

checkpoint_schnell = load_file("flux1-schnell.safetensors")
model_type = infer_diffusers_model_type(checkpoint_schnell)
# Returns: "flux-schnell" (no guidance_in keys)

Related Pages

Principle:Huggingface_Diffusers_Checkpoint_Format_Identification

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment