Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Zai org CogVideo VQVAE Blocks

From Leeroopedia


Knowledge Sources
Domains Video_Generation, Autoencoding
Last Updated 2026-02-10 00:00 GMT

Overview

Standard 2D encoder and decoder building blocks for VQ-VAE, implementing a multi-resolution convolutional autoencoder with residual blocks, self-attention, and progressive spatial downsampling/upsampling.

Description

This module provides the foundational 2D Encoder and Decoder classes along with their constituent building blocks for the VQ-VAE architecture. These blocks define the standard spatial autoencoder upon which CogVideo's video-specific extensions are built.

The Encoder progressively downsamples input images through resolution levels defined by ch_mult. At each level, num_res_blocks residual blocks are applied, each using GroupNorm normalization with 32 groups, Swish activation (x * sigmoid(x)), and 3x3 convolutions. When the current resolution appears in attn_resolutions, an AttnBlock is appended after the residual block. Spatial downsampling between levels uses strided 2D convolutions with asymmetric zero-padding (padding only on the right and bottom edges). A middle section at the bottleneck contains two ResnetBlocks with an AttnBlock in between. The final output is produced by GroupNorm, Swish, and a 3x3 convolution projecting to 2*z_channels (for mean and log-variance) or z_channels.

The Decoder mirrors the encoder with progressive upsampling via nearest-neighbor interpolation followed by optional learned 3x3 convolutions. It processes from the lowest resolution upward, applying num_res_blocks + 1 residual blocks per level. Both Encoder and Decoder provide a forward_with_features_output method that returns intermediate feature maps at each layer.

AttnBlock implements single-head spatial self-attention with the scaling applied to queries before the matrix multiplication (rather than after), which improves numerical stability in fp16. The attention computation uses 1x1 convolutions for Q, K, V projections.

Usage

Use these blocks as the 2D spatial backbone for VQ-VAE autoencoders. They serve as the base architecture that can be extended with temporal layers for video processing.

Code Reference

Source Location

  • Repository: Zai_org_CogVideo
  • File: sat/sgm/modules/autoencoding/vqvae/vqvae_blocks.py
  • Lines: 1-424

Signature

class Encoder(nn.Module):
    def __init__(
        self,
        *,
        ch,
        out_ch,
        ch_mult=(1, 2, 4, 8),
        num_res_blocks,
        attn_resolutions,
        dropout=0.0,
        resamp_with_conv=True,
        in_channels,
        resolution,
        z_channels,
        double_z=True,
        **ignore_kwargs,
    ):

class Decoder(nn.Module):
    def __init__(
        self,
        *,
        ch,
        out_ch,
        ch_mult=(1, 2, 4, 8),
        num_res_blocks,
        attn_resolutions,
        dropout=0.0,
        resamp_with_conv=True,
        in_channels,
        resolution,
        z_channels,
        give_pre_end=False,
        **ignorekwargs,
    ):

Import

from sat.sgm.modules.autoencoding.vqvae.vqvae_blocks import Encoder, Decoder

I/O Contract

Inputs (Encoder)

Name Type Required Description
ch int Yes Base channel count for the network
out_ch int Yes Number of output channels (used by the decoder counterpart)
ch_mult tuple of int No Channel multiplier per resolution level, default (1, 2, 4, 8)
num_res_blocks int Yes Number of residual blocks per resolution level
attn_resolutions list of int Yes Resolutions at which to apply self-attention
dropout float No Dropout probability, default 0.0
resamp_with_conv bool No Use learned convolutions for downsampling, default True
in_channels int Yes Number of input channels (e.g. 3 for RGB)
resolution int Yes Input spatial resolution
z_channels int Yes Number of latent channels
double_z bool No Double z_channels for mean/variance output, default True

Inputs (Decoder)

Name Type Required Description
ch int Yes Base channel count
out_ch int Yes Number of output channels
ch_mult tuple of int No Channel multiplier per level, default (1, 2, 4, 8)
num_res_blocks int Yes Number of residual blocks per level
attn_resolutions list of int Yes Resolutions for self-attention
in_channels int Yes Number of input channels
resolution int Yes Target output resolution
z_channels int Yes Number of latent input channels
give_pre_end bool No Return features before final projection, default False

Outputs

Name Type Description
h (Encoder) torch.Tensor Latent tensor of shape (B, 2*z_channels or z_channels, H', W')
h (Decoder) torch.Tensor Reconstructed image tensor of shape (B, out_ch, resolution, resolution)

Usage Examples

import torch
from sat.sgm.modules.autoencoding.vqvae.vqvae_blocks import Encoder, Decoder

encoder = Encoder(
    ch=128, out_ch=3, ch_mult=(1, 2, 4, 4),
    num_res_blocks=2, attn_resolutions=[32],
    in_channels=3, resolution=256, z_channels=4,
)

decoder = Decoder(
    ch=128, out_ch=3, ch_mult=(1, 2, 4, 4),
    num_res_blocks=2, attn_resolutions=[32],
    in_channels=256, resolution=256, z_channels=4,
)

# Encode
x = torch.randn(1, 3, 256, 256)
z = encoder(x)  # shape: (1, 8, 32, 32)

# Decode
recon = decoder(z[:, :4, :, :])  # shape: (1, 3, 256, 256)

# Extract intermediate features
z, features = encoder.forward_with_features_output(x)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment