Code used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks"

Related tags

Deep LearningbigBatch
Overview

Train longer, generalize better - Big batch training

This is a code repository used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks" By Elad Hoffer, Itay Hubara and Daniel Soudry.

It is based off convNet.pytorch with some helpful options such as:

  • Training on several datasets
  • Complete logging of trained experiment
  • Graph visualization of the training/validation loss and accuracy
  • Definition of preprocessing and optimization regime for each model

Dependencies

Data

  • Configure your dataset path at data.py.
  • To get the ILSVRC data, you should register on their site for access: http://www.image-net.org/

Experiment examples

python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_lr_fix --epochs 100 --b 2048 --lr_bb_fix;
python main_normal.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_regime_adaptation --epochs 100 --b 2048 --lr_bb_fix --regime_bb_fix;
python main_gbn.py --dataset cifar10 --model resnet --save cifar10_resnet44_bs2048_ghost_bn256 --epochs 100 --b 2048 --lr_bb_fix --mini-batch-size 256;
python main_normal.py --dataset cifar100 --model resnet --save cifar100_wresnet16_4_bs1024_regime_adaptation --epochs 100 --b 1024 --lr_bb_fix --regime_bb_fix;
python main_gbn.py --model mnist_f1 --dataset mnist --save mnist_baseline_bs4096_gbn --epochs 50 --b 4096 --lr_bb_fix --no-regime_bb_fix --mini-batch-size 128;
  • See run_experiments.sh for more examples

Model configuration

Network model is defined by writing a .py file in models folder, and selecting it using the model flag. Model function must be registered in models/__init__.py The model function must return a trainable network. It can also specify additional training options such optimization regime (either a dictionary or a function), and input transform modifications.

e.g for a model definition:

class Model(nn.Module):

    def __init__(self, num_classes=1000):
        super(Model, self).__init__()
        self.model = nn.Sequential(...)

        self.regime = {
            0: {'optimizer': 'SGD', 'lr': 1e-2,
                'weight_decay': 5e-4, 'momentum': 0.9},
            15: {'lr': 1e-3, 'weight_decay': 0}
        }

        self.input_transform = {
            'train': transforms.Compose([...]),
            'eval': transforms.Compose([...])
        }
    def forward(self, inputs):
        return self.model(inputs)

 def model(**kwargs):
        return Model()
Owner
Elad Hoffer
Elad Hoffer
Image-Stitching - Panorama composition using SIFT Features and a custom implementaion of RANSAC algorithm

About The Project Panorama composition using SIFT Features and a custom implementaion of RANSAC algorithm (Random Sample Consensus). Author: Andreas P

Andreas Panayiotou 3 Jan 03, 2023
Pytorch implementation for the EMNLP 2020 (Findings) paper: Connecting the Dots: A Knowledgeable Path Generator for Commonsense Question Answering

Path-Generator-QA This is a Pytorch implementation for the EMNLP 2020 (Findings) paper: Connecting the Dots: A Knowledgeable Path Generator for Common

Peifeng Wang 33 Dec 05, 2022
Code for WECHSEL: Effective initialization of subword embeddings for cross-lingual transfer of monolingual language models.

WECHSEL Code for WECHSEL: Effective initialization of subword embeddings for cross-lingual transfer of monolingual language models. arXiv: https://arx

Institute of Computational Perception 45 Dec 29, 2022
The most simple and minimalistic navigation dashboard.

Navigation This project follows a goal to have simple and lightweight dashboard with different links. I use it to have my own self-hosted service dash

Yaroslav 23 Dec 23, 2022
Fast and customizable reconnaissance workflow tool based on simple YAML based DSL.

Fast and customizable reconnaissance workflow tool based on simple YAML based DSL, with support of notifications and distributed workload of that work

Américo Júnior 3 Mar 11, 2022
Dynamic Graph Event Detection

DyGED Dynamic Graph Event Detection Get Started pip install -r requirements.txt TODO Paper link to arxiv, and how to cite. Twitter Weather dataset tra

Mert Koşan 3 May 09, 2022
Pose estimation with MoveNet Lightning

Pose Estimation With MoveNet Lightning MoveNet is the TensorFlow pre-trained model that identifies 17 different key points of the human body. It is th

Yash Vora 2 Jan 04, 2022
classification task on dataset-CIFAR10,by using Tensorflow/keras

CIFAR10-Tensorflow classification task on dataset-CIFAR10,by using Tensorflow/keras 在这一个库中,我使用Tensorflow与keras框架搭建了几个卷积神经网络模型,针对CIFAR10数据集进行了训练与测试。分别使

3 Oct 17, 2021
AOT (Associating Objects with Transformers) in PyTorch

An efficient modular implementation of Associating Objects with Transformers for Video Object Segmentation in PyTorch

162 Dec 14, 2022
Generalized Data Weighting via Class-level Gradient Manipulation

Generalized Data Weighting via Class-level Gradient Manipulation This repository is the official implementation of Generalized Data Weighting via Clas

18 Nov 12, 2022
Using OpenAI's CLIP to upscale and enhance images

CLIP Upscaler and Enhancer Using OpenAI's CLIP to upscale and enhance images Based on nshepperd's JAX CLIP Guided Diffusion v2.4 Sample Results Viewpo

Tripp Lyons 5 Jun 14, 2022
Bio-OFC gym implementation and Gym-Fly environment

Bio-OFC gym implementation and Gym-Fly environment This repository includes the gym compatible implementation of the Bio-OFC algorithm from the paper

Siavash Golkar 1 Nov 16, 2021
LSTMs (Long Short Term Memory) RNN for prediction of price trends

Price Prediction with Recurrent Neural Networks LSTMs BTC-USD price prediction with deep learning algorithm. Artificial Neural Networks specifically L

5 Nov 12, 2021
Understanding and Overcoming the Challenges of Efficient Transformer Quantization

Transformer Quantization This repository contains the implementation and experiments for the paper presented in Yelysei Bondarenko1, Markus Nagel1, Ti

83 Dec 30, 2022
Code for the paper "Implicit Representations of Meaning in Neural Language Models"

Implicit Representations of Meaning in Neural Language Models Preliminaries Create and set up a conda environment as follows: conda create -n state-pr

Belinda Li 39 Nov 03, 2022
Learnable Motion Coherence for Correspondence Pruning

Learnable Motion Coherence for Correspondence Pruning Yuan Liu, Lingjie Liu, Cheng Lin, Zhen Dong, Wenping Wang Project Page Any questions or discussi

liuyuan 41 Nov 30, 2022
A Machine Teaching Framework for Scalable Recognition

MEMORABLE This repository contains the source code accompanying our ICCV 2021 paper. A Machine Teaching Framework for Scalable Recognition Pei Wang, N

2 Dec 08, 2021
WormMovementSimulation - 3D Simulation of Worm Body Movement with Neurons attached to its body

Generate 3D Locomotion Data This module is intended to create 2D video trajector

1 Aug 09, 2022
Keras-tensorflow implementation of Fully Convolutional Networks for Semantic Segmentation(Unfinished)

Keras-FCN Fully convolutional networks and semantic segmentation with Keras. Models Models are found in models.py, and include ResNet and DenseNet bas

645 Dec 29, 2022
MADT: Offline Pre-trained Multi-Agent Decision Transformer

MADT: Offline Pre-trained Multi-Agent Decision Transformer A link to our paper can be found on Arxiv. Overview Official codebase for Offline Pre-train

Linghui Meng 51 Dec 21, 2022