Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:VainF Torch Pruning GroupHessianImportance

From Leeroopedia


Field Value
Sources Torch-Pruning, Optimal Brain Damage
Domains Deep_Learning, Model_Compression
Last Updated 2026-02-08 00:00 GMT

Overview

Concrete tool for Hessian-based group importance estimation provided by Torch-Pruning.

Description

GroupHessianImportance extends GroupMagnitudeImportance to incorporate second-order information for structured pruning. Instead of relying on weight magnitudes alone, it uses per-sample gradient accumulation (grad.pow(2)) to approximate the diagonal Hessian, yielding curvature-weighted importance scores for each channel group.

The typical workflow requires three steps:

  1. Call zero_grad() to clear any previously accumulated gradients.
  2. For each sample in a calibration batch, call accumulate_grad(model) to accumulate the squared gradients (grad.pow(2)) across model parameters.
  3. Call __call__(group) to compute the final importance scores for a given pruning group.

This three-step pattern ensures that the diagonal Hessian estimate reflects the curvature information from the calibration data before importance scores are computed.

Code Reference

Source Location

Repository
Torch-Pruning
File
torch_pruning/pruner/importance.py
Lines
662--816

Signature

class GroupHessianImportance(GroupMagnitudeImportance):
    def __init__(self,
                 group_reduction: str = "mean",
                 normalizer: str = 'mean',
                 bias: bool = False,
                 target_types: list = [_ConvNd, Linear, _BatchNorm, LayerNorm]):

Additional Methods

def zero_grad(self):
    """Clear accumulated gradients."""

def accumulate_grad(self, model):
    """Accumulate grad.pow(2) for all target parameters in the model."""

Import

from torch_pruning.pruner.importance import GroupHessianImportance

I/O Contract

Constructor Inputs

Name Type Default Description
group_reduction str "mean" Reduction method to aggregate per-parameter importance within a group (e.g., "mean", "sum", "max")
normalizer str "mean" Normalization strategy applied to importance scores (e.g., "mean", "sum", "max", "gaussian")
bias bool False Whether to include bias parameters in importance computation
target_types list [_ConvNd, Linear, _BatchNorm, LayerNorm] Module types whose parameters are considered for importance scoring

Methods

Method Input Output Description
zero_grad() (none) (none) Clears all accumulated squared gradients
accumulate_grad(model) model (nn.Module) (none) Accumulates grad.pow(2) for each parameter in the model
__call__(group) group (Group) 1-D torch.Tensor or None Returns importance scores for each channel in the group, or None if the group contains no target types

Usage Examples

Per-Sample Gradient Loop Pattern

import torch
from torch_pruning.pruner.importance import GroupHessianImportance

# Initialize the importance estimator
imp = GroupHessianImportance()

# Clear any previously accumulated gradients
imp.zero_grad()

# Accumulate squared gradients over calibration samples
for batch in calibration_loader:
    inputs, targets = batch
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()
    imp.accumulate_grad(model)
    model.zero_grad()

# Compute importance for a specific pruning group
importance_scores = imp(group)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment