Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Zai org CogVideo IFNet

From Leeroopedia


Knowledge Sources
Domains Video_Generation, Optical_Flow, Frame_Interpolation
Last Updated 2026-02-10 00:00 GMT

Overview

The Intermediate Flow Network (IFNet) estimates bidirectional optical flow and blending masks for video frame interpolation using a coarse-to-fine multi-scale architecture with teacher distillation.

Description

IFNet is the core neural network component of the RIFE (Real-Time Intermediate Flow Estimation) frame interpolation pipeline. It consists of three student IFBlock stages operating at progressively finer scales (4x, 2x, 1x) and one teacher block (block_tea) that receives the ground-truth frame during training for knowledge distillation.

Each IFBlock downscales input features by 2x using two strided convolutions with PReLU activations, processes them through eight residual convolution layers, and produces a 5-channel output via transposed convolution: 4 channels for bidirectional optical flow (2 channels per direction) and 1 channel for a blending mask. The flow is iteratively refined across scales, with each stage receiving the warped images and accumulated flow from the previous stage.

At inference, the three student blocks produce progressively refined flow fields and blending masks. The final result is computed by warping both input frames according to the estimated flows and blending them using the sigmoid-activated mask. A Contextnet and Unet (from the refine module) produce a residual correction applied to the final merged output to enhance detail.

Usage

Use IFNet as the primary flow estimation backbone for fixed-timestep frame interpolation in the RIFE pipeline. It is instantiated and managed by the Model class in RIFE.py.

Code Reference

Source Location

  • Repository: Zai_org_CogVideo
  • File: inference/gradio_composite_demo/rife/IFNet.py

Signature

class IFBlock(nn.Module):
    def __init__(self, in_planes, c=64)
    def forward(self, x, flow, scale) -> Tuple[torch.Tensor, torch.Tensor]

class IFNet(nn.Module):
    def __init__(self)
    def forward(self, x, scale=[4, 2, 1], timestep=0.5) -> Tuple[list, torch.Tensor, list, Optional[torch.Tensor], Optional[torch.Tensor], float]

Import

from inference.gradio_composite_demo.rife.IFNet import IFNet, IFBlock

I/O Contract

Inputs

IFNet.forward:

Name Type Required Description
x torch.Tensor Yes Concatenated input tensor: img0 (3ch), img1 (3ch), and optionally gt (3ch) along channel dimension. Shape: (B, 6, H, W) at inference or (B, 9, H, W) during training
scale list[int] No Multi-scale factors for the three IFBlock stages, default [4, 2, 1]
timestep float No Interpolation timestep, default 0.5 (midpoint between frames)

IFBlock.forward:

Name Type Required Description
x torch.Tensor Yes Concatenated image features and optionally accumulated mask
flow torch.Tensor or None Yes Accumulated optical flow from previous stages, or None for the first stage
scale int Yes Current scale factor for input downsampling

Outputs

IFNet.forward:

Name Type Description
flow_list list[torch.Tensor] Optical flow fields at each scale, each of shape (B, 4, H, W)
mask torch.Tensor Final blending mask of shape (B, 1, H, W)
merged list[torch.Tensor] Blended interpolated frames at each scale, each of shape (B, 3, H, W)
flow_teacher torch.Tensor or None Teacher flow field (training only), None at inference
merged_teacher torch.Tensor or None Teacher merged frame (training only), None at inference
loss_distill float Distillation loss between student and teacher flows (0 at inference)

Usage Examples

import torch
from inference.gradio_composite_demo.rife.IFNet import IFNet

model = IFNet()
model.eval()

# Two input frames concatenated along channel dim
img0 = torch.randn(1, 3, 256, 256)
img1 = torch.randn(1, 3, 256, 256)
x = torch.cat((img0, img1), dim=1)  # (1, 6, 256, 256)

with torch.no_grad():
    flow_list, mask, merged, flow_teacher, merged_teacher, loss_distill = model(x)
    interpolated_frame = merged[2]  # Final scale result: (1, 3, 256, 256)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment