Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL WanTrainingModule

From Leeroopedia


Knowledge Sources
Domains Diffusion_Models, Model_Architecture
Last Updated 2026-02-07 20:00 GMT

Overview

Concrete Wan video diffusion training module with LoRA injection and reward scoring provided by the Alibaba ROLL library.

Description

The WanTrainingModule class extends DiffusionTrainingModule with Wan2.2-specific initialization including DiT model loading, LoRA injection via PEFT, frozen VAE/text encoder, Euler ODE scheduler, and FaceAnalysis reward scorer.

Usage

Instantiated by the diffusion module provider during cluster initialization.

Code Reference

Source Location

  • Repository: Alibaba ROLL
  • File: roll/pipeline/diffusion/modules/wan_module.py
  • Lines: L98-174

Signature

class WanTrainingModule(DiffusionTrainingModule):
    def __init__(
        self,
        model_paths,
        reward_model_path,
        tokenizer_path,
        trainable_models,
        lora_target_modules: str = "q,k,v,o,ffn.0,ffn.2",
        lora_rank: int = 32,
        use_gradient_checkpointing: bool = True,
        num_inference_steps: int = 8,
        mid_timestep: int = 4,
        final_timestep: int = 7,
        **kwargs
    ) -> None:
        """
        Initialize Wan video training module.

        Args:
            model_paths: Paths to DiT, VAE, text encoder components
            reward_model_path: Path to ONNX face analysis models
            tokenizer_path: Path to text tokenizer
            trainable_models: Which components to train (usually ["dit"])
            lora_target_modules: LoRA target layers
            lora_rank: LoRA rank (default 32)
            num_inference_steps: Euler denoising steps (default 8)
            mid_timestep: Gradient-enabled boundary (steps 0-3 frozen, 4-7 grad)
            final_timestep: Last denoising step
        """

Import

from roll.pipeline.diffusion.modules.wan_module import WanTrainingModule

I/O Contract

Inputs

Name Type Required Description
model_paths dict Yes Paths to DiT, VAE, text encoder, ONNX face models
lora_rank int No LoRA rank (default 32)
num_inference_steps int No Euler denoising steps (default 8)

Outputs

Name Type Description
WanTrainingModule WanTrainingModule Initialized module with LoRA-injected DiT and reward scorer

Usage Examples

module = WanTrainingModule(
    model_paths=wan22_paths,
    reward_model_path="./face_models",
    tokenizer_path="./tokenizer",
    trainable_models=["dit"],
    lora_rank=32,
    num_inference_steps=8,
    mid_timestep=4,
)

Related Pages

Implements Principle

Requires Environment

Environment Dependencies

This implementation requires the following environment constraints:

Heuristics Applied

This implementation uses the following heuristics:

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment