Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Zai org CogVideo VideoDecoder Temporal

From Leeroopedia
Revision as of 17:09, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Zai_org_CogVideo_VideoDecoder_Temporal.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Video_Generation, Autoencoding, Temporal_Modeling
Last Updated 2026-02-10 00:00 GMT

Overview

VideoDecoder extends a 2D image autoencoder decoder with temporal (video) capabilities by injecting time-aware residual blocks, 3D convolutions, and temporal attention layers that process frame sequences using configurable alpha-blending merge strategies.

Description

The VideoDecoder class extends the base Decoder from the diffusion modules and transforms it into a video-capable decoder through a set of modular temporal components. It overrides three factory methods -- _make_conv, _make_resblock, and _make_attn -- to inject temporal processing based on a configurable time_mode parameter that supports three modes: "all" (both temporal convolutions and attention), "conv-only" (temporal convolutions only), and "attn-only" (temporal attention only).

The temporal components include:

  • VideoResBlock: Wraps a spatial ResnetBlock with an additional 3D ResBlock that operates along the time dimension. The spatial and temporal outputs are blended via a learnable or fixed alpha mixing factor, allowing gradual activation of temporal processing during training.
  • AE3DConv: Extends Conv2d with an additional Conv3d layer for temporal mixing. The 2D convolution processes spatial features first, then the 3D convolution mixes across the time dimension.
  • VideoBlock / MemoryEfficientVideoBlock: Add temporal transformer attention on top of spatial attention, with sinusoidal timestep embeddings injected before the temporal attention. The xformers-based variant provides memory-efficient attention for large sequences.

The merge_strategy parameter controls how spatial and temporal features are combined: "fixed" uses a constant alpha registered as a buffer, while "learned" uses a sigmoid-gated parameter that is trained end-to-end.

Usage

Use VideoDecoder when building a video autoencoder that needs to decode frame sequences from latent representations. It is the primary decoder architecture for converting 2D autoencoder checkpoints into video-capable decoders, leveraging the alpha-blending mechanism to enable smooth fine-tuning from pretrained image models to video models.

Code Reference

Source Location

  • Repository: Zai_org_CogVideo
  • File: sat/sgm/modules/autoencoding/temporal_ae.py
  • Lines: 281-333 (VideoDecoder), 18-81 (VideoResBlock), 84-105 (AE3DConv), 108-171 (VideoBlock), 174-237 (MemoryEfficientVideoBlock), 240-273 (make_time_attn)

Signature

class VideoDecoder(Decoder):
    available_time_modes = ["all", "conv-only", "attn-only"]

    def __init__(
        self,
        *args,
        video_kernel_size: Union[int, list] = 3,
        alpha: float = 0.0,
        merge_strategy: str = "learned",
        time_mode: str = "conv-only",
        **kwargs,
    ):
        ...

    def get_last_layer(self, skip_time_mix=False, **kwargs):
        ...

    def _make_attn(self) -> Callable:
        ...

    def _make_conv(self) -> Callable:
        ...

    def _make_resblock(self) -> Callable:
        ...
class VideoResBlock(ResnetBlock):
    def __init__(
        self,
        out_channels,
        *args,
        dropout=0.0,
        video_kernel_size=3,
        alpha=0.0,
        merge_strategy="learned",
        **kwargs,
    ):
        ...

    def forward(self, x, temb, skip_video=False, timesteps=None):
        ...
class AE3DConv(torch.nn.Conv2d):
    def __init__(self, in_channels, out_channels, video_kernel_size=3, *args, **kwargs):
        ...

    def forward(self, input, timesteps, skip_video=False):
        ...

Import

from sat.sgm.modules.autoencoding.temporal_ae import VideoDecoder, VideoResBlock, AE3DConv, VideoBlock

I/O Contract

VideoDecoder Inputs

Name Type Required Description
video_kernel_size Union[int, list] No (default 3) Kernel size for temporal convolutions; can be a list like [3, 1, 1]
alpha float No (default 0.0) Initial alpha value for spatial-temporal blending (0.0 = fully spatial)
merge_strategy str No (default "learned") How to combine spatial and temporal features: "fixed" or "learned"
time_mode str No (default "conv-only") Which temporal components to activate: "all", "conv-only", or "attn-only"
*args, **kwargs - - All remaining arguments are passed to the base Decoder class

VideoResBlock Forward Inputs

Name Type Required Description
x torch.Tensor Yes Input tensor of shape (B*T, C, H, W) where T is the number of frames
temb torch.Tensor Yes Timestep embedding tensor (can be None)
skip_video bool No (default False) If True, skip temporal processing and behave as a spatial-only block
timesteps int No Number of frames; defaults to self.timesteps if not provided

Outputs

Name Type Description
output torch.Tensor Decoded tensor of shape (B*T, C_out, H_out, W_out) with temporal information blended via alpha mixing
get_last_layer torch.Tensor Returns the weight of either the time_mix_conv (default) or the spatial conv_out layer

Usage Examples

from sat.sgm.modules.autoencoding.temporal_ae import VideoDecoder

# Create a video decoder with learned temporal blending
decoder = VideoDecoder(
    ch=128,
    out_ch=3,
    ch_mult=(1, 2, 4, 4),
    num_res_blocks=2,
    attn_resolutions=[],
    dropout=0.0,
    in_channels=4,
    resolution=256,
    z_channels=4,
    video_kernel_size=3,
    alpha=0.0,           # start fully spatial
    merge_strategy="learned",
    time_mode="conv-only",
)

# Decode a batch of 4 videos, each with 16 frames
# Input shape: (B*T, C, H, W) = (64, 4, 32, 32)
z = torch.randn(64, 4, 32, 32)
decoded = decoder(z, timesteps=16)
# Output shape: (64, 3, 256, 256)

# Access the last layer for training
last_layer = decoder.get_last_layer()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment