Implementation of Continuous Sparsification, a method for pruning and ticket search in deep networks

Overview

PWC

Continuous Sparsification

Implementation of Continuous Sparsification (CS), a method based on l_0 regularization to find sparse neural networks, proposed in [Winning the Lottery with Continuous Sparsification].

Requirements

Python 2/3, PyTorch == 1.1.0

Training a ResNet on CIFAR with Continuous Sparsification

The main.py script can be used to train a ResNet-18 on CIFAR-10 with Continuous Sparsification. By default it will perform 3 rounds of training, each round consisting of 85 epochs. With the default hyperparameter values for the mask initialization, mask penalty, and final temperature, the method will find a sub-network with 20-30% sparsity which achieves 91.5-92.0% test accuracy when trained after rewinding (the dense network achieves 90-91%). The training and rewinding protocols follow the ones in the Lottery Ticket Hypothesis papers by Frankle.

In general, the sparsity of the final sub-network can be controlled by changing the value used to initialize the soft mask parameters. This can be done with, for example:

python main.py --mask-initial-value 0.1

The default value is 0.0 and increasing it will result in less sparse sub-networks. High sparsity sub-networks can be found by setting it to -0.1.

Extending the code

To train other network models with Continuous Sparsification, the first step is to choose which layers you want to sparsify and then implement PyTorch modules that perform soft masking on its original parameters. This repository contains code for 2D convolutions with soft masking: the SoftMaskedConv2d module in models/layers.py:

class SoftMaskedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding=1, stride=1, mask_initial_value=0.):
        super(SoftMaskedConv2d, self).__init__()
        self.mask_initial_value = mask_initial_value
        
        self.in_channels = in_channels
        self.out_channels = out_channels    
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
        nn.init.xavier_normal_(self.weight)
        self.init_weight = nn.Parameter(torch.zeros_like(self.weight), requires_grad=False)
        self.init_mask()
        
    def init_mask(self):
        self.mask_weight = nn.Parameter(torch.Tensor(self.out_channels, self.in_channels, self.kernel_size, self.kernel_size))
        nn.init.constant_(self.mask_weight, self.mask_initial_value)

    def compute_mask(self, temp, ticket):
        scaling = 1. / sigmoid(self.mask_initial_value)
        if ticket: mask = (self.mask_weight > 0).float()
        else: mask = F.sigmoid(temp * self.mask_weight)
        return scaling * mask      
        
    def prune(self, temp):
        self.mask_weight.data = torch.clamp(temp * self.mask_weight.data, max=self.mask_initial_value)   

    def forward(self, x, temp=1, ticket=False):
        self.mask = self.compute_mask(temp, ticket)
        masked_weight = self.weight * self.mask
        out = F.conv2d(x, masked_weight, stride=self.stride, padding=self.padding)        
        return out
        
    def checkpoint(self):
        self.init_weight.data = self.weight.clone()       
        
    def rewind_weights(self):
        self.weight.data = self.init_weight.clone()

    def extra_repr(self):
        return '{}, {}, kernel_size={}, stride={}, padding={}'.format(
            self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding)

Extending it to other layers is straightforward, since you only need to change the init, init_mask and the forward methods. In init_mask, you should create a mask parameter (of PyTorch Parameter type) for each parameter set that you want to sparsify -- each mask parameter must have the same dimensions as the corresponding parameter.

    def init_mask(self):
        self.mask_weight = nn.Parameter(torch.Tensor(...))
        nn.init.constant_(self.mask_weight, self.mask_initial_value)

In the forward method, you need to compute the masked parameter for each parameter to be sparsified (e.g. masked weights for a Linear layer), and then compute the output of the layer with the corresponding PyTorch functional call (e.g. F.Linear for Linear layers). For example:

    def forward(self, x, temp=1, ticket=False):
        self.mask = self.compute_mask(temp, ticket)
        masked_weight = self.weight * self.mask
        out = F.linear(x, masked_weight)        
        return out

Once all the required layers have been implemented, it remains to implement the network which CS will sparsify. In models/networks.py, you can find code for the ResNet-18 and use it as base to implement other networks. In general, your network can inherit from MaskedNet instead of nn.Module and most of the required functionalities will be immediately available. What remains is to use the layers you implemented (the ones with soft masked paramaters) in your network, and remember to pass temp and ticket as additional inputs: temp is the current temperature of CS (assumed to be the attribute model.temp in main.py), while ticket is a boolean variable that controls whether the parameters' masks should be soft (ticket=False) or hard (ticket=True). Having ticket=True means that the mask will be binary and the masked parameters will actually be sparse. Use ticket=False for training (i.e. sub-network search) and ticket=True once you are done and want to evaluate the sparse sub-network.

Future plans

We plan to make the effort of applying CS to other layers/networks considerably smaller. This will be hopefully achieved by offering a function that receives a standard PyTorch Module object and returns another Module but with the mask parameters properly created and the forward passes overloaded to use masked parameters instead.

If there are specific functionalities that would help you in your research or in applying our method in general, feel free to suggest it and we will consider implementing it.

Citation

If you use our method for research purposes, please cite our work:

@article{ssm2019cs,
       author = {Savarese, Pedro and Silva, Hugo and Maire, Michael},
        title = {Winning the Lottery with Continuous Sparsification},
      journal = {arXiv:1912.04427},
         year = "2019"
}
Owner
Pedro Savarese
PhD student at TTIC
Pedro Savarese
A new GCN model for Point Cloud Analyse

Pytorch Implementation of PointNet and PointNet++ This repo is implementation for VA-GCN in pytorch. Classification (ModelNet10/40) Data Preparation D

12 Feb 02, 2022
git《USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation》(2020) GitHub: [fig2]

USD-Seg This project is an implement of paper USD-Seg:Learning Universal Shape Dictionary for Realtime Instance Segmentation, based on FCOS detector f

Ruolin Ye 80 Nov 28, 2022
TLXZoo - Pre-trained models based on TensorLayerX

Pre-trained models based on TensorLayerX. TensorLayerX is a multi-backend AI fra

TensorLayer Community 13 Dec 07, 2022
Object detection GUI based on PaddleDetection

PP-Tracking GUI界面测试版 本项目是基于飞桨开源的实时跟踪系统PP-Tracking开发的可视化界面 在PaddlePaddle中加入pyqt进行GUI页面研发,可使得整个训练过程可视化,并通过GUI界面进行调参,模型预测,视频输出等,通过多种类型的识别,简化整体预测流程。 GUI界面

杨毓栋 68 Jan 02, 2023
RL Algorithms with examples in Python / Pytorch / Unity ML agents

Reinforcement Learning Project This project was created to make it easier to get started with Reinforcement Learning. It now contains: An implementati

Rogier Wachters 3 Aug 19, 2022
Conversational text Analysis using various NLP techniques

PyConverse Let me try first Installation pip install pyconverse Usage Please try this notebook that demos the core functionalities: basic usage noteb

Rita Anjana 158 Dec 25, 2022
Pytorch implementation for "Large-Scale Long-Tailed Recognition in an Open World" (CVPR 2019 ORAL)

Large-Scale Long-Tailed Recognition in an Open World [Project] [Paper] [Blog] Overview Open Long-Tailed Recognition (OLTR) is the author's re-implemen

Zhongqi Miao 761 Dec 26, 2022
Python版OpenCVのTracking APIのサンプルです。DaSiamRPNアルゴリズムまで対応しています。

OpenCV-Object-Tracker-Sample Python版OpenCVのTracking APIのサンプルです。   Requirement opencv-contrib-python 4.5.3.56 or later Algorithm 2021/07/16時点でOpenCVには以

KazuhitoTakahashi 36 Jan 01, 2023
Causal estimators for use with WhyNot

WhyNot Estimators A collection of causal inference estimators implemented in Python and R to pair with the Python causal inference library whynot. For

ZYKLS 8 Apr 06, 2022
Official Implementation of DAFormer: Improving Network Architectures and Training Strategies for Domain-Adaptive Semantic Segmentation

DAFormer: Improving Network Architectures and Training Strategies for Domain-Adaptive Semantic Segmentation [Arxiv] [Paper] As acquiring pixel-wise an

Lukas Hoyer 305 Dec 29, 2022
ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

ESGD-M - A stochastic non-convex second order optimizer, suitable for training deep learning models, for PyTorch

Katherine Crowson 53 Dec 29, 2022
This repository contains the code for "Self-Diagnosis and Self-Debiasing: A Proposal for Reducing Corpus-Based Bias in NLP".

Self-Diagnosis and Self-Debiasing This repository contains the source code for Self-Diagnosis and Self-Debiasing: A Proposal for Reducing Corpus-Based

Timo Schick 62 Dec 12, 2022
Graduation Project

Gesture-Detection-and-Depth-Estimation This is my graduation project. (1) In this project, I use the YOLOv3 object detection model to detect gesture i

ChaosAT 1 Nov 23, 2021
A Flow-based Generative Network for Speech Synthesis

WaveGlow: a Flow-based Generative Network for Speech Synthesis Ryan Prenger, Rafael Valle, and Bryan Catanzaro In our recent paper, we propose WaveGlo

NVIDIA Corporation 2k Dec 26, 2022
GeneralOCR is open source Optical Character Recognition based on PyTorch.

Introduction GeneralOCR is open source Optical Character Recognition based on PyTorch. It makes a fidelity and useful tool to implement SOTA models on

57 Dec 29, 2022
Few-shot NLP benchmark for unified, rigorous eval

FLEX FLEX is a benchmark and framework for unified, rigorous few-shot NLP evaluation. FLEX enables: First-class NLP support Support for meta-training

AI2 85 Dec 03, 2022
Code for “ACE-HGNN: Adaptive Curvature ExplorationHyperbolic Graph Neural Network”

ACE-HGNN: Adaptive Curvature Exploration Hyperbolic Graph Neural Network This repository is the implementation of ACE-HGNN in PyTorch. Environment pyt

9 Nov 28, 2022
3D position tracking for soccer players with multi-camera videos

This repo contains a full pipeline to support 3D position tracking of soccer players, with multi-view calibrated moving/fixed video sequences as inputs.

Yuchang Jiang 72 Dec 27, 2022
pytorch implementation of the ICCV'21 paper "MVTN: Multi-View Transformation Network for 3D Shape Recognition"

MVTN: Multi-View Transformation Network for 3D Shape Recognition (ICCV 2021) By Abdullah Hamdi, Silvio Giancola, Bernard Ghanem Paper | Video | Tutori

Abdullah Hamdi 64 Jan 03, 2023
Amazing-Python-Scripts - 🚀 Curated collection of Amazing Python scripts from Basics to Advance with automation task scripts.

📑 Introduction A curated collection of Amazing Python scripts from Basics to Advance with automation task scripts. This is your Personal space to fin

Avinash Ranjan 1.1k Dec 29, 2022