Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Zai org CogVideo RIFE Loss Functions

From Leeroopedia


Knowledge Sources
Domains Video_Generation, Loss_Functions, Optical_Flow
Last Updated 2026-02-10 00:00 GMT

Overview

A collection of loss function modules used for training RIFE frame interpolation models, including endpoint error (EPE), ternary census loss, Sobel edge loss, and VGG perceptual loss.

Description

This module defines five loss-related classes used in the RIFE training pipeline:

  • EPE (Endpoint Error): Computes the L2 distance between predicted and ground-truth optical flow, masked by a loss region. The loss map is computed as the square root of the sum of squared per-channel differences, multiplied by a binary mask.
  • Ternary: Implements the census transform loss for structural comparison. It extracts 7x7 local patches via convolution with an identity kernel, computes normalized differences from the center pixel (using the transformation transf / sqrt(0.81 + transf^2)), converts images to grayscale, and measures a robust Hamming distance between transformed patches. A valid mask excludes border pixels.
  • SOBEL: Applies horizontal and vertical Sobel edge-detection kernels to both prediction and ground truth, computing L1 differences on the edge maps. The kernels [[1,0,-1],[2,0,-2],[1,0,-1]] and its transpose detect horizontal and vertical edges respectively.
  • MeanShift: A 1x1 convolution layer that normalizes or denormalizes images using ImageNet mean and standard deviation statistics. Used as a preprocessing step for VGG feature extraction.
  • VGGPerceptualLoss: Extracts features from 5 layers of a pretrained VGG-19 network (indices 2, 7, 12, 21, 30) and computes weighted L1 differences. Layer weights [1/2.6, 1/4.8, 1/3.7, 1/5.6, 10/1.5] balance contributions from early texture features and deeper semantic features.

Usage

Use these loss functions when training RIFE models. EPE supervises flow accuracy, Ternary and SOBEL enforce structural and edge consistency, and VGGPerceptualLoss ensures perceptual quality. They are instantiated by the Model class in RIFE.py.

Code Reference

Source Location

Signature

class EPE(nn.Module):
    def __init__(self)
    def forward(self, flow: torch.Tensor, gt: torch.Tensor, loss_mask: torch.Tensor) -> torch.Tensor

class Ternary(nn.Module):
    def __init__(self)
    def transform(self, img: torch.Tensor) -> torch.Tensor
    def rgb2gray(self, rgb: torch.Tensor) -> torch.Tensor
    def hamming(self, t1: torch.Tensor, t2: torch.Tensor) -> torch.Tensor
    def valid_mask(self, t: torch.Tensor, padding: int) -> torch.Tensor
    def forward(self, img0: torch.Tensor, img1: torch.Tensor) -> torch.Tensor

class SOBEL(nn.Module):
    def __init__(self)
    def forward(self, pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor

class MeanShift(nn.Conv2d):
    def __init__(self, data_mean: list, data_std: list, data_range=1, norm=True)

class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self, rank=0)
    def forward(self, X: torch.Tensor, Y: torch.Tensor, indices=None) -> torch.Tensor

Import

from inference.gradio_composite_demo.rife.loss import EPE, Ternary, SOBEL, VGGPerceptualLoss, MeanShift

I/O Contract

Inputs

EPE.forward:

Name Type Required Description
flow torch.Tensor Yes Predicted optical flow tensor
gt torch.Tensor Yes Ground truth optical flow tensor (detached internally)
loss_mask torch.Tensor Yes Binary mask indicating valid flow regions

Ternary.forward:

Name Type Required Description
img0 torch.Tensor Yes First RGB image tensor (B, 3, H, W)
img1 torch.Tensor Yes Second RGB image tensor (B, 3, H, W)

SOBEL.forward:

Name Type Required Description
pred torch.Tensor Yes Predicted image tensor (B, C, H, W)
gt torch.Tensor Yes Ground truth image tensor (B, C, H, W)

VGGPerceptualLoss.forward:

Name Type Required Description
X torch.Tensor Yes Predicted image tensor (B, 3, H, W)
Y torch.Tensor Yes Target image tensor (B, 3, H, W)
indices list No VGG layer indices for feature extraction (overridden to [2, 7, 12, 21, 30] internally)

Outputs

Name Type Description
EPE output torch.Tensor Per-pixel L2 flow error map multiplied by loss mask
Ternary output torch.Tensor Census transform distance map with border mask, shape (B, 1, H, W)
SOBEL output torch.Tensor Sum of horizontal and vertical edge L1 differences, shape (B*C, 1, H, W)
VGGPerceptualLoss output torch.Tensor Scalar weighted sum of L1 feature differences across 5 VGG-19 layers

Usage Examples

import torch
from inference.gradio_composite_demo.rife.loss import EPE, Ternary, SOBEL, VGGPerceptualLoss

# Endpoint Error for flow supervision
epe = EPE()
pred_flow = torch.randn(2, 2, 256, 256).cuda()
gt_flow = torch.randn(2, 2, 256, 256).cuda()
mask = torch.ones(2, 1, 256, 256).cuda()
flow_loss = epe(pred_flow, gt_flow, mask)

# Ternary census loss for structural similarity
ternary = Ternary()
img0 = torch.randn(2, 3, 256, 256).cuda()
img1 = torch.randn(2, 3, 256, 256).cuda()
census_loss = ternary(img0, img1)

# Sobel edge loss
sobel = SOBEL()
edge_loss = sobel(img0, img1)

# VGG perceptual loss
vgg_loss_fn = VGGPerceptualLoss(rank=0)
perceptual_loss = vgg_loss_fn(img0, img1)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment