Implementation:Huggingface Diffusers Convert Checkpoint To Diffusers
Appearance
| Field | Value |
|---|---|
| Type | Pattern Doc |
| Overview | Concrete weight key remapping functions that convert original checkpoint state dicts to Diffusers format, using Flux as a detailed example |
| Domains | Model Conversion, Tensor Operations |
| Workflow | Checkpoint_Conversion |
| Related Principle | Huggingface_Diffusers_Weight_Mapping |
| Source | src/diffusers/loaders/single_file_utils.py:L2244-L2438
|
| Last Updated | 2026-02-13 00:00 GMT |
Code Reference
convert_flux_transformer_checkpoint_to_diffusers
Source: src/diffusers/loaders/single_file_utils.py:L2244-L2438
def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {}
keys = list(checkpoint.keys())
# Strip framework prefix
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
# Detect layer counts dynamically
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint
if "double_blocks." in k))[-1] + 1
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint
if "single_blocks." in k))[-1] + 1
mlp_ratio = 4.0
inner_dim = 3072
def swap_scale_shift(weight):
shift, scale = weight.chunk(2, dim=0)
return torch.cat([scale, shift], dim=0)
## Timestep embeddings: time_in -> time_text_embed.timestep_embedder
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = \
checkpoint.pop("time_in.in_layer.weight")
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = \
checkpoint.pop("time_in.in_layer.bias")
# ... (linear_2 similarly)
## Text embeddings: vector_in -> time_text_embed.text_embedder
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = \
checkpoint.pop("vector_in.in_layer.weight")
# ...
## Input projections
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
# Double transformer blocks (joint attention)
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
# AdaLN modulation
converted_state_dict[f"{block_prefix}norm1.linear.weight"] = \
checkpoint.pop(f"double_blocks.{i}.img_mod.lin.weight")
# QKV splitting: fused (3*dim, dim) -> separate Q, K, V
sample_q, sample_k, sample_v = torch.chunk(
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = sample_q
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = sample_k
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = sample_v
# ... (context Q, K, V similarly from txt_attn.qkv)
# QK norm weights
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = \
checkpoint.pop(f"double_blocks.{i}.img_attn.norm.query_norm.scale")
# ...
# Single transformer blocks (fused Q+K+V+MLP)
for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}."
mlp_hidden_dim = int(inner_dim * mlp_ratio)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(
checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0
)
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = mlp
# ...
# Final layer with scale-shift swap
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
checkpoint.pop("final_layer.adaLN_modulation.1.weight")
)
return converted_state_dict
Key Parameters
| Parameter | Type | Description |
|---|---|---|
checkpoint |
dict[str, torch.Tensor] |
Raw checkpoint state dict (modified in-place via pop())
|
**kwargs |
dict |
Additional keyword arguments (e.g., config for some converters)
|
I/O Contract
Inputs
checkpoint: Raw state dict with original key naming. Keys may have optionalmodel.diffusion_model.prefix.
Outputs
dict[str, torch.Tensor]: New state dict with Diffusers-compatible key names and correctly shaped tensors.
Conversion Patterns
Pattern 1: Simple Key Rename
# Original: "time_in.in_layer.weight"
# Diffusers: "time_text_embed.timestep_embedder.linear_1.weight"
converted[new_key] = checkpoint.pop(old_key)
Pattern 2: QKV Splitting (Equal)
# Original: fused QKV weight of shape (3*dim, dim)
q, k, v = torch.chunk(checkpoint.pop("attn.qkv.weight"), 3, dim=0)
converted["attn.to_q.weight"] = q # (dim, dim)
converted["attn.to_k.weight"] = k # (dim, dim)
converted["attn.to_v.weight"] = v # (dim, dim)
Pattern 3: QKVM Splitting (Unequal)
# Original: fused Q+K+V+MLP weight of shape (3*dim + mlp_dim, dim)
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
q, k, v, mlp = torch.split(checkpoint.pop("linear1.weight"), split_size, dim=0)
Pattern 4: Scale-Shift Swap
# Original: [shift, scale] concatenated
# Diffusers: [scale, shift] concatenated
shift, scale = weight.chunk(2, dim=0)
converted_weight = torch.cat([scale, shift], dim=0)
Pattern 5: Prefix Stripping
for k in list(checkpoint.keys()):
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
External Dependencies
torch(fortorch.chunk,torch.split,torch.cat)
Usage Examples
Converting a Flux Checkpoint
from safetensors.torch import load_file
from diffusers.loaders.single_file_utils import convert_flux_transformer_checkpoint_to_diffusers
# Load original checkpoint
checkpoint = load_file("flux1-dev.safetensors")
# Convert to Diffusers format
diffusers_state_dict = convert_flux_transformer_checkpoint_to_diffusers(checkpoint)
# Load into model
from diffusers import FluxTransformer2DModel
model = FluxTransformer2DModel.from_config(config)
model.load_state_dict(diffusers_state_dict)
Verifying Conversion Completeness
# After conversion, check for unconverted keys
remaining = set(checkpoint.keys())
if remaining:
print(f"Warning: {len(remaining)} keys not consumed: {remaining}")
Related Pages
- Huggingface_Diffusers_Weight_Mapping (principle for this implementation) - Theory of key remapping and tensor operations
- Huggingface_Diffusers_Single_File_Loadable_Classes (dispatches to this) - Registry that selects the conversion function
- Huggingface_Diffusers_Infer_Model_Type (prerequisite) - Model type determines config for conversion
- Huggingface_Diffusers_From_Single_File (caller) - from_single_file invokes the conversion
Principle:Huggingface_Diffusers_Weight_Mapping
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment