Unofficial JAX implementations of Deep Learning models

Overview

JAX Models

license-shield release-shield python-shield code-style

Table of Contents
  1. About The Project
  2. Getting Started
  3. Contributing
  4. License
  5. Contact

About The Project

The JAX Models repository aims to provide open sourced JAX/Flax implementations for research papers originally without code or code written with frameworks other than JAX. The goal of this project is to make a collection of models, layers, activations and other utilities that are most commonly used for research. All papers and derived or translated code is cited in either the README or the docstrings. If you think that any citation is missed then please raise an issue.

All implementations provided here are available on Papers With Code.


Available model implementations for JAX are:
  1. MetaFormer is Actually What You Need for Vision (Weihao Yu et al., 2021)
  2. Augmenting Convolutional networks with attention-based aggregation (Hugo Touvron et al., 2021)
  3. MPViT : Multi-Path Vision Transformer for Dense Prediction (Youngwan Lee et al., 2021)
  4. MLP-Mixer: An all-MLP Architecture for Vision (Ilya Tolstikhin et al., 2021)
  5. Patches Are All You Need (Anonymous et al., 2021)
  6. SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers (Enze Xie et al., 2021)
  7. A ConvNet for the 2020s (Zhuang Liu et al., 2021)
  8. Masked Autoencoders Are Scalable Vision Learners (Kaiming He et al., 2021)

Available layers for out-of-the-box integration:
  1. DropPath (Stochastic Depth) (Gao Huang et al., 2021)
  2. Squeeze-and-Excitation Layer (Jie Hu et al. 2019)
  3. Depthwise Convolution (François Chollet, 2017)

Prerequisites

Prerequisites can be installed separately through the requirements.txt file in the main directory using:

pip install -r requirements.txt

The use of a virtual environment is highly recommended to avoid version incompatibilites.

Installation

This project is built with Python 3 for the latest JAX/Flax versions and can be directly installed via pip.

pip install jax-models

If you wish to use the latest version then you can directly clone the repository too.

git clone https://github.com/DarshanDeshpande/jax-models.git

Usage

To see all model architectures available:

from jax_models.models.model_registry import list_models
from pprint import pprint

pprint(list_models())

To load your desired model:

from jax_models.models.model_registry import load_model
load_model('mpvit-base', attach_head=True, num_classes=1000, dropout=0.1)

Contributing

Please raise an issue if any implementation gives incorrect results, crashes unexpectedly during training/inference or if any citation is missing.

You can contribute to jax_models by supporting me with compute resources or by contributing your own resources to provide pretrained weights.

If you wish to donate to this inititative then please drop me a mail here.

License

Distributed under the Apache 2.0 License. See LICENSE for more information.

Contact

Feel free to reach out for any issues or requests related to these implementations

Darshan Deshpande - Email | Twitter | LinkedIn

You might also like...
Very deep VAEs in JAX/Flax
Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX

CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on

PyTorch implementations of neural network models for keyword spotting
PyTorch implementations of neural network models for keyword spotting

Honk: CNNs for Keyword Spotting Honk is a PyTorch reimplementation of Google's TensorFlow convolutional neural networks for keyword spotting, which ac

Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning
Unofficial implementation of Proxy Anchor Loss for Deep Metric Learning

Proxy Anchor Loss for Deep Metric Learning Unofficial pytorch, tensorflow and mxnet implementations of Proxy Anchor Loss for Deep Metric Learning. Not

Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.
Time-series-deep-learning - Developing Deep learning LSTM, BiLSTM models, and NeuralProphet for multi-step time-series forecasting of stock price.

Stock Price Prediction Using Deep Learning Univariate Time Series Predicting stock price using historical data of a company using Neural networks for

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Objax Apache-2Objax (🥉19 · ⭐ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax

Objax Tutorials | Install | Documentation | Philosophy This is not an officially supported Google product. Objax is an open source machine learning fr

Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX
Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX

coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version.

JAX code for the paper
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Comments
  • Missing Axis Swap in ExtractPatches and MergePatches

    Missing Axis Swap in ExtractPatches and MergePatches

    In patch_utils.py, the modules ExtractPatches and MergePatches are missing an axis swap between the reshapes, resulting in the extracted patches becoming horizontal stripes. For example, if we follow the code in ExtractPatches:

    >>> inputs = jnp.arange(16).reshape(1, 4, 4, 1)
    >>> inputs[0, :, :, 0]
    
    DeviceArray([[ 0,  1,  2,  3],
                 [ 4,  5,  6,  7],
                 [ 8,  9, 10, 11],
                 [12, 13, 14, 15]], dtype=int32)
    
    >>> patch_size = 2
    >>> batch, height, width, channels = inputs.shape
    >>> height, width = height // patch_size, width // patch_size
    >>> x = jnp.reshape(inputs, (batch, height, patch_size, width, patch_size, channels))
    >>> x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))
    >>> x[0, 0, :]
    
    DeviceArray([0, 1, 2, 3], dtype=int32)
    

    We see that the first patch extracted is not the patch containing [0, 1, 4, 5], but the horizontal stripe [0, 1, 2, 3]. To fix this problem, we should add an axis swap. For ExtractPatches, this should be:

    batch, height, width, channels = inputs.shape
    height, width = height // patch_size, width // patch_size
    x = jnp.reshape(
        inputs, (batch, height, patch_size, width, patch_size, channels)
    )
    x = jnp.swapaxes(x, 2, 3)
    x = jnp.reshape(x, (batch, height * width, patch_size ** 2 * channels))
    

    For MergePatches, this should be:

    batch, length, _ = inputs.shape
    height = width = int(length**0.5)
    x = jnp.reshape(inputs, (batch, height, width, patch_size, patch_size, -1))
    x = jnp.swapaxes(x, 2, 3)
    x = jnp.reshape(x, (batch, height * patch_size, width * patch_size, -1))
    
    bug 
    opened by young-geng 4
  • fix convnext to make it work with jax.jit

    fix convnext to make it work with jax.jit

    Hey, first of all, thanks for the nice codebase. When doing inference using the convnext model, I noticed the following issue:

    Calling x.item() will call float(x), which breaks the jit tracer. We can remove the list comprehension in unnecessary conversion to make jax.jit work. Without jax.jit, the model is very slow for me, running with only ~30% GPU utilization (RTX 3090).

    This issue could apply to other models as well, maybe it is a good idea to include a test for applying jax.jit to each model?

    opened by maxidl 1
Releases(v0.5-van)
Owner
Helping Machines Learn Better 💻😃
A PyTorch-based library for semi-supervised learning

News If you want to join TorchSSL team, please e-mail Yidong Wang ([email protected]<

1k Jan 06, 2023
Code for paper "A Critical Assessment of State-of-the-Art in Entity Alignment" (https://arxiv.org/abs/2010.16314)

A Critical Assessment of State-of-the-Art in Entity Alignment This repository contains the source code for the paper A Critical Assessment of State-of

Max Berrendorf 16 Oct 14, 2022
Implementations of LSTM: A Search Space Odyssey variants and their training results on the PTB dataset.

An LSTM Odyssey Code for training variants of "LSTM: A Search Space Odyssey" on Fomoro. Check out the blog post. Training Install TensorFlow. Clone th

Fomoro AI 95 Apr 13, 2022
A blender add-on that automatically re-aligns wrong axis objects.

Auto Align A blender add-on that automatically re-aligns wrong axis objects. Usage There are three options available in the 3D Viewport Sidebar It

29 Nov 25, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intenti

NVIDIA Corporation 6.9k Jan 03, 2023
MAU: A Motion-Aware Unit for Video Prediction and Beyond, NeurIPS2021

MAU (NeurIPS2021) Zheng Chang, Xinfeng Zhang, Shanshe Wang, Siwei Ma, Yan Ye, Xinguang Xiang, Wen GAo. Official PyTorch Code for "MAU: A Motion-Aware

ZhengChang 20 Nov 25, 2022
unofficial pytorch implement of "Squareplus: A Softplus-Like Algebraic Rectifier"

SquarePlus (Pytorch implement) unofficial pytorch implement of "Squareplus: A Softplus-Like Algebraic Rectifier" SquarePlus Squareplus is a Softplus-L

SeeFun 3 Dec 29, 2021
PoseViz – Multi-person, multi-camera 3D human pose visualization tool built using Mayavi.

PoseViz – 3D Human Pose Visualizer Multi-person, multi-camera 3D human pose visualization tool built using Mayavi. As used in MeTRAbs visualizations.

István Sárándi 79 Dec 30, 2022
Deconfounding Temporal Autoencoder: Estimating Treatment Effects over Time Using Noisy Proxies

Deconfounding Temporal Autoencoder (DTA) This is a repository for the paper "Deconfounding Temporal Autoencoder: Estimating Treatment Effects over Tim

Milan Kuzmanovic 3 Feb 04, 2022
Oscar and VinVL

Oscar: Object-Semantics Aligned Pre-training for Vision-and-Language Tasks VinVL: Revisiting Visual Representations in Vision-Language Models Updates

Microsoft 938 Dec 26, 2022
A web application that provides real time temperature and humidity readings of a house.

About A web application which provides real time temperature and humidity readings of a house. If you're interested in the data collected so far click

Ben Thompson 3 Jan 28, 2022
SCU OlympicsRunning Baseline

Competition 1v1 running Environment check details in Jidi Competition RLChina2021智能体竞赛 做出的修改: 奖励重塑:修改了环境,重新设置了奖励的分配,使得奖励组成不只有零和博弈,还有探索环境的奖励。 算法微调:修改了官

ZiSeoi Wong 2 Nov 23, 2021
This repo includes the CUB-GHA (Gaze-based Human Attention) dataset and code of the paper "Human Attention in Fine-grained Classification".

HA-in-Fine-Grained-Classification This repo includes the CUB-GHA (Gaze-based Human Attention) dataset and code of the paper "Human Attention in Fine-g

16 Oct 29, 2022
Quick program made to generate alpha and delta tables for Hidden Markov Models

HMM_Calc Functions for generating Alpha and Delta tables from a Hidden Markov Model. Parameters: a: Matrix of transition probabilities. a[i][j] = a_{i

Adem Odza 1 Dec 04, 2021
:boar: :bear: Deep Learning based Python Library for Stock Market Prediction and Modelling

bulbea "Deep Learning based Python Library for Stock Market Prediction and Modelling." Table of Contents Installation Usage Documentation Dependencies

Achilles Rasquinha 1.8k Jan 05, 2023
Code release for BlockGAN: Learning 3D Object-aware Scene Representations from Unlabelled Images

BlockGAN Code release for BlockGAN: Learning 3D Object-aware Scene Representations from Unlabelled Images BlockGAN: Learning 3D Object-aware Scene Rep

41 May 18, 2022
ULMFiT for Genomic Sequence Data

Genomic ULMFiT This is an implementation of ULMFiT for genomics classification using Pytorch and Fastai. The model architecture used is based on the A

Karl 276 Dec 12, 2022
MIM: MIM Installs OpenMMLab Packages

MIM provides a unified API for launching and installing OpenMMLab projects and their extensions, and managing the OpenMMLab model zoo.

OpenMMLab 254 Jan 04, 2023
Implements the training, testing and editing tools for "Pluralistic Image Completion"

Pluralistic Image Completion ArXiv | Project Page | Online Demo | Video(demo) This repository implements the training, testing and editing tools for "

Chuanxia Zheng 615 Dec 08, 2022
Information Gain Filtration (IGF) is a method for filtering domain-specific data during language model finetuning. IGF shows significant improvements over baseline fine-tuning without data filtration.

Information Gain Filtration Information Gain Filtration (IGF) is a method for filtering domain-specific data during language model finetuning. IGF sho

4 Jul 28, 2022