Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:VainF Torch Pruning GroupNormPruner

From Leeroopedia


Template:Metadata

Overview

Concrete tool for group-level importance-adaptive regularization and pruning provided by Torch-Pruning.

Description

GroupNormPruner extends BasePruner with regularization capabilities. During training, call update_regularizer() each epoch and regularize(model) each iteration after loss.backward(). It modifies weight gradients to drive low-importance channels toward zero. After training, use step() to prune.

The pruner operates at the dependency group level: it iterates over all groups in the dependency graph, computes per-channel importance scores, derives an adaptive scaling factor gamma, and applies the regularization term to gradients of output channels, input channels, and BatchNorm layers within each group.

Code Reference

Source: torch_pruning/pruner/algorithms/group_norm_pruner.py, Lines 10-192

Signature:

class GroupNormPruner(BasePruner):
    def __init__(
        self,
        model: nn.Module,
        example_inputs: torch.Tensor,
        importance: typing.Callable,
        reg=1e-4,
        alpha=4,
        global_pruning: bool = False,
        pruning_ratio: float = 0.5,
        pruning_ratio_dict: typing.Dict[nn.Module, float] = None,
        max_pruning_ratio: float = 1.0,
        iterative_steps: int = 1,
        iterative_pruning_ratio_scheduler: typing.Callable = linear_scheduler,
        ignored_layers: typing.List[nn.Module] = None,
        round_to: int = None,
        isomorphic: bool = False,
        in_channel_groups: typing.Dict[nn.Module, int] = dict(),
        out_channel_groups: typing.Dict[nn.Module, int] = dict(),
        num_heads: typing.Dict[nn.Module, int] = dict(),
        prune_num_heads: bool = False,
        prune_head_dims: bool = True,
        head_pruning_ratio: float = 0.0,
        head_pruning_ratio_dict: typing.Dict[nn.Module, float] = None,
        customized_pruners: typing.Dict[typing.Any, function.BasePruningFunc] = None,
        unwrapped_parameters: typing.Dict[nn.Parameter, int] = None,
        root_module_types: typing.List = [ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM],
        forward_fn: typing.Callable = None,
        output_transform: typing.Callable = None,
        # ... inherits all BasePruner params
    ):

Key Methods:

  • update_regularizer() -- Refreshes the internal group list from the dependency graph. Call once per epoch.
  • regularize(model, alpha=16, bias=False) -- Applies importance-adaptive regularization to weight gradients in-place. Call after each loss.backward().
  • step() -- Inherited from BasePruner. Executes one round of structural pruning.

Import:

import torch_pruning as tp
pruner = tp.pruner.GroupNormPruner(...)

I/O Contract

Parameter Type Default Description
reg float 1e-4 Base regularization coefficient
alpha float 4 Regularization scaling factor; range is [2^0, 2^alpha]
model nn.Module (required) The model to prune
example_inputs torch.Tensor (required) Dummy inputs for dependency graph tracing
importance Callable (required) Group importance estimator

Behavior:

  • regularize() modifies weight.grad.data in-place for all layers in each dependency group.
  • step() prunes the model in-place, removing channels identified as least important.

Usage Examples

import torch
import torch.nn as nn
import torch_pruning as tp

# 1. Setup model and pruner
model = ResNet50(num_classes=100)
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.GroupNormImportance(p=2)

pruner = tp.pruner.GroupNormPruner(
    model=model,
    example_inputs=example_inputs,
    importance=imp,
    reg=1e-4,
    alpha=4,
    pruning_ratio=0.5,
    iterative_steps=1,
)

# 2. Sparse training loop with regularization
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
for epoch in range(num_epochs):
    pruner.update_regularizer()  # refresh groups each epoch
    for batch_idx, (data, target) in enumerate(train_loader):
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        pruner.regularize(model)   # apply adaptive regularization to gradients
        optimizer.step()
        optimizer.zero_grad()

# 3. Prune the model after sparse training
pruner.step()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment