Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Diffusers Convert Checkpoint To Diffusers

From Leeroopedia
Field Value
Type Pattern Doc
Overview Concrete weight key remapping functions that convert original checkpoint state dicts to Diffusers format, using Flux as a detailed example
Domains Model Conversion, Tensor Operations
Workflow Checkpoint_Conversion
Related Principle Huggingface_Diffusers_Weight_Mapping
Source src/diffusers/loaders/single_file_utils.py:L2244-L2438
Last Updated 2026-02-13 00:00 GMT

Code Reference

convert_flux_transformer_checkpoint_to_diffusers

Source: src/diffusers/loaders/single_file_utils.py:L2244-L2438

def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
    converted_state_dict = {}
    keys = list(checkpoint.keys())

    # Strip framework prefix
    for k in keys:
        if "model.diffusion_model." in k:
            checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

    # Detect layer counts dynamically
    num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint
                         if "double_blocks." in k))[-1] + 1
    num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint
                                 if "single_blocks." in k))[-1] + 1
    mlp_ratio = 4.0
    inner_dim = 3072

    def swap_scale_shift(weight):
        shift, scale = weight.chunk(2, dim=0)
        return torch.cat([scale, shift], dim=0)

    ## Timestep embeddings: time_in -> time_text_embed.timestep_embedder
    converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = \
        checkpoint.pop("time_in.in_layer.weight")
    converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = \
        checkpoint.pop("time_in.in_layer.bias")
    # ... (linear_2 similarly)

    ## Text embeddings: vector_in -> time_text_embed.text_embedder
    converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = \
        checkpoint.pop("vector_in.in_layer.weight")
    # ...

    ## Input projections
    converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
    converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")

    # Double transformer blocks (joint attention)
    for i in range(num_layers):
        block_prefix = f"transformer_blocks.{i}."

        # AdaLN modulation
        converted_state_dict[f"{block_prefix}norm1.linear.weight"] = \
            checkpoint.pop(f"double_blocks.{i}.img_mod.lin.weight")

        # QKV splitting: fused (3*dim, dim) -> separate Q, K, V
        sample_q, sample_k, sample_v = torch.chunk(
            checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
        )
        converted_state_dict[f"{block_prefix}attn.to_q.weight"] = sample_q
        converted_state_dict[f"{block_prefix}attn.to_k.weight"] = sample_k
        converted_state_dict[f"{block_prefix}attn.to_v.weight"] = sample_v
        # ... (context Q, K, V similarly from txt_attn.qkv)

        # QK norm weights
        converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = \
            checkpoint.pop(f"double_blocks.{i}.img_attn.norm.query_norm.scale")
        # ...

    # Single transformer blocks (fused Q+K+V+MLP)
    for i in range(num_single_layers):
        block_prefix = f"single_transformer_blocks.{i}."
        mlp_hidden_dim = int(inner_dim * mlp_ratio)
        split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)

        q, k, v, mlp = torch.split(
            checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0
        )
        converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q
        converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k
        converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v
        converted_state_dict[f"{block_prefix}proj_mlp.weight"] = mlp
        # ...

    # Final layer with scale-shift swap
    converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
    converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
        checkpoint.pop("final_layer.adaLN_modulation.1.weight")
    )

    return converted_state_dict

Key Parameters

Parameter Type Description
checkpoint dict[str, torch.Tensor] Raw checkpoint state dict (modified in-place via pop())
**kwargs dict Additional keyword arguments (e.g., config for some converters)

I/O Contract

Inputs

  • checkpoint: Raw state dict with original key naming. Keys may have optional model.diffusion_model. prefix.

Outputs

  • dict[str, torch.Tensor]: New state dict with Diffusers-compatible key names and correctly shaped tensors.

Conversion Patterns

Pattern 1: Simple Key Rename

# Original: "time_in.in_layer.weight"
# Diffusers: "time_text_embed.timestep_embedder.linear_1.weight"
converted[new_key] = checkpoint.pop(old_key)

Pattern 2: QKV Splitting (Equal)

# Original: fused QKV weight of shape (3*dim, dim)
q, k, v = torch.chunk(checkpoint.pop("attn.qkv.weight"), 3, dim=0)
converted["attn.to_q.weight"] = q  # (dim, dim)
converted["attn.to_k.weight"] = k  # (dim, dim)
converted["attn.to_v.weight"] = v  # (dim, dim)

Pattern 3: QKVM Splitting (Unequal)

# Original: fused Q+K+V+MLP weight of shape (3*dim + mlp_dim, dim)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(checkpoint.pop("linear1.weight"), split_size, dim=0)

Pattern 4: Scale-Shift Swap

# Original: [shift, scale] concatenated
# Diffusers: [scale, shift] concatenated
shift, scale = weight.chunk(2, dim=0)
converted_weight = torch.cat([scale, shift], dim=0)

Pattern 5: Prefix Stripping

for k in list(checkpoint.keys()):
    if "model.diffusion_model." in k:
        checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

External Dependencies

  • torch (for torch.chunk, torch.split, torch.cat)

Usage Examples

Converting a Flux Checkpoint

from safetensors.torch import load_file
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers

# Load original checkpoint
checkpoint = load_file("flux1-dev.safetensors")

# Convert to Diffusers format
diffusers_state_dict = convert_flux_transformer_checkpoint_to_diffusers(checkpoint)

# Load into model
from diffusers import FluxTransformer2DModel
model = FluxTransformer2DModel.from_config(config)
model.load_state_dict(diffusers_state_dict)

Verifying Conversion Completeness

# After conversion, check for unconverted keys
remaining = set(checkpoint.keys())
if remaining:
    print(f"Warning: {len(remaining)} keys not consumed: {remaining}")

Related Pages

Principle:Huggingface_Diffusers_Weight_Mapping

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment