Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Zai org CogVideo RIFE Model

From Leeroopedia


Knowledge Sources
Domains Video_Generation, Frame_Interpolation, Optical_Flow
Last Updated 2026-02-10 00:00 GMT

Overview

The RIFE Model class wraps IFNet into a high-level interface that provides training, inference, model checkpoint management, and evaluation for Real-Time Intermediate Flow Estimation frame interpolation.

Description

The Model class serves as the main orchestrator for the RIFE frame interpolation pipeline. It initializes the flow estimation network (standard IFNet for fixed-timestep or IFNet_m for arbitrary-timestep interpolation), an AdamW optimizer with weight decay of 1e-3, and three loss modules: EPE (endpoint error for flow supervision), LapLoss (Laplacian pyramid loss for multi-scale reconstruction), and SOBEL (edge-aware loss). It supports distributed training via PyTorch's DistributedDataParallel (DDP).

The inference method concatenates two input frames, runs them through the flow network at configurable scales, and returns the merged interpolated frame. It supports optional test-time augmentation (TTA) by running inference on horizontally and vertically flipped inputs and averaging the results with the unflipped output.

The update method handles the full training step: it passes concatenated input frames and ground truth through the flow network, computes a combined loss of Laplacian pyramid loss on the student output, Laplacian pyramid loss on the teacher output, and weighted distillation loss (0.01 weight), then backpropagates. Checkpoint loading handles DDP module prefix stripping via a convert helper.

Usage

Use the RIFE Model as the main entry point for frame interpolation tasks, both for training new models and running inference. It is the primary model class used by the Gradio composite demo.

Code Reference

Source Location

Signature

class Model:
    def __init__(self, local_rank=-1, arbitrary=False)
    def train(self)
    def eval(self)
    def device(self)
    def load_model(self, path, rank=0)
    def save_model(self, path, rank=0)
    def inference(self, img0, img1, scale=1, scale_list=[4, 2, 1], TTA=False, timestep=0.5)
    def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None)

Import

from inference.gradio_composite_demo.rife.RIFE import Model

I/O Contract

Inputs

__init__:

Name Type Required Description
local_rank int No GPU rank for DDP, default -1 (no DDP)
arbitrary bool No If True, use IFNet_m for arbitrary timestep interpolation; default False

inference:

Name Type Required Description
img0 torch.Tensor Yes First input frame, shape (B, 3, H, W)
img1 torch.Tensor Yes Second input frame, shape (B, 3, H, W)
scale int No Global scale divisor applied to scale_list, default 1
scale_list list[int] No Multi-scale factors for flow estimation, default [4, 2, 1]
TTA bool No Enable test-time augmentation via flipping, default False
timestep float No Interpolation position between frames, default 0.5

update:

Name Type Required Description
imgs torch.Tensor Yes Concatenated input frames (B, 6, H, W)
gt torch.Tensor Yes Ground truth intermediate frame (B, 3, H, W)
learning_rate float No Learning rate for this step, default 0
mul int No Multiplier (unused in current implementation), default 1
training bool No Training mode flag, default True
flow_gt torch.Tensor No Ground truth optical flow (unused in current implementation)

Outputs

inference:

Name Type Description
result torch.Tensor Interpolated frame of shape (B, 3, H, W), values in [0, 1]

update:

Name Type Description
merged torch.Tensor Interpolated frame at finest scale, shape (B, 3, H, W)
info dict Dictionary containing: merged_tea, mask, mask_tea, flow, flow_tea, loss_l1, loss_tea, loss_distill

Usage Examples

import torch
from inference.gradio_composite_demo.rife.RIFE import Model

# Initialize and load pretrained weights
model = Model(local_rank=-1, arbitrary=False)
model.load_model("/path/to/rife_weights/", rank=0)
model.eval()

# Interpolate between two frames
img0 = torch.randn(1, 3, 480, 640).cuda()
img1 = torch.randn(1, 3, 480, 640).cuda()

with torch.no_grad():
    result = model.inference(img0, img1, scale=1, TTA=False, timestep=0.5)
    # result shape: (1, 3, 480, 640)

# With test-time augmentation for higher quality
with torch.no_grad():
    result_tta = model.inference(img0, img1, TTA=True)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment