Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Zai org CogVideo DiagonalGaussianDistribution

From Leeroopedia


Knowledge Sources
Domains Video_Generation, Variational_Inference, Autoencoding
Last Updated 2026-02-10 00:00 GMT

Overview

DiagonalGaussianDistribution is a probability distribution class that represents a multivariate Gaussian with a diagonal covariance matrix, used as the latent distribution in VAE encoders for sampling and KL divergence computation.

Description

The DiagonalGaussianDistribution class takes a parameter tensor and splits it along the channel dimension into mean and log-variance components. The log-variance is clamped to the range [-30.0, 20.0] for numerical stability. From these parameters, the standard deviation and variance are derived.

Sampling uses the reparameterization trick: x = mean + std * epsilon, where epsilon is drawn from a standard normal distribution via torch.randn_like. This formulation allows gradients to flow through the sampling operation during backpropagation.

The class provides methods for computing KL divergence against a standard normal prior or against another DiagonalGaussianDistribution instance, as well as negative log-likelihood computation. A deterministic mode is supported where sampling always returns the mean (zero variance).

The file also includes DiracDistribution (a degenerate distribution returning a fixed value) and the standalone normal_kl function for element-wise KL divergence computation between two Gaussians with full broadcasting support.

Usage

Use DiagonalGaussianDistribution as the latent distribution in VAE and VAE-based diffusion model encoders. It provides the sampling mechanism and KL regularization term needed for the VAE objective function.

Code Reference

Source Location

  • Repository: Zai_org_CogVideo
  • File: sat/sgm/modules/distributions/distributions.py

Signature

class DiagonalGaussianDistribution(object):
    def __init__(self, parameters, deterministic=False):

Import

from sat.sgm.modules.distributions.distributions import DiagonalGaussianDistribution

I/O Contract

Inputs

Name Type Required Description
parameters Tensor Yes Concatenated mean and log-variance tensor of shape (B, 2*C, H, W), which is split along dim=1 into mean and logvar each of shape (B, C, H, W).
deterministic bool No If True, sets variance and standard deviation to zero, making sampling return the mean. Default: False.

Outputs

Method Return Type Description
sample() Tensor A sample from the distribution via the reparameterization trick, shape (B, C, H, W).
kl(other=None) Tensor KL divergence per sample. Against standard normal if other=None, otherwise against another DiagonalGaussianDistribution. Shape (B,).
nll(sample, dims) Tensor Negative log-likelihood of sample under this distribution, summed over specified dims.
mode() Tensor The mode (mean) of the distribution.

Key Attributes

Attribute Type Description
mean Tensor Mean of the distribution, shape (B, C, H, W).
logvar Tensor Log-variance, clamped to [-30.0, 20.0].
std Tensor Standard deviation, computed as exp(0.5 * logvar).
var Tensor Variance, computed as exp(logvar).

Usage Examples

from sat.sgm.modules.distributions.distributions import DiagonalGaussianDistribution

# Encoder outputs concatenated mean and logvar
encoder_output = torch.randn(4, 8, 32, 32)  # 2*C=8 -> C=4

# Create distribution and sample
dist = DiagonalGaussianDistribution(encoder_output)
z = dist.sample()          # shape: (4, 4, 32, 32)

# Compute KL divergence against standard normal prior
kl_loss = dist.kl()        # shape: (4,)

# Compute negative log-likelihood
nll = dist.nll(z)          # shape: (4,)

# Deterministic mode (no stochasticity)
dist_det = DiagonalGaussianDistribution(encoder_output, deterministic=True)
z_det = dist_det.sample()  # returns the mean exactly

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment