Refactoring dalle-pytorch and taming-transformers for TPU VM

Overview

Text-to-Image Translation (DALL-E) for TPU in Pytorch

Refactoring Taming Transformers and DALLE-pytorch for TPU VM with Pytorch Lightning

Requirements

pip install -r requirements.txt

Data Preparation

Place any image dataset with ImageNet-style directory structure (at least 1 subfolder) to fit the dataset into pytorch ImageFolder.

Training VQVAEs

You can easily test main.py with randomly generated fake data.

python train_vae.py --use_tpus --fake_data

For actual training provide specific directory for train_dir, val_dir, log_dir:

python train_vae.py --use_tpus --train_dir [training_set] --val_dir [val_set] --log_dir [where to save results]

Training DALL-E

python train_dalle.py --use_tpus --train_dir [training_set] --val_dir [val_set] --log_dir [where to save results] --vae_path [pretrained vae] --bpe_path [pretrained bpe(optional)]

TODO

  • Refactor Encoder and Decoder modules for better readability
  • Refactor VQVAE2
  • Add Net2Net Conditional Transformer for conditional image generation
  • Refactor, optimize, and merge DALL-E with Net2Net Conditional Transformer
  • Add Guided Diffusion + CLIP for image refinement
  • Add VAE converter for JAX to support dalle-mini
  • Add DALL-E colab notebook
  • Add RBGumbelQuantizer
  • Add HiT

ON-GOING

  • Test large dataset loading on TPU Pods
  • Change current DALL-E code to fully support latest updates from DALLE-pytorch

DONE

  • Add VQVAE, VQGAN, and Gumbel VQVAE(Discrete VAE), Gumbel VQGAN
  • Add VQVAE2
  • Add EMA update for Vector Quantization
  • Debug VAEs (Single TPU Node, TPU Pods, GPUs)
  • Resolve SIGSEGV issue with large TPU Pods pytorch-xla #3028
  • Add DALL-E
  • Debug DALL-E (Single TPU Node, TPU Pods, GPUs)
  • Add WebDataset support
  • Add VAE Image Logger by modifying pl_bolts TensorboardGenerativeModelImageSampler()
  • Add DALLE Image Logger by modifying pl_bolts TensorboardGenerativeModelImageSampler()
  • Add automatic checkpoint saver and resume for sudden (which happens a lot) TPU restart
  • Reimplement EMA VectorQuantizer with nn.Embedding
  • Add DALL-E colab notebook by afiaka87
  • Add Normed Vector Quantizer by GallagherCommaJack
  • Resolve SIGSEGV issue with large TPU Pods pytorch-xla #3068
  • Debug WebDataset functionality

BibTeX

@misc{oord2018neural,
      title={Neural Discrete Representation Learning}, 
      author={Aaron van den Oord and Oriol Vinyals and Koray Kavukcuoglu},
      year={2018},
      eprint={1711.00937},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{razavi2019generating,
      title={Generating Diverse High-Fidelity Images with VQ-VAE-2}, 
      author={Ali Razavi and Aaron van den Oord and Oriol Vinyals},
      year={2019},
      eprint={1906.00446},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
@misc{esser2020taming,
      title={Taming Transformers for High-Resolution Image Synthesis}, 
      author={Patrick Esser and Robin Rombach and Björn Ommer},
      year={2020},
      eprint={2012.09841},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
@misc{ramesh2021zeroshot,
    title   = {Zero-Shot Text-to-Image Generation}, 
    author  = {Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
    year    = {2021},
    eprint  = {2102.12092},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
Owner
Kim, Taehoon
Research Scientist & Machine Learning Engineer.
Kim, Taehoon
PyTorch implemention of ICCV'21 paper SGPA: Structure-Guided Prior Adaptation for Category-Level 6D Object Pose Estimation

SGPA: Structure-Guided Prior Adaptation for Category-Level 6D Object Pose Estimation This is the PyTorch implemention of ICCV'21 paper SGPA: Structure

Chen Kai 24 Dec 05, 2022
A simple code to perform canny edge contrast detection on images.

CECED-Canny-Edge-Contrast-Enhanced-Detection A simple code to perform canny edge contrast detection on images. A simple code to process images using c

Happy N. Monday 3 Feb 15, 2022
Project dự đoán giá cổ phiếu bằng thuật toán LSTM gồm: code train và code demo

Web predicts stock prices using Long - Short Term Memory algorithm Give me some start please!!! User interface image: Choose: DayBegin, DayEnd, Stock

Vo Thuong Truong Nhon 8 Nov 11, 2022
Monk is a low code Deep Learning tool and a unified wrapper for Computer Vision.

Monk - A computer vision toolkit for everyone Why use Monk Issue: Want to begin learning computer vision Solution: Start with Monk's hands-on study ro

Tessellate Imaging 507 Dec 04, 2022
[NeurIPS 2021] ORL: Unsupervised Object-Level Representation Learning from Scene Images

Unsupervised Object-Level Representation Learning from Scene Images This repository contains the official PyTorch implementation of the ORL algorithm

Jiahao Xie 55 Dec 03, 2022
Codes for building and training the neural network model described in Domain-informed neural networks for interaction localization within astroparticle experiments.

Domain-informed Neural Networks Codes for building and training the neural network model described in Domain-informed neural networks for interaction

DIDACTS 0 Dec 13, 2021
InferPy: Deep Probabilistic Modeling with Tensorflow Made Easy

InferPy: Deep Probabilistic Modeling Made Easy InferPy is a high-level API for probabilistic modeling written in Python and capable of running on top

PGM-Lab 141 Oct 13, 2022
TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

FunMatch-Distillation TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A g

Sayak Paul 67 Dec 20, 2022
Face recognize and crop them

Face Recognize Cropping Module Source 아이디어 Face Alignment with OpenCV and Python Requirement 필요 라이브러리 imutil dlib python-opence (cv2) Usage 사용 방법 open

Cho Moon Gi 1 Feb 15, 2022
Where2Act: From Pixels to Actions for Articulated 3D Objects

Where2Act: From Pixels to Actions for Articulated 3D Objects The Proposed Where2Act Task. Given as input an articulated 3D object, we learn to propose

Kaichun Mo 69 Nov 28, 2022
A framework for Quantification written in Python

QuaPy QuaPy is an open source framework for quantification (a.k.a. supervised prevalence estimation, or learning to quantify) written in Python. QuaPy

41 Dec 14, 2022
💛 Code and Dataset for our EMNLP 2021 paper: "Perspective-taking and Pragmatics for Generating Empathetic Responses Focused on Emotion Causes"

Perspective-taking and Pragmatics for Generating Empathetic Responses Focused on Emotion Causes Official PyTorch implementation and EmoCause evaluatio

Hyunwoo Kim 51 Jan 06, 2023
A resource for learning about ML, DL, PyTorch and TensorFlow. Feedback always appreciated :)

A resource for learning about ML, DL, PyTorch and TensorFlow. Feedback always appreciated :)

Aladdin Persson 4.7k Jan 08, 2023
Planning from Pixels in Environments with Combinatorially Hard Search Spaces -- NeurIPS 2021

PPGS: Planning from Pixels in Environments with Combinatorially Hard Search Spaces Environment Setup We recommend pipenv for creating and managing vir

Autonomous Learning Group 11 Jun 26, 2022
ruptures: change point detection in Python

Welcome to ruptures ruptures is a Python library for off-line change point detection. This package provides methods for the analysis and segmentation

Charles T. 1.1k Jan 03, 2023
A Pytorch Implementation of [Source data‐free domain adaptation of object detector through domain

A Pytorch Implementation of Source data‐free domain adaptation of object detector through domain‐specific perturbation Please follow Faster R-CNN and

1 Dec 25, 2021
Animation of solving the traveling salesman problem to optimality using mixed-integer programming and iteratively eliminating sub tours

tsp-streamlit Animation of solving the traveling salesman problem to optimality using mixed-integer programming and iteratively eliminating sub tours.

4 Nov 05, 2022
Speckle-free Holography with Partially Coherent Light Sources and Camera-in-the-loop Calibration

Speckle-free Holography with Partially Coherent Light Sources and Camera-in-the-loop Calibration Project Page | Paper Yifan Peng*, Suyeon Choi*, Jongh

Stanford Computational Imaging Lab 19 Dec 11, 2022
Code for "Learning Skeletal Graph Neural Networks for Hard 3D Pose Estimation" ICCV'21

Skeletal-GNN Code for "Learning Skeletal Graph Neural Networks for Hard 3D Pose Estimation" ICCV'21 Various deep learning techniques have been propose

37 Oct 23, 2022
Official implementation for paper: A Latent Transformer for Disentangled Face Editing in Images and Videos.

A Latent Transformer for Disentangled Face Editing in Images and Videos Official implementation for paper: A Latent Transformer for Disentangled Face

InterDigital 108 Dec 09, 2022