Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Zai org CogVideo SAT VideoDataset

From Leeroopedia


Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources CogVideo
Domains Data_Pipeline, Video_Processing
Last Updated 2026-02-10 00:00 GMT

Overview

Concrete tool for loading video datasets in WebDataset or directory format for SAT training provided by the CogVideo SAT module.

Description

The data_video.py module provides two dataset classes for the SAT training pipeline:

  • VideoDataset: A WebDataset-based class that streams video samples from sharded tar archives. It inherits from MetaDistributedWebDataset and applies the process_fn_video processing function to each sample in the stream.
  • SFTDataset: A standard PyTorch Dataset class that reads .mp4 video files and .txt caption files from a directory structure. It walks the directory tree, collects all .mp4 files, and loads corresponding captions from .txt files located by replacing .mp4 with .txt and videos with labels in the path.

Both classes yield dictionaries with the same structure, ensuring compatibility with the downstream training loop.

Usage

Import and use these classes when creating the dataset for SAT-based CogVideoX training. The dataset class is typically selected via the YAML configuration file and instantiated by the create_dataset_function class method.

Code Reference

Source Location

  • sat/data_video.py:L321-363 (VideoDataset)
  • sat/data_video.py:L365-470 (SFTDataset)

Signature

class VideoDataset(MetaDistributedWebDataset):
    def __init__(
        self,
        path: str,
        image_size: Tuple[int, int],
        num_frames: int,
        fps: int,
        skip_frms_num: float = 0.0,
        nshards: int = sys.maxsize,
        seed: int = 1,
        meta_names: Optional[List] = None,
        shuffle_buffer: int = 1000,
        include_dirs: Optional[str] = None,
        txt_key: str = "caption",
        **kwargs,
    )

class SFTDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        data_dir: str,
        video_size: Tuple[int, int],
        fps: int,
        max_num_frames: int,
        skip_frms_num: float = 3,
    )

Import

from data_video import VideoDataset, SFTDataset  # within sat/ directory

I/O Contract

Inputs

Parameter Type Required Description
path / data_dir str Yes Path to WebDataset tar shards (VideoDataset) or directory containing .mp4 files (SFTDataset).
image_size / video_size Tuple[int, int] Yes Target spatial resolution as (height, width). Standard CogVideoX-2B uses (480, 720).
num_frames / max_num_frames int Yes Number of frames to sample per video. Standard CogVideoX-2B uses 49.
fps int Yes Target frames per second for temporal sampling. Standard CogVideoX uses 8.
skip_frms_num float No Number of frames to skip at beginning and end of video to avoid transitions. Default: 0.0 for VideoDataset, 3 for SFTDataset.
nshards int No Maximum number of tar shards to use (VideoDataset only). Default: sys.maxsize.
seed int No Random seed for shuffling (VideoDataset only). Use -1 for random seed. Default: 1.
shuffle_buffer int No Size of the in-memory shuffle buffer (VideoDataset only). Default: 1000.
txt_key str No Key name for the caption field in WebDataset samples (VideoDataset only). Default: "caption".

Outputs

Output Type Description
Sample dictionary Dict Each sample yields {"mp4": Tensor[T,C,H,W], "txt": str, "num_frames": int, "fps": int}. The mp4 tensor contains normalized pixel values in [-1, 1]. The num_frames field contains the actual frame count (may differ from max_num_frames for short videos in SFTDataset). The fps field contains the target FPS value.

Usage Examples

Using SFTDataset for Custom Fine-tuning

from data_video import SFTDataset

dataset = SFTDataset(
    data_dir="path/to/videos",
    video_size=(480, 720),
    fps=8,
    max_num_frames=49,
    skip_frms_num=3.0,
)

sample = dataset[0]
# sample["mp4"].shape -> torch.Size([49, 3, 480, 720])
# sample["txt"] -> "A caption describing the video content"
# sample["num_frames"] -> 49
# sample["fps"] -> 8

Using VideoDataset for Large-Scale Training

from data_video import VideoDataset

dataset = VideoDataset(
    path="/data/video_shards",
    image_size=(480, 720),
    num_frames=49,
    fps=8,
    skip_frms_num=0.0,
    shuffle_buffer=1000,
)

YAML Configuration (SFTDataset)

data:
  target: data_video.SFTDataset
  params:
    video_size: [480, 720]
    fps: 8
    max_num_frames: 49
    skip_frms_num: 3.

External Dependencies

  • decord: Video frame decoding via VideoReader.
  • torchvision.transforms: Spatial resizing and cropping operations.
  • webdataset: Tar shard streaming via MetaDistributedWebDataset (for VideoDataset).

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment