TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

Overview

FunMatch-Distillation

TF2 implementation of knowledge distillation using the "function matching" hypothesis from the paper Knowledge distillation: A good teacher is patient and consistent by Beyer et al.

The techniques have been demonstrated using three datasets:

This repository provides Kaggle Kernel notebooks so that we can leverage the free TPu v3-8 to run the long training schedules. Please refer to this section.

Importance

The importance of knowledge distillation lies in its practical usefulness. With the recipes from "function matching", we can now perform knowledge distillation using a principled approach yielding student models that can actually match the performance of their teacher models. This essentially allows us to compress bigger models into (much) smaller ones thereby reducing storage costs and improving inference speed.

Key ingredients

  • No use of ground-truth labels during distillation.
  • Teacher and student should see same images during distillation as opposed to differently augmented views of same images.
  • Aggressive form of MixUp as the key augmentation recipe. MixUp is paired with "Inception-style" cropping (implemented in this script).
  • A LONG training schedule for distillation. At least 1000 epochs to get good results without overfitting. The importance of a long training schedule is paramount as studied in the paper.

Results

The table below summarizes the results of my experiments. In all cases, teacher is a BiT-ResNet101x3 model and student is a BiT-ResNet50x1. For fun, you can also try to distill into other model families. BiT stands for "Big Transfer" and it was proposed in this paper.

Dataset Teacher/Student Top-1 Acc on Test Location
Flowers102 Teacher 98.18% Link
Flowers102 Student (1000 epochs) 81.02% Link
Pet37 Teacher 90.92% Link
Pet37 Student (300 epochs) 81.3% Link
Pet37 Student (1000 epochs) 86% Link
Food101 Teacher 85.52% Link
Food101 Student (100 epochs) 76.06% Link

(Location denotes the trained model location.)

These results are consistent with Table 4 of the original paper.

It should be noted that none of the above student training regimes showed signs of overfitting. Further improvements can be done by training for longer. The authors also showed that Shampoo can get to similar performance much quicker than Adam during distillation. So, it may very well be possible to get this performance with fewer epochs with Shampoo.

A few differences from the original implementation:

  • The authors use BiT-ResNet152x2 as a teacher.
  • The mixup() variant I used will produce a pair of duplicate images if the number of images is even. Now, for 8 workers it will become 8 pairs. This may have led to the reduced performance. We can overcome this by using tf.roll(images, 1, axis=0) instead of tf.reverse in the mixup() function. Thanks to Lucas Beyer for pointing this out.

About the notebooks

All the notebooks are fully runnable on Kaggle Kernel. The only requirement is that you'd need a billing enabled GCP account to use GCS Buckets to store data.

Notebook Description Kaggle Kernel
train_bit.ipynb Shows how to train the teacher model. Link
train_bit_keras_tuner.ipynb Shows how to run hyperparameter tuning using
Keras Tuner for the teacher model.
Link
funmatch_distillation.ipynb Shows an implementation of the recipes
from "function matching".
Link

These are only demonstrated on the Pet37 dataset but will work out-of-the-box for the other datasets too.

TFRecords

For convenience, TFRecords of different datasets are provided:

Dataset TFRecords
Flowers102 Link
Pet37 Link
Food101 Link

Paper citation

@misc{beyer2021knowledge,
      title={Knowledge distillation: A good teacher is patient and consistent}, 
      author={Lucas Beyer and Xiaohua Zhai and Amélie Royer and Larisa Markeeva and Rohan Anil and Alexander Kolesnikov},
      year={2021},
      eprint={2106.05237},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgements

Huge thanks to Lucas Beyer (first author of the paper) for providing suggestions on the initial version of the implementation.

Thanks to the ML-GDE program for providing GCP credits.

Thanks to TRC for providing Cloud TPU access.

You might also like...
Implementation of momentum^2 teacher

Momentum^2 Teacher: Momentum Teacher with Momentum Statistics for Self-Supervised Learning Requirements All experiments are done with python3.6, torch

Code implementation of Data Efficient Stagewise Knowledge Distillation paper.
Code implementation of Data Efficient Stagewise Knowledge Distillation paper.

Data Efficient Stagewise Knowledge Distillation Table of Contents Data Efficient Stagewise Knowledge Distillation Table of Contents Requirements Image

The official implementation of CVPR 2021 Paper: Improving Weakly Supervised Visual Grounding by Contrastive Knowledge Distillation.

Improving Weakly Supervised Visual Grounding by Contrastive Knowledge Distillation This repository is the official implementation of CVPR 2021 paper:

PyTorch implementation of paper A Fast Knowledge Distillation Framework for Visual Recognition.
PyTorch implementation of paper A Fast Knowledge Distillation Framework for Visual Recognition.

FKD: A Fast Knowledge Distillation Framework for Visual Recognition Official PyTorch implementation of paper A Fast Knowledge Distillation Framework f

Official implementation of the paper
Official implementation of the paper "Lightweight Deep CNN for Natural Image Matting via Similarity Preserving Knowledge Distillation"

Lightweight-Deep-CNN-for-Natural-Image-Matting-via-Similarity-Preserving-Knowledge-Distillation Introduction Accepted at IEEE Signal Processing Letter

Pcos-prediction - Predicts the likelihood of Polycystic Ovary Syndrome based on patient attributes and symptoms
Pcos-prediction - Predicts the likelihood of Polycystic Ovary Syndrome based on patient attributes and symptoms

PCOS Prediction 🥼 Predicts the likelihood of Polycystic Ovary Syndrome based on

[ICLR 2021 Spotlight Oral] "Undistillable: Making A Nasty Teacher That CANNOT teach students", Haoyu Ma, Tianlong Chen, Ting-Kuei Hu, Chenyu You, Xiaohui Xie, Zhangyang Wang

Undistillable: Making A Nasty Teacher That CANNOT teach students "Undistillable: Making A Nasty Teacher That CANNOT teach students" Haoyu Ma, Tianlong

Unet network with mean teacher for altrasound image segmentation

Unet network with mean teacher for altrasound image segmentation

Details about the wide minima density hypothesis and metrics to compute width of a minima

wide-minima-density-hypothesis Details about the wide minima density hypothesis and metrics to compute width of a minima This repo presents the wide m

Releases(v4.0.0)
Owner
Sayak Paul
Trying to learn how machines learn.
Sayak Paul
The codebase for Data-driven general-purpose voice activity detection.

Data driven GPVAD Repository for the work in TASLP 2021 Voice activity detection in the wild: A data-driven approach using teacher-student training. S

Heinrich Dinkel 75 Nov 27, 2022
Joint learning of images and text via maximization of mutual information

mutual_info_img_txt Joint learning of images and text via maximization of mutual information. This repository incorporates the algorithms presented in

Ruizhi Liao 10 Dec 22, 2022
Pytorch for Segmentation

Pytorch for Semantic Segmentation This repo has been deprecated currently and I will not maintain it. Meanwhile, I strongly recommend you can refer to

ycszen 411 Nov 22, 2022
This is the repository for The Machine Learning Workshops, published by AI DOJO

This is the repository for The Machine Learning Workshops, published by AI DOJO. It contains all the workshop's code with supporting project files necessary to work through the code.

AI Dojo 12 May 06, 2022
The InterScript dataset contains interactive user feedback on scripts generated by a T5-XXL model.

Interscript The Interscript dataset contains interactive user feedback on a T5-11B model generated scripts. Dataset data.json contains the data in an

AI2 8 Dec 01, 2022
NeurIPS 2021, "Fine Samples for Learning with Noisy Labels"

[Official] FINE Samples for Learning with Noisy Labels This repository is the official implementation of "FINE Samples for Learning with Noisy Labels"

mythbuster 27 Dec 23, 2022
Saliency - Framework-agnostic implementation for state-of-the-art saliency methods (XRAI, BlurIG, SmoothGrad, and more).

Saliency Methods 🔴 Now framework-agnostic! (Example core notebook) 🔴 🔗 For further explanation of the methods and more examples of the resulting ma

PAIR code 849 Dec 27, 2022
FedMM: Saddle Point Optimization for Federated Adversarial Domain Adaptation

This repository contains the code accompanying the paper " FedMM: Saddle Point Optimization for Federated Adversarial Domain Adaptation" Paper link: R

20 Jun 29, 2022
Implementation for ACProp ( Momentum centering and asynchronous update for adaptive gradient methdos, NeurIPS 2021)

This repository contains code to reproduce results for submission NeurIPS 2021, "Momentum Centering and Asynchronous Update for Adaptive Gradient Meth

Juntang Zhuang 15 Jun 11, 2022
Multi-modal co-attention for drug-target interaction annotation and Its Application to SARS-CoV-2

CoaDTI Multi-modal co-attention for drug-target interaction annotation and Its Application to SARS-CoV-2 Abstract Environment The test was conducted i

Layne_Huang 7 Nov 14, 2022
This is the repository for paper NEEDLE: Towards Non-invertible Backdoor Attack to Deep Learning Models.

This is the repository for paper NEEDLE: Towards Non-invertible Backdoor Attack to Deep Learning Models.

1 Oct 25, 2021
Pytorch Implementations of large number classical backbone CNNs, data enhancement, torch loss, attention, visualization and some common algorithms.

Torch-template-for-deep-learning Pytorch implementations of some **classical backbone CNNs, data enhancement, torch loss, attention, visualization and

Li Shengyan 270 Dec 31, 2022
Gans-in-action - Companion repository to GANs in Action: Deep learning with Generative Adversarial Networks

GANs in Action by Jakub Langr and Vladimir Bok List of available code: Chapter 2: Colab, Notebook Chapter 3: Notebook Chapter 4: Notebook Chapter 6: C

GANs in Action 914 Dec 21, 2022
PyTorch implementation of the TTC algorithm

Trust-the-Critics This repository is a PyTorch implementation of the TTC algorithm and the WGAN misalignment experiments presented in Trust the Critic

0 Nov 29, 2021
Code repository accompanying the paper "On Adversarial Robustness: A Neural Architecture Search perspective"

On Adversarial Robustness: A Neural Architecture Search perspective Preparation: Clone the repository: https://github.com/tdchaitanya/nas-robustness.g

Chaitanya Devaguptapu 4 Nov 10, 2022
"Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices", official implementation

Moshpit SGD: Communication-Efficient Decentralized Training on Heterogeneous Unreliable Devices This repository contains the official PyTorch implemen

Yandex Research 21 Oct 18, 2022
The official github repository for Towards Continual Knowledge Learning of Language Models

Towards Continual Knowledge Learning of Language Models This is the official github repository for Towards Continual Knowledge Learning of Language Mo

Joel Jang | 장요엘 65 Jan 07, 2023
Leibniz is a python package which provide facilities to express learnable partial differential equations with PyTorch

Leibniz is a python package which provide facilities to express learnable partial differential equations with PyTorch

Beijing ColorfulClouds Technology Co.,Ltd. 16 Aug 07, 2022
[ICCV2021] Official Pytorch implementation for SDGZSL (Semantics Disentangling for Generalized Zero-Shot Learning)

Semantics Disentangling for Generalized Zero-shot Learning This is the official implementation for paper Zhi Chen, Yadan Luo, Ruihong Qiu, Zi Huang, J

25 Dec 06, 2022
It is a system used to detect bone fractures. using techniques deep learning and image processing

MohammedHussiengadalla-Intelligent-Classification-System-for-Bone-Fractures It is a system used to detect bone fractures. using techniques deep learni

Mohammed Hussien 7 Nov 11, 2022