Overview
Provides the RIFE (Real-Time Intermediate Flow Estimation) frame interpolation inference pipeline that doubles video frame rates by inserting synthesized intermediate frames between consecutive original frames.
Description
This module implements the complete RIFE inference workflow for post-processing generated videos:
pad_image() -- Pads an image tensor so its height and width are multiples of max(32, int(32/scale)), which is required by the RIFE network architecture.
make_inference() -- Recursively generates intermediate frames between two input frames. For n=1 it produces a single middle frame; for larger n it recursively bisects the interval, producing 2^exp - 1 intermediate frames.
ssim_interpolation_rife() -- The core inference function. For each consecutive frame pair, it computes SSIM on 32x32 thumbnails to select the interpolation strategy: frames with SSIM > 0.996 are treated as near-duplicates (produces a slight variation), frames with SSIM < 0.2 are treated as scene cuts (duplicates the current frame), and otherwise normal recursive interpolation is applied.
load_rife_model() -- Loads the RIFE HDv3 model from a local directory and sets it to evaluation mode.
rife_inference_with_path() -- End-to-end pipeline: reads a video file via OpenCV, converts to tensors, runs SSIM-based interpolation, and exports the result as an MP4 at 16 fps.
rife_inference_with_latents() -- Operates on in-memory tensor batches from the diffusion pipeline, returning interpolated latent tensors.
Usage
Used as an optional post-processing step in the composite Gradio demo to improve video smoothness. After generating video frames with the diffusion model, call rife_inference_with_latents() on the tensor output or rife_inference_with_path() on a saved video file to double the frame rate.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: inference/gradio_composite_demo/rife_model.py
Signature
def pad_image(img, scale) -> tuple[torch.Tensor, tuple]
def make_inference(model, I0, I1, upscale_amount, n) -> list[torch.Tensor]
@torch.inference_mode()
def ssim_interpolation_rife(
model, samples, exp=1, upscale_amount=1, output_device="cpu"
) -> list[torch.Tensor]
def load_rife_model(model_path) -> Model
def frame_generator(video_capture) -> Generator
def rife_inference_with_path(model, video_path) -> str
def rife_inference_with_latents(model, latents) -> torch.Tensor
Import
from rife_model import (
load_rife_model,
rife_inference_with_path,
rife_inference_with_latents,
ssim_interpolation_rife,
)
I/O Contract
Inputs (ssim_interpolation_rife)
| Name |
Type |
Required |
Description
|
| model |
Model |
Yes |
Loaded RIFE HDv3 model instance
|
| samples |
torch.Tensor |
Yes |
Video frames tensor of shape [F, C, H, W] with values in [0, 1]
|
| exp |
int |
No |
Interpolation exponent; produces 2^exp - 1 intermediate frames per pair. Default: 1
|
| upscale_amount |
float |
No |
Scale factor for padding calculation. Default: 1
|
| output_device |
str |
No |
Device for output tensors. Default: "cpu"
|
Outputs (ssim_interpolation_rife)
| Name |
Type |
Description
|
| frames |
list[torch.Tensor] |
List of interpolated frame tensors, each of shape [1, C, H, W]
|
Inputs (rife_inference_with_latents)
| Name |
Type |
Required |
Description
|
| model |
Model |
Yes |
Loaded RIFE HDv3 model instance
|
| latents |
torch.Tensor |
Yes |
Batch of video latents of shape [B, F, C, H, W]
|
Outputs (rife_inference_with_latents)
| Name |
Type |
Description
|
| result |
torch.Tensor |
Interpolated latents of shape [B, F', C, H, W] where F' > F
|
Usage Examples
from rife_model import load_rife_model, rife_inference_with_latents
# Load the RIFE model
model = load_rife_model("model_rife")
# Interpolate frames from diffusion pipeline output
# latents shape: [batch, frames, channels, height, width]
interpolated = rife_inference_with_latents(model, latents)
# Or process a video file directly
from rife_model import rife_inference_with_path
output_path = rife_inference_with_path(model, "input_video.mp4")
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.