Minimal implementation of PAWS (https://arxiv.org/abs/2104.13963) in TensorFlow.

Overview

PAWS-TF 🐾

Implementation of Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples (PAWS) in TensorFlow (2.4.1).

PAWS introduces a simple way to combine a very small fraction of labeled data with a comparatively larger corpus of unlabeled data during pre-training. With its approach, it sets the state-of-the-art in semi-supervised learning (as of May 2021) beating methods like SimCLRV2, Meta Pseudo Labels that too with fewer parameters and a smaller pre-training schedule. For details, I recommend checking out the original paper as well as this blog post by the authors.

This repository implements and includes all the major bits proposed in PAWS in TensorFlow. The only major difference is that the pre-training and subsequent fine-tuning weren't run for the original number of epochs (600 and 30 respectively) to save compute. I have reused the utility components for PAWS loss from the original implementation.

Dataset ⌗

The current code works with CIFAR10 and uses 4000 labeled samples (8%) during pre-training (along with the unlabeled samples).

Features

  • Multi-crop augmentation strategy (originally introduced in SwAV)
  • Class stratified sampler (common in few-shot classification problems)
  • WarmUpCosine learning rate schedule (which is typical for self-supervised and semi-supervised pre-training)
  • LARS optimizer (comes from TensorFlow Model Garden)

The trunk portion (all, except the last classification layer) of a WideResNet-28-2 is used inside the encoder for CIFAR10. All the experimental configurations were followed from the Appendix C of the paper.

Setup and code structure 💻

A GCP VM (n1-standard-8) with a single V100 GPU was used for executing the code.

  • paws_train.py runs the pre-training as introduced in PAWS.
  • fine_tune.py runs the fine-tuning part as suggested in Appendix C. Note that this is only required for CIFAR10.
  • nn_eval.py runs the soft nearest neighbor classification on CIFAR10 test set.

Pre-training and fine-tuning total take 1.4 hours to complete. All the logs are available in misc/logs.txt. Additionally, the indices that were used to sample the labeled examples from the CIFAR10 training set are available here.

Results 📊

Pre-training

PAWS minimizes the cross-entropy loss (as well as maximizes mean-entropy) during pre-training. This is what the training plot indicates too:

To evaluate the effectivity of the pre-training, PAWS performs soft nearest neighbor classification to report the top-1 accuracy score on a given test set.

Top-1 Accuracy

This repository gets to 73.46% top-1 accuracy on the CIFAR10 test set. Again, note that I only pre-trained for 50 epochs (as opposed to 600) and fine-tuned for 10 epochs (as opposed to 30). With the original schedule this score should be around 96.0%.

In the following PCA projection plot, we see that the embeddings of images (computed after fine-tuning) of PAWS are starting to be well separated:

Notebooks 📘

There are two Colab Notebooks:

Misc ⺟

  • Model weights are available here for reproducibility.
  • With mixed-precision training, the performance can further be improved. I am open to accepting contributions that would implement mixed-precision training in the current code.

Acknowledgements

  • Huge amount of thanks to Mahmoud Assran (first author of PAWS) for patiently resolving my doubts.
  • ML-GDE program for providing GCP credit support.

Paper Citation

@misc{assran2021semisupervised,
      title={Semi-Supervised Learning of Visual Features by Non-Parametrically Predicting View Assignments with Support Samples}, 
      author={Mahmoud Assran and Mathilde Caron and Ishan Misra and Piotr Bojanowski and Armand Joulin and Nicolas Ballas and Michael Rabbat},
      year={2021},
      eprint={2104.13963},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
You might also like...
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch

alias-free-gan-pytorch Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) This implementation

Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286
Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)
PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)

Asym-Siam: On the Importance of Asymmetry for Siamese Representation Learning This is a PyTorch implementation of the Asym-Siam paper, CVPR 2022: @inp

