Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Huggingface Diffusers Convert Checkpoint To Diffusers

From Leeroopedia
Revision as of 13:03, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Huggingface_Diffusers_Convert_Checkpoint_To_Diffusers.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Field Value
Type Pattern Doc
Overview Concrete weight key remapping functions that convert original checkpoint state dicts to Diffusers format, using Flux as a detailed example
Domains Model Conversion, Tensor Operations
Workflow Checkpoint_Conversion
Related Principle Huggingface_Diffusers_Weight_Mapping
Source src/diffusers/loaders/single_file_utils.py:L2244-L2438
Last Updated 2026-02-13 00:00 GMT

Code Reference

convert_flux_transformer_checkpoint_to_diffusers

Source: src/diffusers/loaders/single_file_utils.py:L2244-L2438

def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
    converted_state_dict = {}
    keys = list(checkpoint.keys())

    # Strip framework prefix
    for k in keys:
        if "model.diffusion_model." in k:
            checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

    # Detect layer counts dynamically
    num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint
                         if "double_blocks." in k))[-1] + 1
    num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint
                                 if "single_blocks." in k))[-1] + 1
    mlp_ratio = 4.0
    inner_dim = 3072

    def swap_scale_shift(weight):
        shift, scale = weight.chunk(2, dim=0)
        return torch.cat([scale, shift], dim=0)

    ## Timestep embeddings: time_in -> time_text_embed.timestep_embedder
    converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = \
        checkpoint.pop("time_in.in_layer.weight")
    converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = \
        checkpoint.pop("time_in.in_layer.bias")
    # ... (linear_2 similarly)

    ## Text embeddings: vector_in -> time_text_embed.text_embedder
    converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = \
        checkpoint.pop("vector_in.in_layer.weight")
    # ...

    ## Input projections
    converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
    converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")

    # Double transformer blocks (joint attention)
    for i in range(num_layers):
        block_prefix = f"transformer_blocks.{i}."

        # AdaLN modulation
        converted_state_dict[f"{block_prefix}norm1.linear.weight"] = \
            checkpoint.pop(f"double_blocks.{i}.img_mod.lin.weight")

        # QKV splitting: fused (3*dim, dim) -> separate Q, K, V
        sample_q, sample_k, sample_v = torch.chunk(
            checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
        )
        converted_state_dict[f"{block_prefix}attn.to_q.weight"] = sample_q
        converted_state_dict[f"{block_prefix}attn.to_k.weight"] = sample_k
        converted_state_dict[f"{block_prefix}attn.to_v.weight"] = sample_v
        # ... (context Q, K, V similarly from txt_attn.qkv)

        # QK norm weights
        converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = \
            checkpoint.pop(f"double_blocks.{i}.img_attn.norm.query_norm.scale")
        # ...

    # Single transformer blocks (fused Q+K+V+MLP)
    for i in range(num_single_layers):
        block_prefix = f"single_transformer_blocks.{i}."
        mlp_hidden_dim = int(inner_dim * mlp_ratio)
        split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)

        q, k, v, mlp = torch.split(
            checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0
        )
        converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q
        converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k
        converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v
        converted_state_dict[f"{block_prefix}proj_mlp.weight"] = mlp
        # ...

    # Final layer with scale-shift swap
    converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
    converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
        checkpoint.pop("final_layer.adaLN_modulation.1.weight")
    )

    return converted_state_dict

Key Parameters

Parameter Type Description
checkpoint dict[str, torch.Tensor] Raw checkpoint state dict (modified in-place via pop())
**kwargs dict Additional keyword arguments (e.g., config for some converters)

I/O Contract

Inputs

  • checkpoint: Raw state dict with original key naming. Keys may have optional model.diffusion_model. prefix.

Outputs

  • dict[str, torch.Tensor]: New state dict with Diffusers-compatible key names and correctly shaped tensors.

Conversion Patterns

Pattern 1: Simple Key Rename

# Original: "time_in.in_layer.weight"
# Diffusers: "time_text_embed.timestep_embedder.linear_1.weight"
converted[new_key] = checkpoint.pop(old_key)

Pattern 2: QKV Splitting (Equal)

# Original: fused QKV weight of shape (3*dim, dim)
q, k, v = torch.chunk(checkpoint.pop("attn.qkv.weight"), 3, dim=0)
converted["attn.to_q.weight"] = q  # (dim, dim)
converted["attn.to_k.weight"] = k  # (dim, dim)
converted["attn.to_v.weight"] = v  # (dim, dim)

Pattern 3: QKVM Splitting (Unequal)

# Original: fused Q+K+V+MLP weight of shape (3*dim + mlp_dim, dim)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(checkpoint.pop("linear1.weight"), split_size, dim=0)

Pattern 4: Scale-Shift Swap

# Original: [shift, scale] concatenated
# Diffusers: [scale, shift] concatenated
shift, scale = weight.chunk(2, dim=0)
converted_weight = torch.cat([scale, shift], dim=0)

Pattern 5: Prefix Stripping

for k in list(checkpoint.keys()):
    if "model.diffusion_model." in k:
        checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

External Dependencies

  • torch (for torch.chunk, torch.split, torch.cat)

Usage Examples

Converting a Flux Checkpoint

from safetensors.torch import load_file
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers

# Load original checkpoint
checkpoint = load_file("flux1-dev.safetensors")

# Convert to Diffusers format
diffusers_state_dict = convert_flux_transformer_checkpoint_to_diffusers(checkpoint)

# Load into model
from diffusers import FluxTransformer2DModel
model = FluxTransformer2DModel.from_config(config)
model.load_state_dict(diffusers_state_dict)

Verifying Conversion Completeness

# After conversion, check for unconverted keys
remaining = set(checkpoint.keys())
if remaining:
    print(f"Warning: {len(remaining)} keys not consumed: {remaining}")

Related Pages

Principle:Huggingface_Diffusers_Weight_Mapping

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment