Implementation:Zai org CogVideo Autoencoder Model
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Autoencoding |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Implements the autoencoder architecture (Encoder, Decoder, and full U-Net Model) for the VAE/latent space in CogVideo's diffusion pipeline, with support for multiple attention backends including vanilla, xformers, and linear attention.
Description
This module provides the core autoencoder components that map between pixel space and the latent space where the diffusion process operates. It contains three primary classes:
Encoder progressively downsamples input images through resolution levels defined by ch_mult, using ResnetBlock layers (GroupNorm, Swish, 3x3 convolution) and configurable attention blocks. The attention type is selected via make_attn, which supports "vanilla" (using PyTorch 2.0 scaled_dot_product_attention), "vanilla-xformers" (using the xformers memory-efficient implementation), "linear" attention, and "none" (identity). A middle section at the bottleneck contains ResnetBlock-Attention-ResnetBlock. The output projects to 2*z_channels or z_channels.
Decoder mirrors the encoder with progressive upsampling. It uses factory methods (_make_attn, _make_resblock, _make_conv) that can be overridden by subclasses for custom behavior. The decoder supports an optional tanh_out activation on the final output and a give_pre_end flag to return features before the final normalization and convolution.
Model is a full U-Net that combines encoding and decoding with skip connections. It optionally includes timestep conditioning via sinusoidal embeddings processed through a two-layer MLP. The U-Net downsampling path stores intermediate features that are concatenated with the upsampling path features via skip connections.
The AttnBlock implementation uses scaled_dot_product_attention from PyTorch 2.0+ by default, reshaping spatial dimensions to sequence format. MemoryEfficientAttnBlock provides an xformers-based alternative for reduced memory usage on older PyTorch versions.
Usage
Use this module as the VAE encoder/decoder that bridges pixel space and the latent space for latent diffusion models. The Encoder compresses images to latent codes, and the Decoder reconstructs images from latent codes after the diffusion process.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: sat/sgm/modules/diffusionmodules/model.py
- Lines: 1-708
Signature
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type="vanilla",
**ignore_kwargs,
):
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type="vanilla",
**ignorekwargs,
):
class Model(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type="vanilla",
):
Import
from sat.sgm.modules.diffusionmodules.model import Encoder, Decoder, Model
I/O Contract
Inputs (Encoder)
| Name | Type | Required | Description |
|---|---|---|---|
| ch | int | Yes | Base channel count |
| out_ch | int | Yes | Output channels (used by counterpart) |
| ch_mult | tuple of int | No | Channel multiplier per level, default (1, 2, 4, 8) |
| num_res_blocks | int | Yes | Residual blocks per resolution level |
| attn_resolutions | list of int | Yes | Resolutions for attention layers |
| dropout | float | No | Dropout probability, default 0.0 |
| in_channels | int | Yes | Input image channels |
| resolution | int | Yes | Input spatial resolution |
| z_channels | int | Yes | Latent channel count |
| double_z | bool | No | Double z_channels for mean/variance, default True |
| attn_type | str | No | Attention backend: "vanilla", "vanilla-xformers", "linear", "none", default "vanilla" |
Inputs (Model forward)
| Name | Type | Required | Description |
|---|---|---|---|
| x | torch.Tensor | Yes | Input image tensor (B, C, H, W) |
| t | torch.Tensor | No | Timestep tensor (B,) for conditioning, required if use_timestep=True |
| context | torch.Tensor | No | Additional context concatenated along channel axis |
Outputs
| Name | Type | Description |
|---|---|---|
| h (Encoder) | torch.Tensor | Latent tensor (B, 2*z_channels or z_channels, H', W') |
| h (Decoder) | torch.Tensor | Reconstructed image (B, out_ch, H, W) |
| h (Model) | torch.Tensor | U-Net output (B, out_ch, H, W) |
Usage Examples
import torch
from sat.sgm.modules.diffusionmodules.model import Encoder, Decoder
# Create encoder
encoder = Encoder(
ch=128, out_ch=3, ch_mult=(1, 2, 4, 4),
num_res_blocks=2, attn_resolutions=[32],
in_channels=3, resolution=256, z_channels=4,
attn_type="vanilla",
)
# Create decoder
decoder = Decoder(
ch=128, out_ch=3, ch_mult=(1, 2, 4, 4),
num_res_blocks=2, attn_resolutions=[32],
in_channels=256, resolution=256, z_channels=4,
attn_type="vanilla",
)
# Encode image to latent
x = torch.randn(1, 3, 256, 256)
z = encoder(x) # shape: (1, 8, 32, 32)
# Decode latent to image
recon = decoder(z[:, :4]) # shape: (1, 3, 256, 256)
# Access last layer weight for loss computation
last_layer = decoder.get_last_layer()