Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Diffusers WanTransformer3DModel Forward

From Leeroopedia
Field Value
Type API Doc
Overview The forward pass of WanTransformer3DModel: patch embedding, condition embedding, transformer blocks with 3D attention, and output unpatchifying
Domains Video Generation, 3D Transformers
Workflow Video_Generation
Related Principle Huggingface_Diffusers_Video_Denoising
Source src/diffusers/models/transformers/transformer_wan.py:L507-L709
Last Updated 2026-02-13 00:00 GMT

Code Reference

WanTransformer3DModel.forward

Source: src/diffusers/models/transformers/transformer_wan.py:L628-L709

def forward(
    self,
    hidden_states: torch.Tensor,
    timestep: torch.LongTensor,
    encoder_hidden_states: torch.Tensor,
    encoder_hidden_states_image: torch.Tensor | None = None,
    return_dict: bool = True,
    attention_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor | dict[str, torch.Tensor]:
    batch_size, num_channels, num_frames, height, width = hidden_states.shape
    p_t, p_h, p_w = self.config.patch_size

    # 1. Rotary position embeddings
    rotary_emb = self.rope(hidden_states)

    # 2. Patch embedding: (B, C, F, H, W) -> (B, seq_len, inner_dim)
    hidden_states = self.patch_embedding(hidden_states)
    hidden_states = hidden_states.flatten(2).transpose(1, 2)

    # 3. Condition embedding
    temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = \
        self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image)

    timestep_proj = timestep_proj.unflatten(1, (6, -1))  # (B, 6, inner_dim)

    if encoder_hidden_states_image is not None:
        encoder_hidden_states = torch.concat(
            [encoder_hidden_states_image, encoder_hidden_states], dim=1
        )

    # 4. Transformer blocks
    for block in self.blocks:
        hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)

    # 5. Output: norm + projection + unpatchify
    shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
    hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
    hidden_states = self.proj_out(hidden_states)

    # Unpatchify back to (B, C, F, H, W)
    hidden_states = hidden_states.reshape(
        batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
    )
    hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
    output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)

    if not return_dict:
        return (output,)
    return Transformer2DModelOutput(sample=output)

Import

from diffusers.models import WanTransformer3DModel

Key Parameters

Parameter Type Description
hidden_states torch.Tensor (B,C,F,H,W) Noisy latent video tensor, e.g., shape (1, 16, 21, 60, 104) for 480x832 at 81 frames
timestep torch.LongTensor (B,) Current diffusion timestep value
encoder_hidden_states torch.Tensor (B, seq_len, dim) Text encoder output embeddings (UMT5 hidden states)
encoder_hidden_states_image None CLIP image embeddings for image-to-video (optional, None for T2V)
return_dict bool Whether to return Transformer2DModelOutput or tuple
attention_kwargs None Additional kwargs passed to attention processors (e.g., LoRA scale)

I/O Contract

Inputs

  • hidden_states: 5D tensor (B, C, F, H, W) where C=16 (latent channels), F=num_latent_frames, H and W are latent spatial dimensions
  • timestep: 1D tensor (B,) with float timestep values, or 2D (B, seq_len) for Wan 2.2 TI2V
  • encoder_hidden_states: 3D tensor (B, 512, text_dim) with text embeddings

Outputs

  • Transformer2DModelOutput with .sample attribute: 5D tensor (B, out_channels, F, H, W) matching the input spatial and temporal dimensions. This is the predicted noise (or velocity) for the current timestep.

Architecture Details

Component Configuration (14B) Configuration (1.3B)
patch_size (1, 2, 2) (1, 2, 2)
num_attention_heads 40 12
attention_head_dim 128 128
num_layers 40 30
ffn_dim 13824 8960
in_channels / out_channels 16 / 16 16 / 16
inner_dim 5120 1536

Usage Examples

Direct Forward Pass (Advanced Usage)

import torch
from diffusers.models import WanTransformer3DModel

model = WanTransformer3DModel.from_pretrained(
    "Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16
)
model.to("cuda")

# Example inputs
hidden_states = torch.randn(1, 16, 21, 60, 104, dtype=torch.bfloat16, device="cuda")
timestep = torch.tensor([500.0], device="cuda")
encoder_hidden_states = torch.randn(1, 512, 4096, dtype=torch.bfloat16, device="cuda")

output = model(
    hidden_states=hidden_states,
    timestep=timestep,
    encoder_hidden_states=encoder_hidden_states,
)
noise_pred = output.sample  # Shape: (1, 16, 21, 60, 104)

Within Pipeline Denoising Loop

# This is what happens inside WanPipeline.__call__:
for i, t in enumerate(timesteps):
    latent_model_input = latents.to(transformer_dtype)
    timestep = t.expand(latents.shape[0])

    with current_model.cache_context("cond"):
        noise_pred = current_model(
            hidden_states=latent_model_input,
            timestep=timestep,
            encoder_hidden_states=prompt_embeds,
            attention_kwargs=attention_kwargs,
            return_dict=False,
        )[0]

    if do_classifier_free_guidance:
        with current_model.cache_context("uncond"):
            noise_uncond = current_model(
                hidden_states=latent_model_input,
                timestep=timestep,
                encoder_hidden_states=negative_prompt_embeds,
                return_dict=False,
            )[0]
        noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)

    latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]

Related Pages

Principle:Huggingface_Diffusers_Video_Denoising

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment