Implementation:Zai org CogVideo IFNet
| Knowledge Sources | |
|---|---|
| Domains | Video_Generation, Optical_Flow, Frame_Interpolation |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
The Intermediate Flow Network (IFNet) estimates bidirectional optical flow and blending masks for video frame interpolation using a coarse-to-fine multi-scale architecture with teacher distillation.
Description
IFNet is the core neural network component of the RIFE (Real-Time Intermediate Flow Estimation) frame interpolation pipeline. It consists of three student IFBlock stages operating at progressively finer scales (4x, 2x, 1x) and one teacher block (block_tea) that receives the ground-truth frame during training for knowledge distillation.
Each IFBlock downscales input features by 2x using two strided convolutions with PReLU activations, processes them through eight residual convolution layers, and produces a 5-channel output via transposed convolution: 4 channels for bidirectional optical flow (2 channels per direction) and 1 channel for a blending mask. The flow is iteratively refined across scales, with each stage receiving the warped images and accumulated flow from the previous stage.
At inference, the three student blocks produce progressively refined flow fields and blending masks. The final result is computed by warping both input frames according to the estimated flows and blending them using the sigmoid-activated mask. A Contextnet and Unet (from the refine module) produce a residual correction applied to the final merged output to enhance detail.
Usage
Use IFNet as the primary flow estimation backbone for fixed-timestep frame interpolation in the RIFE pipeline. It is instantiated and managed by the Model class in RIFE.py.
Code Reference
Source Location
- Repository: Zai_org_CogVideo
- File: inference/gradio_composite_demo/rife/IFNet.py
Signature
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64)
def forward(self, x, flow, scale) -> Tuple[torch.Tensor, torch.Tensor]
class IFNet(nn.Module):
def __init__(self)
def forward(self, x, scale=[4, 2, 1], timestep=0.5) -> Tuple[list, torch.Tensor, list, Optional[torch.Tensor], Optional[torch.Tensor], float]
Import
from inference.gradio_composite_demo.rife.IFNet import IFNet, IFBlock
I/O Contract
Inputs
IFNet.forward:
| Name | Type | Required | Description |
|---|---|---|---|
| x | torch.Tensor | Yes | Concatenated input tensor: img0 (3ch), img1 (3ch), and optionally gt (3ch) along channel dimension. Shape: (B, 6, H, W) at inference or (B, 9, H, W) during training |
| scale | list[int] | No | Multi-scale factors for the three IFBlock stages, default [4, 2, 1] |
| timestep | float | No | Interpolation timestep, default 0.5 (midpoint between frames) |
IFBlock.forward:
| Name | Type | Required | Description |
|---|---|---|---|
| x | torch.Tensor | Yes | Concatenated image features and optionally accumulated mask |
| flow | torch.Tensor or None | Yes | Accumulated optical flow from previous stages, or None for the first stage |
| scale | int | Yes | Current scale factor for input downsampling |
Outputs
IFNet.forward:
| Name | Type | Description |
|---|---|---|
| flow_list | list[torch.Tensor] | Optical flow fields at each scale, each of shape (B, 4, H, W) |
| mask | torch.Tensor | Final blending mask of shape (B, 1, H, W) |
| merged | list[torch.Tensor] | Blended interpolated frames at each scale, each of shape (B, 3, H, W) |
| flow_teacher | torch.Tensor or None | Teacher flow field (training only), None at inference |
| merged_teacher | torch.Tensor or None | Teacher merged frame (training only), None at inference |
| loss_distill | float | Distillation loss between student and teacher flows (0 at inference) |
Usage Examples
import torch
from inference.gradio_composite_demo.rife.IFNet import IFNet
model = IFNet()
model.eval()
# Two input frames concatenated along channel dim
img0 = torch.randn(1, 3, 256, 256)
img1 = torch.randn(1, 3, 256, 256)
x = torch.cat((img0, img1), dim=1) # (1, 6, 256, 256)
with torch.no_grad():
flow_list, mask, merged, flow_teacher, merged_teacher, loss_distill = model(x)
interpolated_frame = merged[2] # Final scale result: (1, 3, 256, 256)