Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:VainF Torch Pruning Count Ops And Params

From Leeroopedia


Metadata

Field Value
Source Torch-Pruning
Domains Deep_Learning, Model_Analysis
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete tool for counting FLOPs and parameters of a neural network provided by Torch-Pruning.

Description

count_ops_and_params deep-copies the model, instruments it with forward hooks that count operations per layer, runs a forward pass with example inputs, then collects total FLOPs and parameter counts. Adapted from flops-counter.pytorch. Supports tuple, list, or dict inputs. Optional layer_wise mode returns per-layer breakdowns.

Code Reference

  • Source: torch_pruning/utils/op_counter.py, Lines 21-47
  • Signature:
@torch.no_grad()
def count_ops_and_params(model, example_inputs, layer_wise=False):
    """Count FLOPs and parameters.

    Args:
        model: nn.Module to profile
        example_inputs: Tensor, tuple, list, or dict
        layer_wise: if True, return per-layer breakdowns

    Returns:
        (flops_count: float, params_count: int) or
        (flops_count, params_count, layer_flops: dict, layer_params: dict)
    """
  • Import:
import torch_pruning as tp
tp.utils.count_ops_and_params

I/O Contract

Inputs

Parameter Type Required Default
model nn.Module Yes
example_inputs Tensor / tuple / list / dict Yes
layer_wise bool No False

Outputs

  • Standard mode: (flops_count: float, params_count: int)
  • With layer_wise=True: (flops_count: float, params_count: int, layer_flops: Dict[nn.Module, float], layer_params: Dict[nn.Module, int])

Usage Examples

import torch
import torch.nn as nn
import torch_pruning as tp

# Build a simple model
model = nn.Sequential(
    nn.Conv2d(3, 64, 3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.Conv2d(64, 128, 3, padding=1),
)

example_inputs = torch.randn(1, 3, 224, 224)

# Profile BEFORE pruning
flops_before, params_before = tp.utils.count_ops_and_params(model, example_inputs)
print(f"Before pruning: FLOPs={flops_before:.2f}, Params={params_before}")

# ... apply pruning ...

# Profile AFTER pruning
flops_after, params_after = tp.utils.count_ops_and_params(model, example_inputs)
print(f"After pruning:  FLOPs={flops_after:.2f}, Params={params_after}")
print(f"FLOPs reduction: {1 - flops_after / flops_before:.2%}")
print(f"Param reduction: {1 - params_after / params_before:.2%}")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment