Overview
Utility module that computes FLOPs, parameter counts, and inference FPS for PyTorch models using JIT tracing and operation-level flop counting.
Description
The flop counter module (rfdetr/util/benchmark.py) provides comprehensive model computational profiling. It uses torch.jit.get_trace_graph to trace a model and walks the computation graph node by node. Each ATen operation is matched against _SUPPORTED_OPS, a dictionary mapping 30+ operation types to their FLOP counting handlers. Handlers cover linear layers (addmm, linear), convolutions, matrix multiplications (matmul, bmm, einsum), normalization layers (batchnorm, layernorm, groupnorm), activations (relu, sigmoid, softmax), and element-wise operations. _IGNORED_OPS lists 40+ zero-compute operations (reshape, permute, slice, etc.). The benchmark function orchestrates end-to-end evaluation: counts parameters, measures FLOPs across dataset images (converting to GFLOPs), and times inference with warmup producing mean/std/min/max statistics.
Usage
Use this module during model development to profile computational cost, compare architecture variants, and verify that FLOPs and latency meet deployment budgets. Import flop_count for standalone FLOP analysis or benchmark for comprehensive profiling with FPS measurement.
Code Reference
Source Location
Signature
def flop_count(
model: nn.Module,
inputs: typing.Tuple[Any, ...],
whitelist: typing.Optional[typing.List[str]] = None,
customized_ops: typing.Optional[typing.Dict[str, typing.Callable]] = None,
) -> typing.DefaultDict[str, float]:
"""
Given a model and an input, compute the Gflops of the given model.
Note the input should have a batch size of 1.
Args:
model: The model to compute flop counts.
inputs: Inputs passed to model (must be a tuple).
whitelist: Subset of _SUPPORTED_OPS to count.
customized_ops: Custom operations and their flop handles.
Returns:
A dictionary mapping operation names to gigaflops.
"""
...
def benchmark(
model: torch.nn.Module,
dataset: Sequence[Any],
output_dir: Any,
) -> Dict[str, Any]:
"""
Compute model size, FLOPs, and FPS.
Args:
model: The PyTorch model to benchmark.
dataset: Dataset providing (image, target) tuples.
output_dir: Directory to save benchmark log.
Returns:
Dictionary with nparam, flops, time, fps, detailed_flops.
"""
...
def warmup(model: torch.nn.Module, inputs: Any, N: int = 10) -> None:
"""Run N warmup iterations with CUDA sync."""
...
def measure_time(model: torch.nn.Module, inputs: Any, N: int = 10) -> float:
"""Measure average inference time over N iterations."""
...
Import
from rfdetr.util.benchmark import flop_count, benchmark, measure_time, warmup
I/O Contract
Inputs (flop_count)
| Name |
Type |
Required |
Description
|
| model |
nn.Module |
Yes |
PyTorch model to profile
|
| inputs |
Tuple[Any, ...] |
Yes |
Model inputs (batch size 1)
|
| whitelist |
List[str] |
No |
Subset of operations to count (default: all supported)
|
| customized_ops |
Dict[str, Callable] |
No |
Custom operation flop handlers
|
Outputs (flop_count)
| Name |
Type |
Description
|
| gflops |
DefaultDict[str, float] |
Dictionary mapping operation names to gigaflop counts
|
Inputs (benchmark)
| Name |
Type |
Required |
Description
|
| model |
nn.Module |
Yes |
PyTorch model to benchmark
|
| dataset |
Sequence[Any] |
Yes |
Dataset providing (image, target) tuples
|
| output_dir |
Path |
Yes |
Directory to save benchmark log
|
Outputs (benchmark)
| Name |
Type |
Description
|
| nparam |
int |
Number of trainable parameters
|
| flops |
Dict[str, float] |
Mean/std/min/max GFLOPs across images
|
| time |
Dict[str, float] |
Mean/std/min/max inference time in seconds
|
| fps |
float |
Frames per second (1 / mean_time)
|
| detailed_flops |
Dict[str, float] |
Per-operation GFLOPs breakdown
|
Usage Examples
Standalone FLOP Count
import torch
from rfdetr.util.benchmark import flop_count
# Assume model is an RF-DETR model on CUDA
model.eval()
model.cuda()
# Create dummy input (batch size 1)
dummy_input = [torch.randn(3, 560, 560).cuda()]
gflops = flop_count(model, (dummy_input,))
total_gflops = sum(gflops.values())
print(f"Total GFLOPs: {total_gflops:.2f}")
for op, count in sorted(gflops.items(), key=lambda x: -x[1]):
print(f" {op}: {count:.4f} GFLOPs")
Measure Inference Time
from rfdetr.util.benchmark import measure_time
import torch
model.eval()
model.cuda()
input_tensor = [torch.randn(3, 560, 560).cuda()]
avg_time = measure_time(model, input_tensor, N=50)
print(f"Average inference time: {avg_time*1000:.1f}ms ({1/avg_time:.1f} FPS)")
Related Pages