Convert BART models to ONNX with quantization. 3X reduction in size, and upto 3X boost in inference speed

Overview

fast-Bart

Reduction of BART model size by 3X, and boost in inference speed up to 3X

BART implementation of the fastT5 library (https://github.com/Ki6an/fastT5)

Pytorch model -> ONNX model -> Quantized ONNX model


Install

Install using requirements.txt file

git clone https://github.com/siddharth-sharma7/fast-Bart
cd fast-Bart
pip install -r requirements.txt

Usage

The export_and_get_onnx_model() method exports the given pretrained Bart model to onnx, quantizes it and runs it on the onnxruntime with default settings. The returned model from this method supports the generate() method of huggingface.

If you don't wish to quantize the model then use quantized=False in the method.

from fastBart import export_and_get_onnx_model
from transformers import AutoTokenizer

model_name = 'facebook/bart-base'
model = export_and_get_onnx_model(model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name)
input = "This is a very long sentence and needs to be summarized."
token = tokenizer(input, return_tensors='pt')

tokens = model.generate(input_ids=token['input_ids'],
               attention_mask=token['attention_mask'],
               num_beams=3)

output = tokenizer.decode(tokens.squeeze(), skip_special_tokens=True)
print(output)

to run the already exported model use get_onnx_model()

you can customize the whole pipeline as shown in the below code example:

from fastBart import (OnnxBart, get_onnx_runtime_sessions,
                    generate_onnx_representation, quantize)
from transformers import AutoTokenizer

model_or_model_path = 'facebook/bart-base'

# Step 1. convert huggingfaces bart model to onnx
onnx_model_paths = generate_onnx_representation(model_or_model_path)

# Step 2. (recommended) quantize the converted model for fast inference and to reduce model size.
# The process is slow for the decoder and init-decoder onnx files (can take up to 15 mins)
quant_model_paths = quantize(onnx_model_paths)

# step 3. setup onnx runtime
model_sessions = get_onnx_runtime_sessions(quant_model_paths)

# step 4. get the onnx model
model = OnnxBart(model_or_model_path, model_sessions)

                      ...
custom output paths

By default, fastBart creates a models-bart folder in the current directory and stores all the models. You can provide a custom path for a folder to store the exported models. And to run already exported models that are stored in a custom folder path: use get_onnx_model(onnx_models_path="/path/to/custom/folder/")

from fastBart import export_and_get_onnx_model, get_onnx_model

model_name = "facebook/bart-base"
custom_output_path = "/path/to/custom/folder/"

# 1. stores models to custom_output_path
model = export_and_get_onnx_model(model_name, custom_output_path)

# 2. run already exported models that are stored in custom path
# model = get_onnx_model(model_name, custom_output_path)

Functionalities

  • Export any pretrained Bart model to ONNX easily.
  • The exported model supports beam search and greedy search and more via generate() method.
  • Reduce the model size by 3X using quantization.
  • Up to 3X speedup compared to PyTorch execution for greedy search and 2-3X for beam search.
Owner
Siddharth Sharma
Machine learning | NLP | Computer Vision
Siddharth Sharma
SE-MSCNN: A Lightweight Multi-scaled Fusion Network for Sleep Apnea Detection Using Single-Lead ECG Signals

SE-MSCNN: A Lightweight Multi-scaled Fusion Network for Sleep Apnea Detection Using Single-Lead ECG Signals Abstract Sleep apnea (SA) is a common slee

9 Dec 21, 2022
Simple-Neural-Network From Scratch in Python

Simple-Neural-Network From Scratch in Python This is a simple Neural Network created without any Machine Learning Libraries. The only dependencies are

Aum Shah 1 Dec 28, 2021
Yolo Traffic Light Detection With Python

Yolo-Traffic-Light-Detection This project is based on detecting the Traffic light. Pretained data is used. This application entertained both real time

Ananta Raj Pant 2 Aug 08, 2022
Second-Order Neural ODE Optimizer, NeurIPS 2021 spotlight

Second-order Neural ODE Optimizer (NeurIPS 2021 Spotlight) [arXiv] ✔️ faster convergence in wall-clock time | ✔️ O(1) memory cost | ✔️ better test-tim

Guan-Horng Liu 39 Oct 22, 2022
FL-WBC: Enhancing Robustness against Model Poisoning Attacks in Federated Learning from a Client Perspective

FL-WBC: Enhancing Robustness against Model Poisoning Attacks in Federated Learning from a Client Perspective Official implementation of "FL-WBC: Enhan

Jingwei Sun 26 Nov 28, 2022
Effect of Deep Transfer and Multi task Learning on Sperm Abnormality Detection

Effect of Deep Transfer and Multi task Learning on Sperm Abnormality Detection Introduction This repository includes codes and models of "Effect of De

Amir Abbasi 5 Sep 05, 2022
Multiple Object Extraction from Aerial Imagery with Convolutional Neural Networks

This is an implementation of Volodymyr Mnih's dissertation methods on his Massachusetts road & building dataset and my original methods that are publi

Shunta Saito 255 Sep 07, 2022
Differentiable molecular simulation of proteins with a coarse-grained potential

Differentiable molecular simulation of proteins with a coarse-grained potential This repository contains the learned potential, simulation scripts and

UCL Bioinformatics Group 44 Dec 10, 2022
Image-to-Image Translation in PyTorch

CycleGAN and pix2pix in PyTorch New: Please check out contrastive-unpaired-translation (CUT), our new unpaired image-to-image translation model that e

Jun-Yan Zhu 19k Jan 07, 2023
Disease Informed Neural Networks (DINNs) — neural networks capable of learning how diseases spread, forecasting their progression, and finding their unique parameters (e.g. death rate).

DINN We introduce Disease Informed Neural Networks (DINNs) — neural networks capable of learning how diseases spread, forecasting their progression, a

19 Dec 10, 2022
ManiSkill-Learn is a framework for training agents on SAPIEN Open-Source Manipulation Skill Challenge (ManiSkill Challenge), a large-scale learning-from-demonstrations benchmark for object manipulation.

ManiSkill-Learn ManiSkill-Learn is a framework for training agents on SAPIEN Open-Source Manipulation Skill Challenge, a large-scale learning-from-dem

Hao Su's Lab, UCSD 48 Dec 30, 2022
TensorFlow CNN for fast style transfer

Fast Style Transfer in TensorFlow Add styles from famous paintings to any photo in a fraction of a second! It takes 100ms on a 2015 Titan X to style t

1 Dec 14, 2021
Make Watson Assistant send messages to your Discord Server

Make Watson Assistant send messages to your Discord Server Prerequisites Sign up for an IBM Cloud account. Fill in the required information and press

1 Jan 10, 2022
Official implementation of the RAVE model: a Realtime Audio Variational autoEncoder

RAVE: Realtime Audio Variational autoEncoder Official implementation of RAVE: A variational autoencoder for fast and high-quality neural audio synthes

ACIDS 587 Jan 01, 2023
Sync2Gen Code for ICCV 2021 paper: Scene Synthesis via Uncertainty-Driven Attribute Synchronization

Sync2Gen Code for ICCV 2021 paper: Scene Synthesis via Uncertainty-Driven Attribute Synchronization 0. Environment Environment: python 3.6 and cuda 10

Haitao Yang 62 Dec 30, 2022
Brain Tumor Detection with Tensorflow Neural Networks.

Brain-Tumor-Detection A convolutional neural network model built with Tensorflow & Keras to detect brain tumor and its different variants. Data of the

404ErrorNotFound 5 Aug 23, 2022
Unofficial pytorch implementation of paper "One-Shot Free-View Neural Talking-Head Synthesis for Video Conferencing"

One-Shot Free-View Neural Talking Head Synthesis Unofficial pytorch implementation of paper "One-Shot Free-View Neural Talking-Head Synthesis for Vide

ZLH 406 Dec 23, 2022
Fine-grained Post-training for Improving Retrieval-based Dialogue Systems - NAACL 2021

Fine-grained Post-training for Multi-turn Response Selection Implements the model described in the following paper Fine-grained Post-training for Impr

Janghoon Han 83 Dec 20, 2022
FB-tCNN for SSVEP Recognition

FB-tCNN for SSVEP Recognition Here are the codes of the tCNN and FB-tCNN in the paper "Filter Bank Convolutional Neural Network for Short Time-Window

Wenlong Ding 12 Dec 14, 2022
A scanpy extension to analyse single-cell TCR and BCR data.

Scirpy: A Scanpy extension for analyzing single-cell immune-cell receptor sequencing data Scirpy is a scalable python-toolkit to analyse T cell recept

ICBI 145 Jan 03, 2023