A best practice for tensorflow project template architecture.

Overview

Tensorflow Project Template

A simple and well designed structure is essential for any Deep Learning project, so after a lot of practice and contributing in tensorflow projects here's a tensorflow project template that combines simplcity, best practice for folder structure and good OOP design. The main idea is that there's much stuff you do every time you start your tensorflow project, so wrapping all this shared stuff will help you to change just the core idea every time you start a new tensorflow project.

So, here's a simple tensorflow template that help you get into your main project faster and just focus on your core (Model, Training, ...etc)

Table Of Contents

In a Nutshell

In a nutshell here's how to use this template, so for example assume you want to implement VGG model so you should do the following:

  • In models folder create a class named VGG that inherit the "base_model" class
    class VGGModel(BaseModel):
        def __init__(self, config):
            super(VGGModel, self).__init__(config)
            #call the build_model and init_saver functions.
            self.build_model() 
            self.init_saver() 
  • Override these two functions "build_model" where you implement the vgg model, and "init_saver" where you define a tensorflow saver, then call them in the initalizer.
     def build_model(self):
        # here you build the tensorflow graph of any model you want and also define the loss.
        pass
            
     def init_saver(self):
        # here you initalize the tensorflow saver that will be used in saving the checkpoints.
        self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)
  • In trainers folder create a VGG trainer that inherit from "base_train" class
    class VGGTrainer(BaseTrain):
        def __init__(self, sess, model, data, config, logger):
            super(VGGTrainer, self).__init__(sess, model, data, config, logger)
  • Override these two functions "train_step", "train_epoch" where you write the logic of the training process
    def train_epoch(self):
        """
       implement the logic of epoch:
       -loop on the number of iterations in the config and call the train step
       -add any summaries you want using the summary
        """
        pass

    def train_step(self):
        """
       implement the logic of the train step
       - run the tensorflow session
       - return any metrics you need to summarize
       """
        pass
  • In main file, you create the session and instances of the following objects "Model", "Logger", "Data_Generator", "Trainer", and config
    sess = tf.Session()
    # create instance of the model you want
    model = VGGModel(config)
    # create your data generator
    data = DataGenerator(config)
    # create tensorboard logger
    logger = Logger(sess, config)
  • Pass the all these objects to the trainer object, and start your training by calling "trainer.train()"
    trainer = VGGTrainer(sess, model, data, config, logger)

    # here you train your model
    trainer.train()

You will find a template file and a simple example in the model and trainer folder that shows you how to try your first model simply.

In Details

Project architecture

Folder structure

├──  base
│   ├── base_model.py   - this file contains the abstract class of the model.
│   └── base_train.py   - this file contains the abstract class of the trainer.
│
│
├── model               - this folder contains any model of your project.
│   └── example_model.py
│
│
├── trainer             - this folder contains trainers of your project.
│   └── example_trainer.py
│   
├──  mains              - here's the main(s) of your project (you may need more than one main).
│    └── example_main.py  - here's an example of main that is responsible for the whole pipeline.

│  
├──  data _loader  
│    └── data_generator.py  - here's the data_generator that is responsible for all data handling.
│ 
└── utils
     ├── logger.py
     └── any_other_utils_you_need

Main Components

Models


  • Base model

    Base model is an abstract class that must be Inherited by any model you create, the idea behind this is that there's much shared stuff between all models. The base model contains:

    • Save -This function to save a checkpoint to the desk.
    • Load -This function to load a checkpoint from the desk.
    • Cur_epoch, Global_step counters -These variables to keep track of the current epoch and global step.
    • Init_Saver An abstract function to initialize the saver used for saving and loading the checkpoint, Note: override this function in the model you want to implement.
    • Build_model Here's an abstract function to define the model, Note: override this function in the model you want to implement.
  • Your model

    Here's where you implement your model. So you should :

    • Create your model class and inherit the base_model class
    • override "build_model" where you write the tensorflow model you want
    • override "init_save" where you create a tensorflow saver to use it to save and load checkpoint
    • call the "build_model" and "init_saver" in the initializer.

Trainer


  • Base trainer

    Base trainer is an abstract class that just wrap the training process.

  • Your trainer

    Here's what you should implement in your trainer.

    1. Create your trainer class and inherit the base_trainer class.
    2. override these two functions "train_step", "train_epoch" where you implement the training process of each step and each epoch.

Data Loader

This class is responsible for all data handling and processing and provide an easy interface that can be used by the trainer.

Logger

This class is responsible for the tensorboard summary, in your trainer create a dictionary of all tensorflow variables you want to summarize then pass this dictionary to logger.summarize().

This class also supports reporting to Comet.ml which allows you to see all your hyper-params, metrics, graphs, dependencies and more including real-time metric. Add your API key in the configuration file:

For example: "comet_api_key": "your key here"

Comet.ml Integration

This template also supports reporting to Comet.ml which allows you to see all your hyper-params, metrics, graphs, dependencies and more including real-time metric.

Add your API key in the configuration file:

For example: "comet_api_key": "your key here"

Here's how it looks after you start training:

You can also link your Github repository to your comet.ml project for full version control. Here's a live page showing the example from this repo

Configuration

I use Json as configuration method and then parse it, so write all configs you want then parse it using "utils/config/process_config" and pass this configuration object to all other objects.

Main

Here's where you combine all previous part.

  1. Parse the config file.
  2. Create a tensorflow session.
  3. Create an instance of "Model", "Data_Generator" and "Logger" and parse the config to all of them.
  4. Create an instance of "Trainer" and pass all previous objects to it.
  5. Now you can train your model by calling "Trainer.train()"

Future Work

  • Replace the data loader part with new tensorflow dataset API.

Contributing

Any kind of enhancement or contribution is welcomed.

Acknowledgments

Thanks for my colleague Mo'men Abdelrazek for contributing in this work. and thanks for Mohamed Zahran for the review. Thanks for Jtoy for including the repo in Awesome Tensorflow.

Owner
Mahmoud Gamal Salem
MSc. in AI at university of Guelph and Vector Institute. AI intern @samsung
Mahmoud Gamal Salem
Official pytorch implementation of Rainbow Memory (CVPR 2021)

Rainbow Memory: Continual Learning with a Memory of Diverse Samples

Clova AI Research 91 Dec 17, 2022
PyTorch IPFS Dataset

PyTorch IPFS Dataset IPFSDataset(Dataset) See the jupyter notepad to see how it works and how it interacts with a standard pytorch DataLoader You need

Jake Kalstad 2 Apr 13, 2022
A PyTorch implementation of ViTGAN based on paper ViTGAN: Training GANs with Vision Transformers.

ViTGAN: Training GANs with Vision Transformers A PyTorch implementation of ViTGAN based on paper ViTGAN: Training GANs with Vision Transformers. Refer

Hong-Jia Chen 127 Dec 23, 2022
[ICCV 2021] FaPN: Feature-aligned Pyramid Network for Dense Image Prediction

FaPN: Feature-aligned Pyramid Network for Dense Image Prediction [arXiv] [Project Page] @inproceedings{ huang2021fapn, title={{FaPN}: Feature-alig

EMI-Group 175 Dec 30, 2022
Fast and robust certifiable relative pose estimation

Fast and Robust Relative Pose Estimation for Calibrated Cameras This repository contains the code for the relative pose estimation between two central

42 Dec 06, 2022
A Python package for time series augmentation

tsaug tsaug is a Python package for time series augmentation. It offers a set of augmentation methods for time series, as well as a simple API to conn

Arundo Analytics 278 Jan 01, 2023
Phonetic PosteriorGram (PPG)-Based Voice Conversion (VC)

ppg-vc Phonetic PosteriorGram (PPG)-Based Voice Conversion (VC) This repo implements different kinds of PPG-based VC models. Pretrained models. More m

Liu Songxiang 227 Dec 28, 2022
Experiments on continual learning from a stream of pretrained models.

Ex-model CL Ex-model continual learning is a setting where a stream of experts (i.e. model's parameters) is available and a CL model learns from them

Antonio Carta 6 Dec 04, 2022
Repo for my Tensorflow/Keras CV experiments. Mostly revolving around the Danbooru20xx dataset

SW-CV-ModelZoo Repo for my Tensorflow/Keras CV experiments. Mostly revolving around the Danbooru20xx dataset Framework: TF/Keras 2.7 Training SQLite D

20 Dec 27, 2022
A collection of resources and papers on Diffusion Models, a darkhorse in the field of Generative Models

This repository contains a collection of resources and papers on Diffusion Models and Score-based Models. If there are any missing valuable resources

5.1k Jan 08, 2023
Rendering Point Clouds with Compute Shaders

Compute Shader Based Point Cloud Rendering This repository contains the source code to our techreport: Rendering Point Clouds with Compute Shaders and

Markus Schütz 460 Jan 05, 2023
Text Generation by Learning from Demonstrations

Text Generation by Learning from Demonstrations The README was last updated on March 7, 2021. The repo is based on fairseq (v0.9.?). Paper arXiv Prere

38 Oct 21, 2022
ADSPM: Attribute-Driven Spontaneous Motion in Unpaired Image Translation

ADSPM: Attribute-Driven Spontaneous Motion in Unpaired Image Translation This repository provides a PyTorch implementation of ADSPM. Requirements Pyth

24 Jul 24, 2022
Paddle Graph Learning (PGL) is an efficient and flexible graph learning framework based on PaddlePaddle

DOC | Quick Start | 中文 Breaking News !! 🔥 🔥 🔥 OGB-LSC KDD CUP 2021 winners announced!! (2021.06.17) Super excited to announce our PGL team won TWO

1.5k Jan 06, 2023
Official PyTorch code of Holistic 3D Scene Understanding from a Single Image with Implicit Representation (CVPR 2021)

Implicit3DUnderstanding (Im3D) [Project Page] Holistic 3D Scene Understanding from a Single Image with Implicit Representation Cheng Zhang, Zhaopeng C

Cheng Zhang 149 Jan 08, 2023
An implementation of the proximal policy optimization algorithm

PPO Pytorch C++ This is an implementation of the proximal policy optimization algorithm for the C++ API of Pytorch. It uses a simple TestEnvironment t

Martin Huber 59 Dec 09, 2022
Neural Koopman Lyapunov Control

Neural-Koopman-Lyapunov-Control Code for our paper: Neural Koopman Lyapunov Control Requirements dReal4: v4.19.02.1 PyTorch: 1.2.0 The learning framew

Vrushabh Zinage 6 Dec 24, 2022
discovering subdomains, hidden paths, extracting unique links

python-website-crawler discovering subdomains, hidden paths, extracting unique links pip install -r requirements.txt discover subdomain: You can give

merve 4 Sep 05, 2022
Adaout is a practical and flexible regularization method with high generalization and interpretability

Adaout Adaout is a practical and flexible regularization method with high generalization and interpretability. Requirements python 3.6 (Anaconda versi

lambett 1 Feb 09, 2022
MAT: Mask-Aware Transformer for Large Hole Image Inpainting

MAT: Mask-Aware Transformer for Large Hole Image Inpainting (CVPR2022, Oral) Wenbo Li, Zhe Lin, Kun Zhou, Lu Qi, Yi Wang, Jiaya Jia [Paper] News This

254 Dec 29, 2022