PyTorch Personal Trainer: My framework for deep learning experiments

Related tags

Deep Learningptpt
Overview

Alex's PyTorch Personal Trainer (ptpt)

(name subject to change)

This repository contains my personal lightweight framework for deep learning projects in PyTorch.

Disclaimer: this project is very much work-in-progress. Although technically useable, it is missing many features. Nonetheless, you may find some of the design patterns and code snippets to be useful in the meantime.

Installation

Simply run python -m build in the root of the repo, then run pip install on the resulting .whl file.

No pip package yet..

Usage

Import the library as with any other python library:

from ptpt.trainer import Trainer, TrainerConfig
from ptpt.log import debug, info, warning, error, critical

The core of the library is the trainer.Trainer class. In the simplest case, it takes the following as input:

net:            a `nn.Module` that is the model we wish to train.
loss_fn:        a function that takes a `nn.Module` and a batch as input.
                it returns the loss and optionally other metrics.
train_dataset:  the training dataset.
test_dataset:   the test dataset.
cfg:            a `TrainerConfig` instance that holds all
                hyperparameters.

Once this is instantiated, starting the training loop is as simple as calling trainer.train() where trainer is an instance of Trainer.

cfg stores most of the configuration options for Trainer. See the class definition of TrainerConfig for details on all options.

Examples

An example workflow would go like this:

Define your training and test datasets:

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('../data', train=False, download=True, transform=transform)

Define your model:

# in this case, we have imported `Net` from another file
net = Net()

Define your loss function that calls net, taking the full batch as input:

# minimising classification error
def loss_fn(net, batch):
    X, y = batch
    logits = net(X)
    loss = F.nll_loss(logits, y)

    pred = logits.argmax(dim=-1, keepdim=True)
    accuracy = 100. * pred.eq(y.view_as(pred)).sum().item() / y.shape[0]
    return loss, accuracy

Optionally create a configuration object:

# see class definition for full list of parameters
cfg = TrainerConfig(
    exp_name = 'mnist-conv',
    batch_size = 64,
    learning_rate = 4e-4,
    nb_workers = 4,
    save_outputs = False,
    metric_names = ['accuracy']
)

Initialise the Trainer class:

trainer = Trainer(
    net=net,
    loss_fn=loss_fn,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    cfg=cfg
)

Call trainer.train() to begin the training loop

trainer.train() # Go!

See more examples here.

Motivation

I found myself repeating a lot of same structure in many of my deep learning projects. This project is the culmination of my efforts refining the typical structure of my projects into (what I hope to be) a wholly reusable and general-purpose library.

Additionally, there are many nice theoretical and engineering tricks that are available to deep learning researchers. Unfortunately, a lot of them are forgotten because they fall outside the typical workflow, despite them being very beneficial to include. Another goal of this project is to transparently include these tricks so they can be added and removed with minimal code change. Where it is sane to do so, some of these could be on by default.

Finally, I am guilty of forgetting to implement decent logging: both of standard output and of metrics. Logging of standard output is not hard, and is implemented using other libraries such as rich. However, metric logging is less obvious. I'd like to avoid larger dependencies such as tensorboard being an integral part of the project, so metrics will be logged to simple numpy arrays. The library will then provide functions to produce plots from these, or they can be used in another library.

TODO:

  • Make a todo.

References

Citations

Owner
Alex McKinney
Student at Durham University. I do a variety of things. I use Arch btw
Alex McKinney
Continuous Time LiDAR odometry

CT-ICP: Elastic SLAM for LiDAR sensors This repository implements the SLAM CT-ICP (see our article), a lightweight, precise and versatile pure LiDAR o

385 Dec 29, 2022
Code to accompany our paper "Continual Learning Through Synaptic Intelligence" ICML 2017

Continual Learning Through Synaptic Intelligence This repository contains code to reproduce the key findings of our path integral approach to prevent

Ganguli Lab 82 Nov 03, 2022
How to Train a GAN? Tips and tricks to make GANs work

(this list is no longer maintained, and I am not sure how relevant it is in 2020) How to Train a GAN? Tips and tricks to make GANs work While research

Soumith Chintala 10.8k Dec 31, 2022
A Rao-Blackwellized Particle Filter for 6D Object Pose Tracking

PoseRBPF: A Rao-Blackwellized Particle Filter for 6D Object Pose Tracking PoseRBPF Paper Self-supervision Paper Pose Estimation Video Robot Manipulati

NVIDIA Research Projects 107 Dec 25, 2022
Redash reset for python

redash-reset This will use a default REDASH_SECRET_KEY key of c292a0a3aa32397cdb050e233733900f this allows you to reset the password of the user ID bu

Robert Wiggins 5 Nov 14, 2022
A FAIR dataset of TCV experimental results for validating edge/divertor turbulence models.

TCV-X21 validation for divertor turbulence simulations Quick links Intro Welcome to TCV-X21. We're glad you've found us! This repository is designed t

0 Dec 18, 2021
Repo for FUZE project. I will also publish some Linux kernel LPE exploits for various real world kernel vulnerabilities here. the samples are uploaded for education purposes for red and blue teams.

Linux_kernel_exploits Some Linux kernel exploits for various real world kernel vulnerabilities here. More exploits are yet to come. This repo contains

Wei Wu 472 Dec 21, 2022
The code for our paper "NSP-BERT: A Prompt-based Zero-Shot Learner Through an Original Pre-training Task —— Next Sentence Prediction"

The code for our paper "NSP-BERT: A Prompt-based Zero-Shot Learner Through an Original Pre-training Task —— Next Sentence Prediction"

Sun Yi 201 Nov 21, 2022
Python script that analyses the given datasets and comes up with the best polynomial regression representation with the smallest polynomial degree possible

Python script that analyses the given datasets and comes up with the best polynomial regression representation with the smallest polynomial degree possible, to be the most reliable with the least com

Nikolas B Virionis 2 Aug 01, 2022
《LightXML: Transformer with dynamic negative sampling for High-Performance Extreme Multi-label Text Classification》(AAAI 2021) GitHub:

LightXML: Transformer with dynamic negative sampling for High-Performance Extreme Multi-label Text Classification

76 Dec 05, 2022
Code for the paper Learning the Predictability of the Future

Learning the Predictability of the Future Code from the paper Learning the Predictability of the Future. Website of the project in hyperfuture.cs.colu

Computer Vision Lab at Columbia University 139 Nov 18, 2022
The official implementation of Theme Transformer

Theme Transformer This is the official implementation of Theme Transformer. Checkout our demo and paper : Demo | arXiv Environment: using python versi

Ian Shih 85 Dec 08, 2022
Data, notebooks, and articles associated with the RSNA AI Deep Learning Lab at RSNA 2021

RSNA AI Deep Learning Lab 2021 Intro Welcome Deep Learners! This document provides all the information you need to participate in the RSNA AI Deep Lea

RSNA 65 Dec 16, 2022
PyTorch-Geometric Implementation of MarkovGNN: Graph Neural Networks on Markov Diffusion

MarkovGNN This is the official PyTorch-Geometric implementation of MarkovGNN paper under the title "MarkovGNN: Graph Neural Networks on Markov Diffusi

HipGraph: High-Performance Graph Analytics and Learning 6 Sep 23, 2022
Learning Dynamic Network Using a Reuse Gate Function in Semi-supervised Video Object Segmentation.

Training Script for Reuse-VOS This code implementation of CVPR 2021 paper : Learning Dynamic Network Using a Reuse Gate Function in Semi-supervised Vi

HYOJINPARK 22 Jan 01, 2023
[NeurIPS'21] "AugMax: Adversarial Composition of Random Augmentations for Robust Training" by Haotao Wang, Chaowei Xiao, Jean Kossaifi, Zhiding Yu, Animashree Anandkumar, and Zhangyang Wang.

[NeurIPS'21] "AugMax: Adversarial Composition of Random Augmentations for Robust Training" by Haotao Wang, Chaowei Xiao, Jean Kossaifi, Zhiding Yu, Animashree Anandkumar, and Zhangyang Wang.

VITA 112 Nov 07, 2022
Artificial Neural network regression model to predict the energy output in a combined cycle power plant.

Energy_Output_Predictor Artificial Neural network regression model to predict the energy output in a combined cycle power plant. Abstract Energy outpu

1 Feb 11, 2022
Share a benchmark that can easily apply reinforcement learning in Job-shop-scheduling

Gymjsp Gymjsp is an open source Python library, which uses the OpenAI Gym interface for easily instantiating and interacting with RL environments, and

134 Dec 08, 2022
Notification Triggers for Python

Notipyer Notification triggers for Python Send async email notifications via Python. Get updates/crashlogs from your scripts with ease. Installation p

Chirag Jain 17 May 16, 2022
Official code of ICCV2021 paper "Residual Attention: A Simple but Effective Method for Multi-Label Recognition"

CSRA This is the official code of ICCV 2021 paper: Residual Attention: A Simple But Effective Method for Multi-Label Recoginition Demo, Train and Vali

163 Dec 22, 2022