Sinkformers: Transformers with Doubly Stochastic Attention

Overview

Code for the paper : "Sinkformers: Transformers with Doubly Stochastic Attention"

Paper

You will find our paper here.

Compat

This package has been developed and tested with python3.8. It is therefore not guaranteed to work with earlier versions of python.

Install the repository on your machine

This package can easily be installed using pip, with the following command:

pip install numpy
pip install -e .

This will install the package and all its dependencies, listed in requirements.txt.

Each command has to be executed from the root folder sinkformers. Our code is distributed in the different repositories. For each repository, we modify the architectures proposed by replacing the SoftMax attention with a Sinkhorn attention.

Defining a toy Sinkformer for which attention matrices are doubly stochastic

For this example we use a Transformer from the nlp-tutorial library and define its Sinkformer counterpart with the argument "n_it", the number of iterations in Sinkhorn's algorithm.

cd nlp-tutorial/text-classification-transformer
import torch
from model import TransformerEncoder
n_it = 1
print('1 iteration in Sinkhorn corresponds to the original Transformer: ')
transformer = TransformerEncoder(vocab_size=1000, seq_len=512, n_layers=1,  n_heads=1, n_it=n_it, print_attention=True, pad_id=-1)
inp = torch.arange(512).repeat(5, 1)
out = transformer(inp)
n_it = 5
print('5 iteration in Sinkhorn gives a Sinkformer with perfectly doubly stochastic attention matrices: ')
sinkformer = TransformerEncoder(vocab_size=1000, seq_len=512, n_layers=1,  n_heads=1, n_it=n_it, print_attention=True, pad_id=-1)
inp = torch.arange(512).repeat(5, 1)
out = sinkformer(inp)

Then go back to the root:

cd ..
cd ..

Reproducing the experiments of the paper

Comparison of the different normalizations.

python plot_normalizations.py

ModelNet 40 classification. Code adapted from this repository. First, you need to preprocess the ModelNet40 dataset available here. Unzip it and save it under model_net_40/data. Then, preferably on multiple cpus, run

cd model_net_40
python to_h5.py
python formatting.py
cd ..
mv model_net_40/data/ModelNet40_cloud.h5 set_transformer/ModelNet40_cloud.h5
cd set_transformer
mkdir ../dataset
mv ModelNet40_cloud.h5 ../dataset/ModelNet40_cloud.h5
cd ..

Then you can train a Set Sinkformer (or Set Transformer) on ModelNet 40 with

cd set_transformer
python one_expe.py
cd ..

Arguments for one_expe.py can be accessed through

cd set_transformer
python one_expe.py --help
cd ..

Results are saved in the folder set_transformer/results. You can plot the learning curves using the script set_transformer/plot_results.py. The array iterations in the script must contains the different values for n_it used when training.

Sentiment Analysis. Code adapted from this repository. You can also train a Sinkformer for Sentiment Analysis on the IMDb Dataset with the following command (the IMDb Dataset is downloaded automatically).

cd nlp-tutorial/text-classification-transformer
python one_expe.py
cd ..
cd ..

Arguments for one_expe.py can be accessed through

cd nlp-tutorial/text-classification-transformer
python one_expe.py --help
cd ..

Results are saved in the folder nlp-tutorial/text-classification-transformer/results. You can plot the learning curves using the script nlp-tutorial/text-classification-transformer/plot_results.py. The array iterations in the script must contain the different values for "n_it" used when training.

ViT Cats and Dogs classification. Code adapted from this repository. First, you can download the data set here, unzip it and save the train and test repositories at sinkformers/vit-pytorch/examples/data. Then you can run

cd vit-pytorch
python one_expe.py
cd ..

Arguments for one_expe.py can be accessed through

cd vit-pytorch
python one_expe.py --help
cd ..

Results are saved in the folder vit-pytorch/results. You can plot the learning curves using the script vit-pytorch/plot_results.py. The array iterations in the script must contain the different values for "n_it" used when training.

ViT MNIST. The MNIST dataset will be downloaded automatically.

cd vit-pytorch
python one_expe_mnist.py
cd ..

Arguments for one_expe_mnist.py can be accessed through

cd vit-pytorch
python one_expe_mnist.py --help
cd ..

Especially, the argument "ps" is the patch size. Results are saved in the folder vit-pytorch/results_mnist. You can plot the learning curves using the script vit-pytorch/plot_results_mnist.py. The array iterations in the script must contain the different values for "n_it" used when training. The array patches_size in the script must contain the different values for "ps" used when training.

Cite

If you use this code in your project, please cite::

Michael E. Sander, Pierre Ablin, Mathieu Blondel, Gabriel Peyré
Sinkformers: Transformers with Doubly Stochastic Attention
arXiv preprint arXiv:2110.11773, 2021
https://arxiv.org/abs/2110.11773
Owner
Michael E. Sander
Michael E. Sander
A Pytorch Implementation of Source Data-free Domain Adaptation for a Faster R-CNN

A Pytorch Implementation of Source Data-free Domain Adaptation for a Faster R-CNN Please follow Faster R-CNN and DAF to complete the environment confi

2 Jan 12, 2022
[NeurIPS 2021] Garment4D: Garment Reconstruction from Point Cloud Sequences

Garment4D [PDF] | [OpenReview] | [Project Page] Overview This is the codebase for our NeurIPS 2021 paper Garment4D: Garment Reconstruction from Point

Fangzhou Hong 112 Dec 23, 2022
Network Compression via Central Filter

Network Compression via Central Filter Environments The code has been tested in the following environments: Python 3.8 PyTorch 1.8.1 cuda 10.2 torchsu

2 May 12, 2022
GAN Image Generator and Characterwise Image Recognizer with python

MODEL SUMMARY 모델의 구조는 크게 6단계로 나뉩니다. STEP 0: Input Image Predict 할 이미지를 모델에 입력합니다. STEP 1: Make Black and White Image STEP 1 은 입력받은 이미지의 글자를 흑색으로, 배경을

Juwan HAN 1 Feb 09, 2022
The NEOSSat is a dual-mission microsatellite designed to detect potentially hazardous Earth-orbit-crossing asteroids and track objects that reside in deep space

The NEOSSat is a dual-mission microsatellite designed to detect potentially hazardous Earth-orbit-crossing asteroids and track objects that reside in deep space

John Salib 2 Jan 30, 2022
kapre: Keras Audio Preprocessors

Kapre Keras Audio Preprocessors - compute STFT, ISTFT, Melspectrogram, and others on GPU real-time. Tested on Python 3.6 and 3.7 Why Kapre? vs. Pre-co

Keunwoo Choi 867 Dec 29, 2022
Official repository for the NeurIPS 2021 paper Get Fooled for the Right Reason: Improving Adversarial Robustness through a Teacher-guided curriculum Learning Approach

Get Fooled for the Right Reason Official repository for the NeurIPS 2021 paper Get Fooled for the Right Reason: Improving Adversarial Robustness throu

Sowrya Gali 1 Apr 25, 2022
Really awesome semantic segmentation

really-awesome-semantic-segmentation A list of all papers on Semantic Segmentation and the datasets they use. This site is maintained by Holger Caesar

Holger Caesar 400 Nov 28, 2022
"Learning Free Gait Transition for Quadruped Robots vis Phase-Guided Controller"

PhaseGuidedControl The current version is developed based on the old version of RaiSim series, and possibly requires further modification. It will be

X-Mechanics 12 Oct 21, 2022
Cross-Task Consistency Learning Framework for Multi-Task Learning

Cross-Task Consistency Learning Framework for Multi-Task Learning Tested on numpy(v1.19.1) opencv-python(v4.4.0.42) torch(v1.7.0) torchvision(v0.8.0)

Aki Nakano 2 Jan 08, 2022
NaturalCC is a sequence modeling toolkit that allows researchers and developers to train custom models

NaturalCC NaturalCC is a sequence modeling toolkit that allows researchers and developers to train custom models for many software engineering tasks,

159 Dec 28, 2022
ADGAN - The Implementation of paper Controllable Person Image Synthesis with Attribute-Decomposed GAN

ADGAN - The Implementation of paper Controllable Person Image Synthesis with Attribute-Decomposed GAN CVPR 2020 (Oral); Pose and Appearance Attributes Transfer;

Men Yifang 400 Dec 29, 2022
PyTorch implementation of our paper: Decoupling and Recoupling Spatiotemporal Representation for RGB-D-based Motion Recognition

Decoupling and Recoupling Spatiotemporal Representation for RGB-D-based Motion Recognition, arxiv This is a PyTorch implementation of our paper. 1. Re

DamoCV 11 Nov 19, 2022
Evolutionary Population Curriculum for Scaling Multi-Agent Reinforcement Learning

Evolutionary Population Curriculum for Scaling Multi-Agent Reinforcement Learning This is the code for implementing the MADDPG algorithm presented in

97 Dec 21, 2022
Forest R-CNN: Large-Vocabulary Long-Tailed Object Detection and Instance Segmentation (ACM MM 2020)

Forest R-CNN: Large-Vocabulary Long-Tailed Object Detection and Instance Segmentation (ACM MM 2020) Official implementation of: Forest R-CNN: Large-Vo

Jialian Wu 54 Jan 06, 2023
Official tensorflow implementation for CVPR2020 paper “Learning to Cartoonize Using White-box Cartoon Representations”

Tensorflow implementation for CVPR2020 paper “Learning to Cartoonize Using White-box Cartoon Representations”.

3.7k Dec 31, 2022
This folder contains the python code of UR5E's advanced forward kinematics model.

This folder contains the python code of UR5E's advanced forward kinematics model. By entering the angle of the joint of UR5e, the detailed coordinates of up to 48 points around the robot arm can be c

Qiang Wang 4 Sep 17, 2022
Unified MultiWOZ evaluation scripts for the context-to-response task.

MultiWOZ Context-to-Response Evaluation Standardized and easy to use Inform, Success, BLEU ~ See the paper ~ Easy-to-use scripts for standardized eval

Tomáš Nekvinda 38 Dec 13, 2022
Code for Low-Cost Algorithmic Recourse for Users With Uncertain Cost Functions

EMS-COLS-recourse Initial Code for Low-Cost Algorithmic Recourse for Users With Uncertain Cost Functions Folder structure: data folder contains raw an

Prateek Yadav 1 Nov 25, 2022
Code for NAACL 2021 full paper "Efficient Attentions for Long Document Summarization"

LongDocSum Code for NAACL 2021 paper "Efficient Attentions for Long Document Summarization" This repository contains data and models needed to reprodu

56 Jan 02, 2023