Implementation:NVIDIA NeMo Curator TransNetV2Model
| Knowledge Sources | |
|---|---|
| Domains | Video Processing, Deep Learning, Shot Detection |
| Last Updated | 2026-02-14 00:00 GMT |
Overview
The TransNetV2 class implements a deep neural network for fast shot transition detection in video, used to split videos into scenes and clips in the NeMo Curator video processing pipeline.
Description
The module implements the TransNetV2 architecture as described in the paper "TransNet V2: An effective deep network architecture for fast shot transition detection" (Soucek & Lokoc, 2020). The architecture consists of several interconnected neural network components:
- _TransNetV2 (internal): The core model class extending
nn.Module. It chains stacked dilated dense 3D CNNs with optional frame similarity and color histogram features. The model expects input tensors of shape[B, T, 27, 48, 3]as uint8 and outputs per-frame shot transition probabilities via sigmoid activation.
- StackedDDCNNV2: Stacked dilated dense convolutional blocks with shortcut connections and optional stochastic depth dropout. Each block contains multiple
DilatedDCNNV2layers followed by average or max pooling.
- DilatedDCNNV2: Dilated dense convolutional layers that concatenate outputs from four parallel 3D convolutions at dilation rates 1, 2, 4, and 8, with optional batch normalization and activation.
- Conv3DConfigurable: Configurable 3D convolution supporting (2+1)D separable convolutions, which factorize 3D convolution into spatial and temporal components for efficiency.
- FrameSimilarity: Projects multi-scale features into a similarity space and computes windowed cosine similarity matrices between frames, enabling the model to detect visual changes.
- ColorHistograms: Computes 512-bin RGB color histograms for each frame and uses windowed similarity to detect color distribution changes across frames.
- TransNetV2 (public interface): Wraps
_TransNetV2as a ModelInterface. It handles weight loading from HuggingFace (Sn4kehead/TransNetV2) and provides a callable interface for inference.
The default configuration uses 3 layers with residual factor 16, frame similarity, and color histograms, resulting in a model with 1024-dimensional intermediate features and dropout rate of 0.5.
Usage
Use TransNetV2 as the first step in video processing pipelines where you need to detect shot boundaries and split long videos into coherent clips. Shot transition detection is a prerequisite for downstream video curation tasks such as clip-level embedding, filtering, and captioning.
Code Reference
Source Location
- Repository: NeMo-Curator
- File:
nemo_curator/models/transnetv2.py - Lines: 1-616
Signature
class TransNetV2(ModelInterface):
def __init__(self, model_dir: str | None = None) -> None: ...
def setup(self) -> None: ...
def __call__(self, inputs: torch.Tensor) -> torch.Tensor: ...
@classmethod
def download_weights_on_node(cls, model_dir: str) -> None: ...
# Internal model (not intended for direct use)
class _TransNetV2(nn.Module):
def __init__(
self,
rf: int = 16,
rl: int = 3,
rs: int = 2,
rd: int = 1024,
*,
use_many_hot_targets: bool = True,
use_frame_similarity: bool = True,
use_color_histograms: bool = True,
use_mean_pooling: bool = False,
dropout_rate: float = 0.5,
) -> None: ...
def forward(self, inputs: torch.Tensor) -> torch.Tensor: ...
Import
from nemo_curator.models.transnetv2 import TransNetV2
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_dir | None | No | Directory where model weights are stored |
| inputs (callable) | torch.Tensor |
Yes | Tensor of shape [B, T, 27, 48, 3] with dtype torch.uint8 representing batched video frames (B=batch, T=frames, 27x48 spatial resolution, 3 RGB channels)
|
Outputs
| Name | Type | Description |
|---|---|---|
| predictions | torch.Tensor |
Tensor of shape [B, T, 1] containing per-frame shot transition probabilities (0.0 to 1.0 via sigmoid)
|
Architecture
The TransNetV2 network processes video frames through the following stages:
| Stage | Component | Description |
|---|---|---|
| 1 | Input normalization | Converts uint8 [B, T, H, W, 3] to float [B, 3, T, H, W] and divides by 255 |
| 2 | StackedDDCNNV2 (x3) | Three layers of stacked dilated dense 3D CNNs with increasing filter counts (16, 32, 64) and average pooling |
| 3 | FrameSimilarity | Windowed cosine similarity features from multi-scale block outputs (128-dim output) |
| 4 | ColorHistograms | 512-bin RGB histogram similarity features across a 101-frame window (128-dim output) |
| 5 | FC + ReLU + Dropout | Fully connected layer (output_dim -> 1024) with ReLU and 0.5 dropout |
| 6 | Classification + Sigmoid | Linear layer (1024 -> 1) with sigmoid for per-frame probability |
Usage Examples
Basic Usage
from nemo_curator.models.transnetv2 import TransNetV2
import torch
# Download weights first
TransNetV2.download_weights_on_node(model_dir="/path/to/models")
# Initialize and set up
model = TransNetV2(model_dir="/path/to/models")
model.setup()
# Run inference on video frames
# frames shape: [batch_size, num_frames, 27, 48, 3], dtype: uint8
frames = torch.randint(0, 255, (1, 100, 27, 48, 3), dtype=torch.uint8).cuda()
transition_probs = model(frames)
# transition_probs shape: [1, 100, 1], values between 0 and 1