Implementation:Huggingface Diffusers WanTransformer3DModel Forward
Appearance
| Field | Value |
|---|---|
| Type | API Doc |
| Overview | The forward pass of WanTransformer3DModel: patch embedding, condition embedding, transformer blocks with 3D attention, and output unpatchifying |
| Domains | Video Generation, 3D Transformers |
| Workflow | Video_Generation |
| Related Principle | Huggingface_Diffusers_Video_Denoising |
| Source | src/diffusers/models/transformers/transformer_wan.py:L507-L709
|
| Last Updated | 2026-02-13 00:00 GMT |
Code Reference
WanTransformer3DModel.forward
Source: src/diffusers/models/transformers/transformer_wan.py:L628-L709
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: torch.Tensor | None = None,
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor | dict[str, torch.Tensor]:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
# 1. Rotary position embeddings
rotary_emb = self.rope(hidden_states)
# 2. Patch embedding: (B, C, F, H, W) -> (B, seq_len, inner_dim)
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
# 3. Condition embedding
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = \
self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image)
timestep_proj = timestep_proj.unflatten(1, (6, -1)) # (B, 6, inner_dim)
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat(
[encoder_hidden_states_image, encoder_hidden_states], dim=1
)
# 4. Transformer blocks
for block in self.blocks:
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
# 5. Output: norm + projection + unpatchify
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)
# Unpatchify back to (B, C, F, H, W)
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
Import
from diffusers.models import WanTransformer3DModel
Key Parameters
| Parameter | Type | Description |
|---|---|---|
hidden_states |
torch.Tensor (B,C,F,H,W) |
Noisy latent video tensor, e.g., shape (1, 16, 21, 60, 104) for 480x832 at 81 frames
|
timestep |
torch.LongTensor (B,) |
Current diffusion timestep value |
encoder_hidden_states |
torch.Tensor (B, seq_len, dim) |
Text encoder output embeddings (UMT5 hidden states) |
encoder_hidden_states_image |
None | CLIP image embeddings for image-to-video (optional, None for T2V) |
return_dict |
bool |
Whether to return Transformer2DModelOutput or tuple
|
attention_kwargs |
None | Additional kwargs passed to attention processors (e.g., LoRA scale) |
I/O Contract
Inputs
hidden_states: 5D tensor(B, C, F, H, W)where C=16 (latent channels), F=num_latent_frames, H and W are latent spatial dimensionstimestep: 1D tensor(B,)with float timestep values, or 2D(B, seq_len)for Wan 2.2 TI2Vencoder_hidden_states: 3D tensor(B, 512, text_dim)with text embeddings
Outputs
Transformer2DModelOutputwith.sampleattribute: 5D tensor(B, out_channels, F, H, W)matching the input spatial and temporal dimensions. This is the predicted noise (or velocity) for the current timestep.
Architecture Details
| Component | Configuration (14B) | Configuration (1.3B) |
|---|---|---|
patch_size |
(1, 2, 2) |
(1, 2, 2)
|
num_attention_heads |
40 | 12 |
attention_head_dim |
128 | 128 |
num_layers |
40 | 30 |
ffn_dim |
13824 | 8960 |
in_channels / out_channels |
16 / 16 | 16 / 16 |
inner_dim |
5120 | 1536 |
Usage Examples
Direct Forward Pass (Advanced Usage)
import torch
from diffusers.models import WanTransformer3DModel
model = WanTransformer3DModel.from_pretrained(
"Wan-AI/Wan2.1-T2V-14B-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16
)
model.to("cuda")
# Example inputs
hidden_states = torch.randn(1, 16, 21, 60, 104, dtype=torch.bfloat16, device="cuda")
timestep = torch.tensor([500.0], device="cuda")
encoder_hidden_states = torch.randn(1, 512, 4096, dtype=torch.bfloat16, device="cuda")
output = model(
hidden_states=hidden_states,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
)
noise_pred = output.sample # Shape: (1, 16, 21, 60, 104)
Within Pipeline Denoising Loop
# This is what happens inside WanPipeline.__call__:
for i, t in enumerate(timesteps):
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
with current_model.cache_context("cond"):
noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if do_classifier_free_guidance:
with current_model.cache_context("uncond"):
noise_uncond = current_model(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
Related Pages
- Huggingface_Diffusers_Video_Denoising (principle for this implementation) - Theory of 3D video denoising
- Huggingface_Diffusers_Video_Pipeline_From_Pretrained (loads this model) - Pipeline loads the transformer
- Huggingface_Diffusers_Video_Memory_Setup (optimizes this) - CPU offloading manages transformer GPU placement
Principle:Huggingface_Diffusers_Video_Denoising
Requires Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment