Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Zai org CogVideo SSIM MS SSIM

From Leeroopedia


Knowledge Sources
Domains Video_Generation, Image_Quality_Assessment, Loss_Functions
Last Updated 2026-02-10 00:00 GMT

Overview

Implements Structural Similarity Index (SSIM) and Multi-Scale SSIM (MS-SSIM) as differentiable PyTorch modules for image quality evaluation and training loss computation.

Description

This module provides both functional and class-based interfaces for computing SSIM and MS-SSIM metrics. It includes three SSIM computation variants:

  • ssim: Standard 2D SSIM using an 11x11 Gaussian window (sigma=1.5) with replicate padding. Computes local means (mu), variances (sigma^2), and covariance (sigma12) via grouped convolution, then combines them into the SSIM map using stability constants C1=(0.01*L)^2 and C2=(0.03*L)^2, where L is the dynamic range. Automatically detects the value range from the input data.
  • ssim_matlab: A MATLAB-compatible variant that treats color channels as a volumetric (3D) dimension, using 3D convolution for computing statistics. This matches the MATLAB implementation commonly used in academic benchmarks, enabling fair comparison with published results.
  • msssim: Multi-Scale SSIM that computes SSIM at 5 scales by iteratively applying 2x average pooling to both images. It uses the standard MS-SSIM weights [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] and combines contrast sensitivity (CS) and SSIM values as: output = prod(CS_k^w_k) for k=0..3 * SSIM_4^w_4.

The SSIM class wraps the functional ssim with cached Gaussian windows for efficient reuse and returns the dissimilarity (1 - SSIM) / 2 for use as a loss function. The MSSSIM class wraps the functional msssim.

Helper functions gaussian, create_window, and create_window_3d construct the Gaussian kernels from a 1D Gaussian with sigma=1.5, expanded via outer products.

Usage

Use SSIM and MS-SSIM for evaluating RIFE interpolation results against ground-truth frames. The MATLAB-compatible variant (ssim_matlab) enables fair comparison with benchmark results from the academic literature. The differentiable SSIM class can also be used as a training loss.

Code Reference

Source Location

  • Repository: Zai_org_CogVideo
  • File: inference/gradio_composite_demo/rife/pytorch_msssim/__init__.py

Signature

def gaussian(window_size: int, sigma: float) -> torch.Tensor

def create_window(window_size: int, channel=1) -> torch.Tensor

def create_window_3d(window_size: int, channel=1) -> torch.Tensor

def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None) -> torch.Tensor

def ssim_matlab(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None) -> torch.Tensor

def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False) -> torch.Tensor

class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, val_range=None)
    def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor

class MSSSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, channel=3)
    def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor

Import

from inference.gradio_composite_demo.rife.pytorch_msssim import SSIM, MSSSIM, ssim, ssim_matlab, msssim

I/O Contract

Inputs

ssim / ssim_matlab:

Name Type Required Description
img1 torch.Tensor Yes First image tensor of shape (B, C, H, W)
img2 torch.Tensor Yes Second image tensor of shape (B, C, H, W)
window_size int No Size of the Gaussian window, default 11
window torch.Tensor No Pre-computed Gaussian window for reuse
size_average bool No If True, return scalar mean; if False, return per-sample values. Default True
full bool No If True, return both SSIM and contrast sensitivity (CS). Default False
val_range float No Dynamic range of input values. Auto-detected if None

msssim:

Name Type Required Description
img1 torch.Tensor Yes First image tensor of shape (B, C, H, W)
img2 torch.Tensor Yes Second image tensor of shape (B, C, H, W)
window_size int No Size of the Gaussian window, default 11
size_average bool No If True, return scalar mean. Default True
val_range float No Dynamic range of input values. Auto-detected if None
normalize bool No If True, normalize SSIM and CS to [0, 1] range for training stability. Default False

Outputs

Name Type Description
ssim_value torch.Tensor SSIM index in [-1, 1] (scalar if size_average=True, per-sample otherwise)
cs torch.Tensor Contrast sensitivity value (only returned when full=True)
msssim_value torch.Tensor Multi-Scale SSIM index (scalar if size_average=True)
dssim (SSIM class) torch.Tensor Dissimilarity score computed as (1 - SSIM) / 2, in [0, 1]

Usage Examples

import torch
from inference.gradio_composite_demo.rife.pytorch_msssim import SSIM, MSSSIM, ssim_matlab

# Functional SSIM (MATLAB-compatible for benchmarking)
img1 = torch.randn(1, 3, 256, 256).cuda()
img2 = torch.randn(1, 3, 256, 256).cuda()
score = ssim_matlab(img1, img2, val_range=1.0)
print(f"SSIM (MATLAB): {score.item():.4f}")

# SSIM as a training loss (returns dissimilarity)
ssim_loss = SSIM(window_size=11, size_average=True, val_range=1.0).cuda()
loss = ssim_loss(img1, img2)
loss.backward()

# Multi-Scale SSIM for evaluation
msssim_metric = MSSSIM(window_size=11, size_average=True, channel=3).cuda()
ms_score = msssim_metric(img1, img2)
print(f"MS-SSIM: {ms_score.item():.4f}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment