PyTorch implementation for Convolutional Networks with Adaptive Inference Graphs

Overview

Convolutional Networks with Adaptive Inference Graphs (ConvNet-AIG)

This repository contains a PyTorch implementation of the paper Convolutional Networks with Adaptive Inference Graphs presented at ECCV 2018.

The code is based on the PyTorch example for training ResNet on Imagenet.

Table of Contents

  1. Introduction
  2. Usage
  3. Citing
  4. Requirements
  5. Contact

Introduction

Do convolutional networks really need a fixed feed-forward structure? What if, after identifying the high-level concept of an image, a network could move directly to a layer that can distinguish fine-grained differences? Currently, a network would first need to execute sometimes hundreds of intermediate layers that specialize in unrelated aspects. Ideally, the more a network already knows about an image, the better it should be at deciding which layer to compute next.

Convolutional networks with adaptive inference graphs (ConvNet-AIG) can adaptively define their network topology conditioned on the input image. Following a high-level structure similar to residual networks (ResNets), ConvNet-AIG decides for each input image on the fly which layers are needed. In experiments on ImageNet we show that ConvNet-AIG learns distinct inference graphs for different categories.

Usage

There are two training files. One for CIFAR-10 train.py and one for ImageNet train_img.py.

The network can be simply trained with python train.py or with optional arguments for different hyperparameters:

$ python train.py --expname {your experiment name}

For ImageNet the folder containing the dataset needs to be supplied

$ python train_img.py --expname {your experiment name} [imagenet-folder with train and val folders]

Training progress can be easily tracked with visdom using the --visdom flag. It keeps track of the learning rate, loss, training and validation accuracy as well as the activation rates of the gates for each class.

By default the training code keeps track of the model with the highest performance on the validation set. Thus, after the model has converged, it can be directly evaluated on the test set as follows

$ python train.py --test --resume runs/{your experiment name}/model_best.pth.tar

Requirements

This implementation is developed for

  1. Python 3.6.5
  2. PyTorch 0.3.1
  3. CUDA 9.1

Target Rate schedules

To improve performance and memory efficiency, the target rates of early, last and downsampling layers can be fixed so as to always execute the layers. Specifically, for the results in the paper the following target rate schedules are used for ResNet 50: [1, 1, 0.8, 1, t, t, t, 1, t, t, t, t, t, 1, 0.7, 1] for t in [0.4, 0.5, 0.6, 0.7] For ResNet 101 the following rates can be used: ([1]* 8).extend([t] * 25) for t in [0.3, 0.5]

For compatibility to newer versions, please make a pull request.

Citing

If you find this helps your research, please consider citing:

@conference{Veit2018,
title = {Convolutional Networks with Adaptive Inference Graphs},
author = {Andreas Veit and Serge Belongie},
year = {2018},
journal = {European Conference on Computer Vision (ECCV)},
}

Contact

andreas at cs dot cornell dot edu

Owner
Andreas Veit
Research Scientist at Google Research in New York City
Andreas Veit
A PyTorch library and evaluation platform for end-to-end compression research

CompressAI CompressAI (compress-ay) is a PyTorch library and evaluation platform for end-to-end compression research. CompressAI currently provides: c

InterDigital 680 Jan 06, 2023
PyTorch implementation(s) of various ResNet models from Twitch streams.

pytorch-resnet-twitch PyTorch implementation(s) of various ResNet models from Twitch streams. Status: ResNet50 currently not working. Will update in n

Daniel Bourke 3 Jan 11, 2022
codes for paper Combining Dynamic Local Context Focus and Dependency Cluster Attention for Aspect-level sentiment classification

DLCF-DCA codes for paper Combining Dynamic Local Context Focus and Dependency Cluster Attention for Aspect-level sentiment classification. submitted t

15 Aug 30, 2022
Code to reproduce experiments in the paper "Explainability Requires Interactivity".

Explainability Requires Interactivity This repository contains the code to train all custom models used in the paper Explainability Requires Interacti

Digital Health & Machine Learning 5 Apr 07, 2022
Fast convergence of detr with spatially modulated co-attention

Fast convergence of detr with spatially modulated co-attention Usage There are no extra compiled components in SMCA DETR and package dependencies are

peng gao 135 Dec 07, 2022
Breast Cancer Detection 🔬 ITI "AI_Pro" Graduation Project

BreastCancerDetection - This program is designed to predict two severity of abnormalities associated with breast cancer cells: benign and malignant. Mammograms from MIAS is preprocessed and features

6 Nov 29, 2022
Machine Learning University: Accelerated Computer Vision Class

Machine Learning University: Accelerated Computer Vision Class This repository contains slides, notebooks, and datasets for the Machine Learning Unive

AWS Samples 1.3k Dec 28, 2022
Region-aware Contrastive Learning for Semantic Segmentation, ICCV 2021

Region-aware Contrastive Learning for Semantic Segmentation, ICCV 2021 Abstract Recent works have made great success in semantic segmentation by explo

Hanzhe Hu 30 Dec 29, 2022
An original implementation of "Noisy Channel Language Model Prompting for Few-Shot Text Classification"

Channel LM Prompting (and beyond) This includes an original implementation of Sewon Min, Mike Lewis, Hannaneh Hajishirzi, Luke Zettlemoyer. "Noisy Cha

Sewon Min 92 Jan 07, 2023
Language-Driven Semantic Segmentation

Language-driven Semantic Segmentation (LSeg) The repo contains official PyTorch Implementation of paper Language-driven Semantic Segmentation. Authors

Intelligent Systems Lab Org 416 Jan 03, 2023
PyTorch implementation of paper "IBRNet: Learning Multi-View Image-Based Rendering", CVPR 2021.

IBRNet: Learning Multi-View Image-Based Rendering PyTorch implementation of paper "IBRNet: Learning Multi-View Image-Based Rendering", CVPR 2021. IBRN

Google Interns 371 Jan 03, 2023
Reviving Iterative Training with Mask Guidance for Interactive Segmentation

This repository provides the source code for training and testing state-of-the-art click-based interactive segmentation models with the official PyTorch implementation

Visual Understanding Lab @ Samsung AI Center Moscow 406 Jan 01, 2023
A data-driven approach to quantify the value of classifiers in a machine learning ensemble.

Documentation | External Resources | Research Paper Shapley is a Python library for evaluating binary classifiers in a machine learning ensemble. The

Benedek Rozemberczki 188 Dec 29, 2022
Website for D2C paper

D2C This is the repository that contains source code for the D2C Website. If you find D2C useful for your work please cite: @article{sinha2021d2c au

1 Oct 21, 2021
Code accompanying the paper "Knowledge Base Completion Meets Transfer Learning"

Knowledge Base Completion Meets Transfer Learning This code accompanies the paper Knowledge Base Completion Meets Transfer Learning published at EMNLP

14 Nov 27, 2022
Wenzhou-Kean University AI-LAB

AI-LAB This is Wenzhou-Kean University AI-LAB. Our research interests are in Computer Vision and Natural Language Processing. Computer Vision Please g

WKU AI-LAB 10 May 05, 2022
Reusable constraint types to use with typing.Annotated

annotated-types PEP-593 added typing.Annotated as a way of adding context-specific metadata to existing types, and specifies that Annotated[T, x] shou

125 Dec 26, 2022
Implementation of the master's thesis "Temporal copying and local hallucination for video inpainting".

Temporal copying and local hallucination for video inpainting This repository contains the implementation of my master's thesis "Temporal copying and

David Álvarez de la Torre 1 Dec 02, 2022
Covid19-Forecasting - An interactive website that tracks, models and predicts COVID-19 Cases

Covid-Tracker This is an interactive website that tracks, models and predicts CO

Adam Lahmadi 1 Feb 01, 2022
Run object detection model on the Raspberry Pi

Using TensorFlow Lite with Python is great for embedded devices based on Linux, such as Raspberry Pi.

Dimitri Yanovsky 6 Oct 08, 2022