Implementation:Zai org CogVideo DiagonalGaussianDistribution
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Variational_Inference, Autoencoding |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
DiagonalGaussianDistribution is a probability distribution class that represents a multivariate Gaussian with a diagonal covariance matrix, used as the latent distribution in VAE encoders for sampling and KL divergence computation.
Description
The DiagonalGaussianDistribution class takes a parameter tensor and splits it along the channel dimension into mean and log-variance components. The log-variance is clamped to the range [-30.0, 20.0] for numerical stability. From these parameters, the standard deviation and variance are derived.
Sampling uses the reparameterization trick: x = mean + std * epsilon, where epsilon is drawn from a standard normal distribution via torch.randn_like. This formulation allows gradients to flow through the sampling operation during backpropagation.
The class provides methods for computing KL divergence against a standard normal prior or against another DiagonalGaussianDistribution instance, as well as negative log-likelihood computation. A deterministic mode is supported where sampling always returns the mean (zero variance).
The file also includes DiracDistribution (a degenerate distribution returning a fixed value) and the standalone normal_kl function for element-wise KL divergence computation between two Gaussians with full broadcasting support.
Usage
Use DiagonalGaussianDistribution as the latent distribution in VAE and VAE-based diffusion model encoders. It provides the sampling mechanism and KL regularization term needed for the VAE objective function.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File:
sat/sgm/modules/distributions/distributions.py
Signature
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
Import
from sat.sgm.modules.distributions.distributions import DiagonalGaussianDistribution
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
parameters |
Tensor |
Yes | Concatenated mean and log-variance tensor of shape (B, 2*C, H, W), which is split along dim=1 into mean and logvar each of shape (B, C, H, W).
|
deterministic |
bool |
No | If True, sets variance and standard deviation to zero, making sampling return the mean. Default: False.
|
Outputs
| Method | Return Type | Description |
|---|---|---|
sample() |
Tensor |
A sample from the distribution via the reparameterization trick, shape (B, C, H, W).
|
kl(other=None) |
Tensor |
KL divergence per sample. Against standard normal if other=None, otherwise against another DiagonalGaussianDistribution. Shape (B,).
|
nll(sample, dims) |
Tensor |
Negative log-likelihood of sample under this distribution, summed over specified dims.
|
mode() |
Tensor |
The mode (mean) of the distribution. |
Key Attributes
| Attribute | Type | Description |
|---|---|---|
mean |
Tensor |
Mean of the distribution, shape (B, C, H, W).
|
logvar |
Tensor |
Log-variance, clamped to [-30.0, 20.0].
|
std |
Tensor |
Standard deviation, computed as exp(0.5 * logvar).
|
var |
Tensor |
Variance, computed as exp(logvar).
|
Usage Examples
from sat.sgm.modules.distributions.distributions import DiagonalGaussianDistribution
# Encoder outputs concatenated mean and logvar
encoder_output = torch.randn(4, 8, 32, 32) # 2*C=8 -> C=4
# Create distribution and sample
dist = DiagonalGaussianDistribution(encoder_output)
z = dist.sample() # shape: (4, 4, 32, 32)
# Compute KL divergence against standard normal prior
kl_loss = dist.kl() # shape: (4,)
# Compute negative log-likelihood
nll = dist.nll(z) # shape: (4,)
# Deterministic mode (no stochasticity)
dist_det = DiagonalGaussianDistribution(encoder_output, deterministic=True)
z_det = dist_det.sample() # returns the mean exactly