This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).
This repository contains the code used for Predicting Patient Outcomes with Graph Representation Learning (https://arxiv.org/abs/2101.03940).

Predicting Patient Outcomes with Graph Representation Learning This repository contains the code used for Predicting Patient Outcomes with Graph Repre

https://arxiv.org/abs/2102.11005
https://arxiv.org/abs/2102.11005

LogME LogME: Practical Assessment of Pre-trained Models for Transfer Learning How to use Just feed the features f and labels y to the function, and yo

Supplementary code for the paper
Supplementary code for the paper "Meta-Solver for Neural Ordinary Differential Equations" https://arxiv.org/abs/2103.08561

Meta-Solver for Neural Ordinary Differential Equations Towards robust neural ODEs using parametrized solvers. Main idea Each Runge-Kutta (RK) solver w

Code for paper "A Critical Assessment of State-of-the-Art in Entity Alignment" (https://arxiv.org/abs/2010.16314)

A Critical Assessment of State-of-the-Art in Entity Alignment This repository contains the source code for the paper A Critical Assessment of State-of

Code for the paper: Learning Adversarially Robust Representations via Worst-Case Mutual Information Maximization (https://arxiv.org/abs/2002.11798)

Representation Robustness Evaluations Our implementation is based on code from MadryLab's robustness package and Devon Hjelm's Deep InfoMax. For all t

ISTR: End-to-End Instance Segmentation with Transformers (https://arxiv.org/abs/2105.00637)

This is the project page for the paper: ISTR: End-to-End Instance Segmentation via Transformers, Jie Hu, Liujuan Cao, Yao Lu, ShengChuan Zhang, Yan Wa

Releases(v1.0.0)
Owner
Sayak Paul
Trying to learn how machines learn.
Sayak Paul
🔀 Visual Room Rearrangement

AI2-THOR Rearrangement Challenge Welcome to the 2021 AI2-THOR Rearrangement Challenge hosted at the CVPR'21 Embodied-AI Workshop. The goal of this cha

AI2 55 Dec 22, 2022
Code Release for ICCV 2021 (oral), "AdaFit: Rethinking Learning-based Normal Estimation on Point Clouds"

AdaFit: Rethinking Learning-based Normal Estimation on Point Clouds (ICCV 2021 oral) **Project Page | Arxiv ** Runsong Zhu¹, Yuan Liu², Zhen Dong¹, Te

40 Dec 30, 2022
Using machine learning to predict and analyze high and low reader engagement for New York Times articles posted to Facebook.

How The New York Times can increase Engagement on Facebook Using machine learning to understand characteristics of news content that garners "high" Fa

Jessica Miles 0 Sep 16, 2021
Optical Character Recognition + Instance Segmentation for russian and english languages

Распознавание рукописного текста в школьных тетрадях Соревнование, проводимое в рамках олимпиады НТО, разработанное Сбером. Платформа ODS. Результаты

Gerasimov Maxim 21 Dec 19, 2022
source code for https://arxiv.org/abs/2005.11248 "Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics"

Accelerating Antimicrobial Discovery with Controllable Deep Generative Models and Molecular Dynamics This work will be published in Nature Biomedical

International Business Machines 71 Nov 15, 2022
Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields.

This repository contains the code release for Mip-NeRF: A Multiscale Representation for Anti-Aliasing Neural Radiance Fields. This implementation is written in JAX, and is a fork of Google's JaxNeRF

Google 625 Dec 30, 2022
FedGS: A Federated Group Synchronization Framework Implemented by LEAF-MX.

FedGS: Data Heterogeneity-Robust Federated Learning via Group Client Selection in Industrial IoT Preparation For instructions on generating data, plea

Lizonghang 9 Dec 22, 2022
A new framework, collaborative cascade prediction based on graph neural networks (CCasGNN) to jointly utilize the structural characteristics, sequence features, and user profiles.

CCasGNN A new framework, collaborative cascade prediction based on graph neural networks (CCasGNN) to jointly utilize the structural characteristics,

5 Apr 29, 2022
[NeurIPS 2021] Galerkin Transformer: a linear attention without softmax

[NeurIPS 2021] Galerkin Transformer: linear attention without softmax Summary A non-numerical analyst oriented explanation on Toward Data Science abou

Shuhao Cao 159 Dec 20, 2022
Doing the asl sign language classification on static images using graph neural networks.

SignLangGNN When GNNs 💜 MediaPipe. This is a starter project where I tried to implement some traditional image classification problem i.e. the ASL si

10 Nov 09, 2022
PCACE: A Statistical Approach to Ranking Neurons for CNN Interpretability

PCACE: A Statistical Approach to Ranking Neurons for CNN Interpretability PCACE is a new algorithm for ranking neurons in a CNN architecture in order

4 Jan 04, 2022
Rethinking Nearest Neighbors for Visual Classification

Rethinking Nearest Neighbors for Visual Classification arXiv Environment settings Check out scripts/env_setup.sh Setup data Download the following fin

Menglin Jia 29 Oct 11, 2022
This is an open source library implementing hyperbox-based machine learning algorithms

hyperbox-brain is a Python open source toolbox implementing hyperbox-based machine learning algorithms built on top of scikit-learn and is distributed

Complex Adaptive Systems (CAS) Lab - University of Technology Sydney 21 Dec 14, 2022
A Diagnostic Dataset for Compositional Language and Elementary Visual Reasoning

CLEVR Dataset Generation This is the code used to generate the CLEVR dataset as described in the paper: CLEVR: A Diagnostic Dataset for Compositional

Facebook Research 503 Jan 04, 2023
基于DouZero定制AI实战欢乐斗地主

DouZero_For_Happy_DouDiZhu: 将DouZero用于欢乐斗地主实战 本项目基于DouZero 环境配置请移步项目DouZero 模型默认为WP,更换模型请修改start.py中的模型路径 运行main.py即可 SL (baselines/sl/): 基于人类数据进行深度学习

1.5k Jan 08, 2023
Prometheus Exporter for data scraped from datenplattform.darmstadt.de

darmstadt-opendata-exporter Scrapes data from https://datenplattform.darmstadt.de and presents it in the Prometheus Exposition format. Pull requests w

Martin Weinelt 2 Apr 12, 2022
code for our paper "Source Data-absent Unsupervised Domain Adaptation through Hypothesis Transfer and Labeling Transfer"

SHOT++ Code for our TPAMI submission "Source Data-absent Unsupervised Domain Adaptation through Hypothesis Transfer and Labeling Transfer" that is ext

75 Dec 16, 2022
A PyTorch implementation of PointRend: Image Segmentation as Rendering

PointRend A PyTorch implementation of PointRend: Image Segmentation as Rendering [arxiv] [Official Implementation: Detectron2] This repo for Only Sema

AhnDW 336 Dec 26, 2022
Dynamic hair modeling from monocular videos using deep neural networks

Dynamic Hair Modeling The source code of the networks for our paper "Dynamic hair modeling from monocular videos using deep neural networks" (SIGGRAPH

53 Oct 18, 2022
Tensorflow2.0 🍎🍊 is delicious, just eat it! 😋😋

How to eat TensorFlow2 in 30 days ? 🔥 🔥 Click here for Chinese Version(中文版) 《10天吃掉那只pyspark》 🚀 github项目地址: https://github.com/lyhue1991/eat_pyspark

lyhue1991 9.7k Jan 01, 2023