Implementation:VainF Torch Pruning GroupNormPruner
Overview
Concrete tool for group-level importance-adaptive regularization and pruning provided by Torch-Pruning.
Description
GroupNormPruner extends BasePruner with regularization capabilities. During training, call update_regularizer() each epoch and regularize(model) each iteration after loss.backward(). It modifies weight gradients to drive low-importance channels toward zero. After training, use step() to prune.
The pruner operates at the dependency group level: it iterates over all groups in the dependency graph, computes per-channel importance scores, derives an adaptive scaling factor gamma, and applies the regularization term to gradients of output channels, input channels, and BatchNorm layers within each group.
Code Reference
Source: torch_pruning/pruner/algorithms/group_norm_pruner.py, Lines 10-192
Signature:
class GroupNormPruner(BasePruner):
def __init__(
self,
model: nn.Module,
example_inputs: torch.Tensor,
importance: typing.Callable,
reg=1e-4,
alpha=4,
global_pruning: bool = False,
pruning_ratio: float = 0.5,
pruning_ratio_dict: typing.Dict[nn.Module, float] = None,
max_pruning_ratio: float = 1.0,
iterative_steps: int = 1,
iterative_pruning_ratio_scheduler: typing.Callable = linear_scheduler,
ignored_layers: typing.List[nn.Module] = None,
round_to: int = None,
isomorphic: bool = False,
in_channel_groups: typing.Dict[nn.Module, int] = dict(),
out_channel_groups: typing.Dict[nn.Module, int] = dict(),
num_heads: typing.Dict[nn.Module, int] = dict(),
prune_num_heads: bool = False,
prune_head_dims: bool = True,
head_pruning_ratio: float = 0.0,
head_pruning_ratio_dict: typing.Dict[nn.Module, float] = None,
customized_pruners: typing.Dict[typing.Any, function.BasePruningFunc] = None,
unwrapped_parameters: typing.Dict[nn.Parameter, int] = None,
root_module_types: typing.List = [ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM],
forward_fn: typing.Callable = None,
output_transform: typing.Callable = None,
# ... inherits all BasePruner params
):
Key Methods:
update_regularizer()-- Refreshes the internal group list from the dependency graph. Call once per epoch.regularize(model, alpha=16, bias=False)-- Applies importance-adaptive regularization to weight gradients in-place. Call after eachloss.backward().step()-- Inherited fromBasePruner. Executes one round of structural pruning.
Import:
import torch_pruning as tp
pruner = tp.pruner.GroupNormPruner(...)
I/O Contract
| Parameter | Type | Default | Description |
|---|---|---|---|
reg |
float | 1e-4 | Base regularization coefficient |
alpha |
float | 4 | Regularization scaling factor; range is [2^0, 2^alpha] |
model |
nn.Module | (required) | The model to prune |
example_inputs |
torch.Tensor | (required) | Dummy inputs for dependency graph tracing |
importance |
Callable | (required) | Group importance estimator |
Behavior:
regularize()modifiesweight.grad.datain-place for all layers in each dependency group.step()prunes the model in-place, removing channels identified as least important.
Usage Examples
import torch
import torch.nn as nn
import torch_pruning as tp
# 1. Setup model and pruner
model = ResNet50(num_classes=100)
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.GroupNormImportance(p=2)
pruner = tp.pruner.GroupNormPruner(
model=model,
example_inputs=example_inputs,
importance=imp,
reg=1e-4,
alpha=4,
pruning_ratio=0.5,
iterative_steps=1,
)
# 2. Sparse training loop with regularization
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(num_epochs):
pruner.update_regularizer() # refresh groups each epoch
for batch_idx, (data, target) in enumerate(train_loader):
output = model(data)
loss = criterion(output, target)
loss.backward()
pruner.regularize(model) # apply adaptive regularization to gradients
optimizer.step()
optimizer.zero_grad()
# 3. Prune the model after sparse training
pruner.step()