Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Zai org CogVideo Autoencoder Model

From Leeroopedia
Revision as of 17:08, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Zai_org_CogVideo_Autoencoder_Model.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Video_Generation, Autoencoding
Last Updated 2026-02-10 00:00 GMT

Overview

Implements the autoencoder architecture (Encoder, Decoder, and full U-Net Model) for the VAE/latent space in CogVideo's diffusion pipeline, with support for multiple attention backends including vanilla, xformers, and linear attention.

Description

This module provides the core autoencoder components that map between pixel space and the latent space where the diffusion process operates. It contains three primary classes:

Encoder progressively downsamples input images through resolution levels defined by ch_mult, using ResnetBlock layers (GroupNorm, Swish, 3x3 convolution) and configurable attention blocks. The attention type is selected via make_attn, which supports "vanilla" (using PyTorch 2.0 scaled_dot_product_attention), "vanilla-xformers" (using the xformers memory-efficient implementation), "linear" attention, and "none" (identity). A middle section at the bottleneck contains ResnetBlock-Attention-ResnetBlock. The output projects to 2*z_channels or z_channels.

Decoder mirrors the encoder with progressive upsampling. It uses factory methods (_make_attn, _make_resblock, _make_conv) that can be overridden by subclasses for custom behavior. The decoder supports an optional tanh_out activation on the final output and a give_pre_end flag to return features before the final normalization and convolution.

Model is a full U-Net that combines encoding and decoding with skip connections. It optionally includes timestep conditioning via sinusoidal embeddings processed through a two-layer MLP. The U-Net downsampling path stores intermediate features that are concatenated with the upsampling path features via skip connections.

The AttnBlock implementation uses scaled_dot_product_attention from PyTorch 2.0+ by default, reshaping spatial dimensions to sequence format. MemoryEfficientAttnBlock provides an xformers-based alternative for reduced memory usage on older PyTorch versions.

Usage

Use this module as the VAE encoder/decoder that bridges pixel space and the latent space for latent diffusion models. The Encoder compresses images to latent codes, and the Decoder reconstructs images from latent codes after the diffusion process.

Code Reference

Source Location

  • Repository: Zai_org_CogVideo
  • File: sat/sgm/modules/diffusionmodules/model.py
  • Lines: 1-708

Signature

class Encoder(nn.Module):
    def __init__(
        self,
        *,
        ch,
        out_ch,
        ch_mult=(1, 2, 4, 8),
        num_res_blocks,
        attn_resolutions,
        dropout=0.0,
        resamp_with_conv=True,
        in_channels,
        resolution,
        z_channels,
        double_z=True,
        use_linear_attn=False,
        attn_type="vanilla",
        **ignore_kwargs,
    ):

class Decoder(nn.Module):
    def __init__(
        self,
        *,
        ch,
        out_ch,
        ch_mult=(1, 2, 4, 8),
        num_res_blocks,
        attn_resolutions,
        dropout=0.0,
        resamp_with_conv=True,
        in_channels,
        resolution,
        z_channels,
        give_pre_end=False,
        tanh_out=False,
        use_linear_attn=False,
        attn_type="vanilla",
        **ignorekwargs,
    ):

class Model(nn.Module):
    def __init__(
        self,
        *,
        ch,
        out_ch,
        ch_mult=(1, 2, 4, 8),
        num_res_blocks,
        attn_resolutions,
        dropout=0.0,
        resamp_with_conv=True,
        in_channels,
        resolution,
        use_timestep=True,
        use_linear_attn=False,
        attn_type="vanilla",
    ):

Import

from sat.sgm.modules.diffusionmodules.model import Encoder, Decoder, Model

I/O Contract

Inputs (Encoder)

Name Type Required Description
ch int Yes Base channel count
out_ch int Yes Output channels (used by counterpart)
ch_mult tuple of int No Channel multiplier per level, default (1, 2, 4, 8)
num_res_blocks int Yes Residual blocks per resolution level
attn_resolutions list of int Yes Resolutions for attention layers
dropout float No Dropout probability, default 0.0
in_channels int Yes Input image channels
resolution int Yes Input spatial resolution
z_channels int Yes Latent channel count
double_z bool No Double z_channels for mean/variance, default True
attn_type str No Attention backend: "vanilla", "vanilla-xformers", "linear", "none", default "vanilla"

Inputs (Model forward)

Name Type Required Description
x torch.Tensor Yes Input image tensor (B, C, H, W)
t torch.Tensor No Timestep tensor (B,) for conditioning, required if use_timestep=True
context torch.Tensor No Additional context concatenated along channel axis

Outputs

Name Type Description
h (Encoder) torch.Tensor Latent tensor (B, 2*z_channels or z_channels, H', W')
h (Decoder) torch.Tensor Reconstructed image (B, out_ch, H, W)
h (Model) torch.Tensor U-Net output (B, out_ch, H, W)

Usage Examples

import torch
from sat.sgm.modules.diffusionmodules.model import Encoder, Decoder

# Create encoder
encoder = Encoder(
    ch=128, out_ch=3, ch_mult=(1, 2, 4, 4),
    num_res_blocks=2, attn_resolutions=[32],
    in_channels=3, resolution=256, z_channels=4,
    attn_type="vanilla",
)

# Create decoder
decoder = Decoder(
    ch=128, out_ch=3, ch_mult=(1, 2, 4, 4),
    num_res_blocks=2, attn_resolutions=[32],
    in_channels=256, resolution=256, z_channels=4,
    attn_type="vanilla",
)

# Encode image to latent
x = torch.randn(1, 3, 256, 256)
z = encoder(x)  # shape: (1, 8, 32, 32)

# Decode latent to image
recon = decoder(z[:, :4])  # shape: (1, 3, 256, 256)

# Access last layer weight for loss computation
last_layer = decoder.get_last_layer()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment