A Closer Look at Structured Pruning for Neural Network Compression

Overview

A Closer Look at Structured Pruning for Neural Network Compression

Code used to reproduce experiments in https://arxiv.org/abs/1810.04622.

To prune, we fill our networks with custom MaskBlocks, which are manipulated using Pruner in funcs.py. There will certainly be a better way to do this, but we leave this as an exercise to someone who can code much better than we can.

Setup

This is best done in a clean conda environment:

conda create -n prunes python=3.6
conda activate prunes
conda install pytorch torchvision -c pytorch

Repository layout

-train.py: contains all of the code for training large models from scratch and for training pruned models from scratch
-prune.py: contains the code for pruning trained models
-funcs.py: contains useful pruning functions and any functions we used commonly

CIFAR Experiments

First, you will need some initial models.

To train a WRN-40-2:

python train.py --net='res' --depth=40 --width=2.0 --data_loc= --save_file='res'

The default arguments of train.py are suitable for training WRNs. The following trains a DenseNet-BC-100 (k=12) with its default hyperparameters:

python train.py --net='dense' --depth=100 --data_loc= --save_file='dense' --no_epochs 300 -b 64 --epoch_step '[150,225]' --weight_decay 0.0001 --lr_decay_ratio 0.1

These will automatically save checkpoints to the checkpoints folder.

Pruning

Once training is finished, we can prune our networks using prune.py (defaults are set to WRN pruning, so extra arguments are needed for DenseNets)

python prune.py --net='res'   --data_loc= --base_model='res' --save_file='res_fisher'
python prune.py --net='res'   --data_loc= --l1_prune=True --base_model='res' --save_file='res_l1'

python prune.py --net='dense' --depth 100 --data_loc= --base_model='dense' --save_file='dense_fisher' --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64 --no_epochs 2600
python prune.py --net='dense' --depth 100 --data_loc= --l1_prune=True --base_model='dense' --save_file='dense_l1'  --learning_rate 1e-3 --weight_decay 1e-4 --batch_size 64  --no_epochs 2600

Note that the default is to perform Fisher pruning, so you don't need to pass a flag to use it.
Once finished, we can train the pruned models from scratch, e.g.:

python train.py --data_loc= --net='res' --base_file='res_fisher__prunes' --deploy --mask=1 --save_file='res_fisher__prunes_scratch'

Each model can then be evaluated using:

python train.py --deploy --eval --data_loc= --net='res' --mask=1 --base_file='res_fisher__prunes'

Training Reduced models

This can be done by varying the input arguments to train.py. To reduce depth or width of a WRN, change the corresponding option:

python train.py --net='res' --depth= --width= --data_loc= --save_file='res_reduced'

To add bottlenecks, use the following:

python train.py --net='res' --depth=40 --width=2.0 --data_loc= --save_file='res_bottle' --bottle --bottle_mult 

With DenseNets you can modify the depth or growth, or use --bottle --bottle_mult as above.

Acknowledgements

Jack Turner wrote the L1 stuff, and some other stuff for that matter.

Code has been liberally borrowed from many a repo, including, but not limited to:

https://github.com/xternalz/WideResNet-pytorch
https://github.com/bamos/densenet.pytorch
https://github.com/kuangliu/pytorch-cifar
https://github.com/ShichenLiu/CondenseNet

Citing this work

If you would like to cite this work, please use the following bibtex entry:

@article{crowley2018pruning,
  title={A Closer Look at Structured Pruning for Neural Network Compression},
  author={Crowley, Elliot J and Turner, Jack and Storkey, Amos and O'Boyle, Michael},
  journal={arXiv preprint arXiv:1810.04622},
  year={2018},
  }
Owner
Bayesian and Neural Systems Group
Machine learning research group @ University of Edinburgh
Bayesian and Neural Systems Group
Official implementation of Rich Semantics Improve Few-Shot Learning (BMVC, 2021)

Rich Semantics Improve Few-Shot Learning Paper Link Abstract : Human learning benefits from multi-modal inputs that often appear as rich semantics (e.

Mohamed Afham 11 Jul 26, 2022
Code for ViTAS_Vision Transformer Architecture Search

Vision Transformer Architecture Search This repository open source the code for ViTAS: Vision Transformer Architecture Search. ViTAS aims to search fo

46 Dec 17, 2022
[CVPRW 2021] Code for Region-Adaptive Deformable Network for Image Quality Assessment

RADN [CVPRW 2021] Code for Region-Adaptive Deformable Network for Image Quality Assessment [Paper on arXiv] Overview Update [2021/5/7] add codes for W

IIGROUP 53 Dec 28, 2022
PyTorch implementation of our Adam-NSCL algorithm from our CVPR2021 (oral) paper "Training Networks in Null Space for Continual Learning"

Adam-NSCL This is a PyTorch implementation of Adam-NSCL algorithm for continual learning from our CVPR2021 (oral) paper: Title: Training Networks in N

Shipeng Wang 34 Dec 21, 2022
Deep Convolutional Generative Adversarial Networks

Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks Alec Radford, Luke Metz, Soumith Chintala All images in t

Alec Radford 3.4k Dec 29, 2022
Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition in CVPR19

2s-AGCN Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition in CVPR19 Note PyTorch version should be 0.3! For PyTor

LShi 547 Dec 26, 2022
iNAS: Integral NAS for Device-Aware Salient Object Detection

iNAS: Integral NAS for Device-Aware Salient Object Detection Introduction Integral search design (jointly consider backbone/head structures, design/de

é”¾å®‡č¶… 77 Dec 02, 2022
Dcf-game-infrastructure-public - Contains all the components necessary to run a DC finals (attack-defense CTF) game from OOO

dcf-game-infrastructure All the components necessary to run a game of the OOO DC

Order of the Overflow 46 Sep 13, 2022
Codebase to experiment with a hybrid Transformer that combines conditional sequence generation with regression

Regression Transformer Codebase to experiment with a hybrid Transformer that combines conditional sequence generation with regression . Development se

International Business Machines 27 Jan 05, 2023
Flickr-Faces-HQ (FFHQ) is a high-quality image dataset of human faces, originally created as a benchmark for generative adversarial networks (GAN)

Flickr-Faces-HQ Dataset (FFHQ) Flickr-Faces-HQ (FFHQ) is a high-quality image dataset of human faces, originally created as a benchmark for generative

NVIDIA Research Projects 2.9k Dec 28, 2022
Geometric Algebra package for JAX

JAXGA - JAX Geometric Algebra GitHub | Docs JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing onl

Robin Kahlow 36 Dec 22, 2022
The official implementation of A Unified Game-Theoretic Interpretation of Adversarial Robustness.

This repository is the official implementation of A Unified Game-Theoretic Interpretation of Adversarial Robustness. Requirements pip install -r requi

Jie Ren 17 Dec 12, 2022
PyTorch implementation for MINE: Continuous-Depth MPI with Neural Radiance Fields

MINE: Continuous-Depth MPI with Neural Radiance Fields Project Page | Video PyTorch implementation for our ICCV 2021 paper. MINE: Towards Continuous D

Zijian Feng 325 Dec 29, 2022
Self-Supervised Document-to-Document Similarity Ranking via Contextualized Language Models and Hierarchical Inference

Self-Supervised Document Similarity Ranking (SDR) via Contextualized Language Models and Hierarchical Inference This repo is the implementation for SD

Microsoft 36 Nov 28, 2022
A high-performance Python-based I/O system for large (and small) deep learning problems, with strong support for PyTorch.

WebDataset WebDataset is a PyTorch Dataset (IterableDataset) implementation providing efficient access to datasets stored in POSIX tar archives and us

1.1k Jan 08, 2023
Self-Supervised Pillar Motion Learning for Autonomous Driving (CVPR 2021)

Self-Supervised Pillar Motion Learning for Autonomous Driving Chenxu Luo, Xiaodong Yang, Alan Yuille Self-Supervised Pillar Motion Learning for Autono

QCraft 101 Dec 05, 2022
Implementations of paper Controlling Directions Orthogonal to a Classifier

Classifier Orthogonalization Implementations of paper Controlling Directions Orthogonal to a Classifier , ICLR 2022, Yilun Xu, Hao He, Tianxiao Shen,

Yilun Xu 33 Dec 01, 2022
This is a simple framework to make object detection dataset very quickly

FastAnnotation Table of contents General info Requirements Setup General info This is a simple framework to make object detection dataset very quickly

Serena Tetart 1 Jan 24, 2022
MTA:SA Server Configer.

MTAConfiger MTA:SA Server Configer. Hi šŸ‘‹ , I'm Alireza A Python Developer Boy šŸ”­ Iā€™m currently working on my C# projects šŸŒ± Iā€™m currently Learning CS

3 Jun 07, 2022
Finite-temperature variational Monte Carlo calculation of uniform electron gas using neural canonical transformation.

CoulombGas This code implements the neural canonical transformation approach to the thermodynamic properties of uniform electron gas. Building on JAX,

FermiFlow 9 Mar 03, 2022