Implementation:VainF Torch Pruning GroupHessianImportance
| Field | Value |
|---|---|
| Sources | Torch-Pruning, Optimal Brain Damage |
| Domains | Deep_Learning, Model_Compression |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for Hessian-based group importance estimation provided by Torch-Pruning.
Description
GroupHessianImportance extends GroupMagnitudeImportance to incorporate second-order information for structured pruning. Instead of relying on weight magnitudes alone, it uses per-sample gradient accumulation (grad.pow(2)) to approximate the diagonal Hessian, yielding curvature-weighted importance scores for each channel group.
The typical workflow requires three steps:
- Call
zero_grad()to clear any previously accumulated gradients. - For each sample in a calibration batch, call
accumulate_grad(model)to accumulate the squared gradients (grad.pow(2)) across model parameters. - Call
__call__(group)to compute the final importance scores for a given pruning group.
This three-step pattern ensures that the diagonal Hessian estimate reflects the curvature information from the calibration data before importance scores are computed.
Code Reference
Source Location
- Repository
Torch-Pruning- File
torch_pruning/pruner/importance.py- Lines
- 662--816
Signature
class GroupHessianImportance(GroupMagnitudeImportance):
def __init__(self,
group_reduction: str = "mean",
normalizer: str = 'mean',
bias: bool = False,
target_types: list = [_ConvNd, Linear, _BatchNorm, LayerNorm]):
Additional Methods
def zero_grad(self):
"""Clear accumulated gradients."""
def accumulate_grad(self, model):
"""Accumulate grad.pow(2) for all target parameters in the model."""
Import
from torch_pruning.pruner.importance import GroupHessianImportance
I/O Contract
Constructor Inputs
| Name | Type | Default | Description |
|---|---|---|---|
group_reduction |
str | "mean" | Reduction method to aggregate per-parameter importance within a group (e.g., "mean", "sum", "max") |
normalizer |
str | "mean" | Normalization strategy applied to importance scores (e.g., "mean", "sum", "max", "gaussian") |
bias |
bool | False | Whether to include bias parameters in importance computation |
target_types |
list | [_ConvNd, Linear, _BatchNorm, LayerNorm] | Module types whose parameters are considered for importance scoring |
Methods
| Method | Input | Output | Description |
|---|---|---|---|
zero_grad() |
(none) | (none) | Clears all accumulated squared gradients |
accumulate_grad(model) |
model (nn.Module) |
(none) | Accumulates grad.pow(2) for each parameter in the model
|
__call__(group) |
group (Group) |
1-D torch.Tensor or None |
Returns importance scores for each channel in the group, or None if the group contains no target types |
Usage Examples
Per-Sample Gradient Loop Pattern
import torch
from torch_pruning.pruner.importance import GroupHessianImportance
# Initialize the importance estimator
imp = GroupHessianImportance()
# Clear any previously accumulated gradients
imp.zero_grad()
# Accumulate squared gradients over calibration samples
for batch in calibration_loader:
inputs, targets = batch
outputs = model(inputs)
loss = loss_fn(outputs, targets)
loss.backward()
imp.accumulate_grad(model)
model.zero_grad()
# Compute importance for a specific pruning group
importance_scores = imp(group)