Implementation:Zai org CogVideo SSIM MS SSIM
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Image_Quality_Assessment, Loss_Functions |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Implements Structural Similarity Index (SSIM) and Multi-Scale SSIM (MS-SSIM) as differentiable PyTorch modules for image quality evaluation and training loss computation.
Description
This module provides both functional and class-based interfaces for computing SSIM and MS-SSIM metrics. It includes three SSIM computation variants:
- ssim: Standard 2D SSIM using an 11x11 Gaussian window (sigma=1.5) with replicate padding. Computes local means (mu), variances (sigma^2), and covariance (sigma12) via grouped convolution, then combines them into the SSIM map using stability constants C1=(0.01*L)^2 and C2=(0.03*L)^2, where L is the dynamic range. Automatically detects the value range from the input data.
- ssim_matlab: A MATLAB-compatible variant that treats color channels as a volumetric (3D) dimension, using 3D convolution for computing statistics. This matches the MATLAB implementation commonly used in academic benchmarks, enabling fair comparison with published results.
- msssim: Multi-Scale SSIM that computes SSIM at 5 scales by iteratively applying 2x average pooling to both images. It uses the standard MS-SSIM weights [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] and combines contrast sensitivity (CS) and SSIM values as:
output = prod(CS_k^w_k) for k=0..3 * SSIM_4^w_4.
The SSIM class wraps the functional ssim with cached Gaussian windows for efficient reuse and returns the dissimilarity (1 - SSIM) / 2 for use as a loss function. The MSSSIM class wraps the functional msssim.
Helper functions gaussian, create_window, and create_window_3d construct the Gaussian kernels from a 1D Gaussian with sigma=1.5, expanded via outer products.
Usage
Use SSIM and MS-SSIM for evaluating RIFE interpolation results against ground-truth frames. The MATLAB-compatible variant (ssim_matlab) enables fair comparison with benchmark results from the academic literature. The differentiable SSIM class can also be used as a training loss.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: inference/gradio_composite_demo/rife/pytorch_msssim/__init__.py
Signature
def gaussian(window_size: int, sigma: float) -> torch.Tensor
def create_window(window_size: int, channel=1) -> torch.Tensor
def create_window_3d(window_size: int, channel=1) -> torch.Tensor
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None) -> torch.Tensor
def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None) -> torch.Tensor
def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False) -> torch.Tensor
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, val_range=None)
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor
class MSSSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True, channel=3)
def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor
Import
from inference.gradio_composite_demo.rife.pytorch_msssim import SSIM, MSSSIM, ssim, ssim_matlab, msssim
I/O Contract
Inputs
ssim / ssim_matlab:
| Name | Type | Required | Description |
|---|---|---|---|
| img1 | torch.Tensor | Yes | First image tensor of shape (B, C, H, W) |
| img2 | torch.Tensor | Yes | Second image tensor of shape (B, C, H, W) |
| window_size | int | No | Size of the Gaussian window, default 11 |
| window | torch.Tensor | No | Pre-computed Gaussian window for reuse |
| size_average | bool | No | If True, return scalar mean; if False, return per-sample values. Default True |
| full | bool | No | If True, return both SSIM and contrast sensitivity (CS). Default False |
| val_range | float | No | Dynamic range of input values. Auto-detected if None |
msssim:
| Name | Type | Required | Description |
|---|---|---|---|
| img1 | torch.Tensor | Yes | First image tensor of shape (B, C, H, W) |
| img2 | torch.Tensor | Yes | Second image tensor of shape (B, C, H, W) |
| window_size | int | No | Size of the Gaussian window, default 11 |
| size_average | bool | No | If True, return scalar mean. Default True |
| val_range | float | No | Dynamic range of input values. Auto-detected if None |
| normalize | bool | No | If True, normalize SSIM and CS to [0, 1] range for training stability. Default False |
Outputs
| Name | Type | Description |
|---|---|---|
| ssim_value | torch.Tensor | SSIM index in [-1, 1] (scalar if size_average=True, per-sample otherwise) |
| cs | torch.Tensor | Contrast sensitivity value (only returned when full=True) |
| msssim_value | torch.Tensor | Multi-Scale SSIM index (scalar if size_average=True) |
| dssim (SSIM class) | torch.Tensor | Dissimilarity score computed as (1 - SSIM) / 2, in [0, 1] |
Usage Examples
import torch
from inference.gradio_composite_demo.rife.pytorch_msssim import SSIM, MSSSIM, ssim_matlab
# Functional SSIM (MATLAB-compatible for benchmarking)
img1 = torch.randn(1, 3, 256, 256).cuda()
img2 = torch.randn(1, 3, 256, 256).cuda()
score = ssim_matlab(img1, img2, val_range=1.0)
print(f"SSIM (MATLAB): {score.item():.4f}")
# SSIM as a training loss (returns dissimilarity)
ssim_loss = SSIM(window_size=11, size_average=True, val_range=1.0).cuda()
loss = ssim_loss(img1, img2)
loss.backward()
# Multi-Scale SSIM for evaluation
msssim_metric = MSSSIM(window_size=11, size_average=True, channel=3).cuda()
ms_score = msssim_metric(img1, img2)
print(f"MS-SSIM: {ms_score.item():.4f}")