Implementation:VainF Torch Pruning Count Ops And Params
Appearance
Metadata
| Field | Value |
|---|---|
| Source | Torch-Pruning |
| Domains | Deep_Learning, Model_Analysis |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for counting FLOPs and parameters of a neural network provided by Torch-Pruning.
Description
count_ops_and_params deep-copies the model, instruments it with forward hooks that count operations per layer, runs a forward pass with example inputs, then collects total FLOPs and parameter counts. Adapted from flops-counter.pytorch. Supports tuple, list, or dict inputs. Optional layer_wise mode returns per-layer breakdowns.
Code Reference
- Source:
torch_pruning/utils/op_counter.py, Lines 21-47 - Signature:
@torch.no_grad()
def count_ops_and_params(model, example_inputs, layer_wise=False):
"""Count FLOPs and parameters.
Args:
model: nn.Module to profile
example_inputs: Tensor, tuple, list, or dict
layer_wise: if True, return per-layer breakdowns
Returns:
(flops_count: float, params_count: int) or
(flops_count, params_count, layer_flops: dict, layer_params: dict)
"""
- Import:
import torch_pruning as tp
tp.utils.count_ops_and_params
I/O Contract
Inputs
| Parameter | Type | Required | Default |
|---|---|---|---|
| model | nn.Module | Yes | — |
| example_inputs | Tensor / tuple / list / dict | Yes | — |
| layer_wise | bool | No | False |
Outputs
- Standard mode:
(flops_count: float, params_count: int) - With layer_wise=True:
(flops_count: float, params_count: int, layer_flops: Dict[nn.Module, float], layer_params: Dict[nn.Module, int])
Usage Examples
import torch
import torch.nn as nn
import torch_pruning as tp
# Build a simple model
model = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 3, padding=1),
)
example_inputs = torch.randn(1, 3, 224, 224)
# Profile BEFORE pruning
flops_before, params_before = tp.utils.count_ops_and_params(model, example_inputs)
print(f"Before pruning: FLOPs={flops_before:.2f}, Params={params_before}")
# ... apply pruning ...
# Profile AFTER pruning
flops_after, params_after = tp.utils.count_ops_and_params(model, example_inputs)
print(f"After pruning: FLOPs={flops_after:.2f}, Params={params_after}")
print(f"FLOPs reduction: {1 - flops_after / flops_before:.2%}")
print(f"Param reduction: {1 - params_after / params_before:.2%}")
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment