Need help with Torch-Pruning?
Click the “chat” button below for chat support from the developer who created it, or find similar developers for support.

About the developer

VainF
174 Stars 31 Forks MIT License 92 Commits 15 Opened issues

Description

A pytorch pruning toolkit for structured neural network pruning and layer dependency maintaining.

Services available

!
?

Need anything else?

Contributors list

# 103,163
Python
pytorch
Shell
mobilen...
80 commits
# 164,225
Python
Shell
C++
2 commits

Torch-Pruning

A pytorch toolkit for structured neural network pruning and layer dependency maintaining

This tool will automatically detect and handle layer dependencies (channel consistency) during pruning. It is able to handle various network architectures such as DenseNet, ResNet, and Inception. See examples/test_models.py for more supported models.

How it works

This package will run your model with fake inputs and collect forward information just like

torch.jit
. Then a dependency graph is established to describe the computational graph. When a pruning function (e.g. torchpruning.pruneconv ) is applied on certain layer through
DependencyGraph.get_pruning_plan
, this package will traverse the whole graph to fix inconsistent modules such as BN. The pruning index will be automatically mapped to correct position if there is
torch.split
or
torch.cat
in your model.

Tip: please remember to save the whole model object (weights+architecture) rather than model weights only:

# save a pruned model
# torch.save(model.state_dict(), 'model.pth') # weights only
torch.save(model, 'model.pth') # obj (arch) + weights

load a pruned model

model = torch.load('model.pth') # no load_state_dict

| Dependency | Visualization | Example | | :------------------: | :------------: | :-----: | | Conv-Conv | | AlexNet | | Conv-FC (Global Pooling or Flatten) | | ResNet, VGG |
| Skip Connection | | ResNet | Concatenation | | DenseNet, ASPP | | Split | | torch.chunk |

Known Issues:

  • When groups>1, only depthwise conv is supported, i.e.
    groups
    =
    in_channels
    =
    out_channels
    .
  • Customized operations will be treated as element-wise op, e.g. subclass of
    torch.autograd.Function
    .

Installation

pip install torch_pruning

Quickstart

A minimal example

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)

1. setup strategy (L1 Norm)

strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy()

2. build layer dependency for resnet18

DG = tp.DependencyGraph() DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

3. get a pruning plan from the dependency graph.

pruning_idxs = strategy(model.conv1.weight, amount=0.4) # or manually selected pruning_idxs=[2, 6, 9] pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs ) print(pruning_plan)

4. execute this plan (prune the model)

pruning_plan.exec()

Pruning the resnet.conv1 will affect several layers. Let's inspect the pruning plan (with pruning_idxs=[2, 6, 9]):

-------------
[  prune_conv on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False))>, Index=[2, 6, 9], NumPruned=441]
[  prune_batchnorm on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))>, Index=[2, 6, 9], NumPruned=6]
[  _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[  _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[  prune_related_conv on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=1728]
[  _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[  prune_batchnorm on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))>, Index=[2, 6, 9], NumPruned=6]
[  prune_conv on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=1728]
[  _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[  prune_related_conv on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=1728]
[  _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[  prune_batchnorm on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))>, Index=[2, 6, 9], NumPruned=6]
[  prune_conv on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=1728]
[  _prune_elementwise_op on _ElementWiseOp()>, Index=[2, 6, 9], NumPruned=0]
[  prune_related_conv on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False))>, Index=[2, 6, 9], NumPruned=3456]
[  prune_related_conv on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False))>, Index=[2, 6, 9], NumPruned=384]
11211 parameters will be pruned
-------------

Low-level pruning functions

We have to manually handle the broken dependencies without DependencyGraph.

tp.prune_conv( model.conv1, idxs=[2,6,9] )

fix the broken dependencies manually

tp.prune_batchnorm( model.bn1, idxs=[2,6,9] ) tp.prune_related_conv( model.layer2[0].conv1, idxs=[2,6,9] ) ...

Customized Layers

Please refer to 'examples/customize_layer.py' for pruning customized layers with this package. A detailed tutorial is on the way!

Layer Dependency

During structured pruning, we need to maintain the channel consistency between different layers.

A Simple Case

More Complicated Cases

the layer dependency becomes much more complicated when the model contains skip connections or concatenations.

Residual Block:

Concatenation:

See paper Pruning Filters for Efficient ConvNets for more details.

Example: ResNet18 on Cifar10

1. Train the model

cd examples
python prune_resnet18_cifar10.py --mode train # 11.1M, Acc=0.9248

2. Pruning and fintuning

python prune_resnet18_cifar10.py --mode prune --round 1 --total_epochs 30 --step_size 20 # 4.5M, Acc=0.9229
python prune_resnet18_cifar10.py --mode prune --round 2 --total_epochs 30 --step_size 20 # 1.9M, Acc=0.9207
python prune_resnet18_cifar10.py --mode prune --round 3 --total_epochs 30 --step_size 20 # 0.8M, Acc=0.9176
python prune_resnet18_cifar10.py --mode prune --round 4 --total_epochs 30 --step_size 20 # 0.4M, Acc=0.9102
python prune_resnet18_cifar10.py --mode prune --round 5 --total_epochs 30 --step_size 20 # 0.2M, Acc=0.9011
...

We use cookies. If you continue to browse the site, you agree to the use of cookies. For more information on our use of cookies please see our Privacy Policy.