Source code of the paper Meta-learning with an Adaptive Task Scheduler.

Related tags

Deep LearningATS
Overview

ATS

About

Source code of the paper Meta-learning with an Adaptive Task Scheduler.

If you find this repository useful in your research, please cite the following paper:

@inproceedings{yao2021adaptive,
  title={Meta-learning with an Adaptive Task Scheduler},
  author={Yao, Huaxiu and Wang, Yu and Wei, Ying and Zhao, Peilin and Mahdavi, Mehrdad and Lian, Defu and Finn, Chelsea},
  booktitle={Proceedings of the Thirty-fifth Conference on Neural Information Processing Systems},
  year={2021} 
}

Miniimagenet

The processed miniimagenet dataset could be downloaded here. Assume the dataset has been downloaded and unzipped to /data/miniimagenet, which has the following file structure:

-- miniimagenet  // /data/miniimagenet
  -- miniImagenet
    -- train_task_id.pkl
    -- test_task_id.pkl
    -- mini_imagenet_train.pkl
    -- mini_imagenet_test.pkl
    -- mini_imagenet_val.pkl
    -- training_classes_20000_2_new.npz
    -- training_classes_20000_4_new.npz

Then $datadir in the following code sould be set to /data/miniimagenet.

ATS with noise = 0.6

We need to first pretrain the model with no noise. The model has been uploaded to this repo. You can also pretrain the model by yourself. The script for pretraining is as follows:
(1) 1 shot:

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.0

(2) 5 shot:

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.0

Then move the model to the current directory:
(1) 1 shot:

mv $logdir/ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_1.metalr0.001.innerlr0.01.hidden32/model20000 ./model20000_1shot

(2) 5 shot:

mv $logdir/ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_5.metalr0.001.innerlr0.01.hidden32/model10000 ./model10000_5shot

Then with this model, we could run the uniform sampling and ATS sampling. For ATS, the script is:
(1) 1 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0 --num_classes 5 --metatrain_iterations 30000 --replace 0 --noise 0.6 --logdir $logdir --sampling_method ATS --buffer_size 10  --temperature 0.1 --scheduler_lr 0.001 --warmup 2000 --pretrain_iter 20000

(2) 5 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --replace 0 --noise 0.6 --logdir $logdir --sampling_method ATS --buffer_size 10 --utility_function sample --temperature 0.1 --scheduler_lr 0.001 --warmup 2000 --pretrain_iter 10000

For uniform sampling, we need to use the validation set to finetune the model trained under uniform sampling. The training commands are:
(1) 1 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0 --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6
mkdir models
mv ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_1.metalr0.001.innerlr0.01.hidden32_noise0.6/model30000 ./models/ANIL_0.4_model_1shot
python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0 --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6 --finetune

(2) 5 shot

python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6
mkdir models  // if directory "models" does not exist
mv ANIL_pytorch.data_miniimagenetcls_5.mbs_2.ubs_5.metalr0.001.innerlr0.01.hidden32_noise0.6/model30000 ./models/ANIL_0.4_model_5shot
python3 main.py --meta_batch_size 2 --datasource miniimagenet --datadir $datadir --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --logdir $logdir --noise 0.6 --finetune

ATS with limited budgets

In this setting, pretraining is not needed. You can directly run the following code:
uniform sampling, 1 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --limit_data 1 --logdir ../train_logs --limit_classes 16

uniform sampling, 5 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --limit_data 1 --logdir ../train_logs --limit_classes 16

ATS 1 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 1 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --replace 0 --limit_data 1 --logdir ../train_logs --sampling_method ATS --buffer_size 6 --utility_function sample --temperature 1 --warmup 0 --limit_classes 16

ATS 5 shot

python3 main.py --meta_batch_size 3 --datasource miniimagenet --datadir ./miniimagenet/ --num_updates 5 --num_updates_test 10 --update_batch_size 5 --update_batch_size_eval 15 --resume 0  --num_classes 5 --metatrain_iterations 30000 --replace 0 --limit_data 1 --logdir ../train_logs --sampling_method ATS --buffer_size 6 --utility_function sample --temperature 0.1 --warmup 0 --limit_classes 16

Drug

The processed dataset could be downloaded here. Assume the dataset has been downloaded and unzipped to /data/drug which has the following structure:

-- drug  // /data/drug
  -- ci9b00375_si_001.txt  
  -- compound_fp.npy               
  -- drug_split_id_group2.pickle  
  -- drug_split_id_group6.pickle
  -- ci9b00375_si_002.txt  
  -- drug_split_id_group17.pickle  
  -- drug_split_id_group3.pickle  
  -- drug_split_id_group9.pickle
  -- ci9b00375_si_003.txt  
  -- drug_split_id_group1.pickle   
  -- drug_split_id_group4.pickle  
  -- important_readme.md

Then $datadir in the following script should be set as /data/.

ATS with noise=4.

Uniform Sampling:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --noise 4 --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --noise 4 --data_dir $datadir --train 0

ATS:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --noise 4 --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --noise 4 --data_dir $datadir --train 0

ATS with full budgets

Uniform Sampling:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --data_dir $datadir --train 0

ATS:

python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --data_dir $datadir
python3 main.py --datasource=drug --metatrain_iterations=20 --update_lr=0.005 --meta_lr=0.001 --num_updates=5 --test_num_updates=5 --trial=1 --drug_group=17 --sampling_method ATS --data_dir $datadir --train 0

For ATS, if you need to use 1 for calculating the loss as the input of the scheduler instead of 1, you can add --simple_loss after the script above.

Owner
Huaxiu Yao
Postdoctoral Scholar at [email protected]
Huaxiu Yao
Interactive Image Generation via Generative Adversarial Networks

iGAN: Interactive Image Generation via Generative Adversarial Networks Project | Youtube | Paper Recent projects: [pix2pix]: Torch implementation for

Jun-Yan Zhu 3.9k Dec 23, 2022
Development Kit for the SoccerNet Challenge

SoccerNetv2-DevKit Welcome to the SoccerNet-V2 Development Kit for the SoccerNet Benchmark and Challenge. This kit is meant as a help to get started w

Silvio Giancola 117 Dec 30, 2022
Python Assignments for the Deep Learning lectures by Andrew NG on coursera with complete submission for grading capability.

Python Assignments for the Deep Learning lectures by Andrew NG on coursera with complete submission for grading capability.

Utkarsh Agiwal 1 Feb 03, 2022
Code for GNMR in ICDE 2021

GNMR Code for GNMR in ICDE 2021 Please unzip data files in Datasets/MultiInt-ML10M first. Run labcode_preSamp.py (with graph sampling) for ECommerce-c

7 Oct 27, 2022
Image to Image translation, image generataton, few shot learning

Semi-supervised Learning for Few-shot Image-to-Image Translation [paper] Abstract: In the last few years, unpaired image-to-image translation has witn

yaxingwang 49 Nov 18, 2022
基于PaddleOCR搭建的OCR server... 离线部署用

开头说明 DangoOCR 是基于大家的 CPU处理器 来运行的,CPU处理器 的好坏会直接影响其速度, 但不会影响识别的精度 ,目前此版本识别速度可能在 0.5-3秒之间,具体取决于大家机器的配置,可以的话尽量不要在运行时开其他太多东西。需要配合团子翻译器 Ver3.6 及其以上的版本才可以使用!

胖次团子 131 Dec 25, 2022
Space Ship Simulator using python

FlyOver Basic space-ship simulator using python How to run? Just double click run.py What modules do i need? All modules that i currently using is bui

0 Oct 09, 2022
Wikidated : An Evolving Knowledge Graph Dataset of Wikidata’s Revision History

Wikidated Wikidated 1.0 is a dataset of Wikidata’s full revision history, which encodes changes between Wikidata revisions as sets of deletions and ad

Lukas Schmelzeisen 11 Aug 16, 2022
Pytorch Lightning 1.2k Jan 06, 2023
MG-GCN: Scalable Multi-GPU GCN Training Framework

MG-GCN MG-GCN: multi-GPU GCN training framework. For more information, please read our paper. After cloning our repository, run git submodule update -

Translational Data Analytics (TDA) Lab @GaTech 6 Oct 24, 2022
A forwarding MPI implementation that can use any other MPI implementation via an MPI ABI

MPItrampoline MPI wrapper library: MPI trampoline library: MPI integration tests: MPI is the de-facto standard for inter-node communication on HPC sys

Erik Schnetter 31 Dec 22, 2022
The official implementation of Autoregressive Image Generation using Residual Quantization (CVPR '22)

Autoregressive Image Generation using Residual Quantization (CVPR 2022) The official implementation of "Autoregressive Image Generation using Residual

Kakao Brain 529 Dec 30, 2022
Efficient Two-Step Networks for Temporal Action Segmentation (Neurocomputing 2021)

Efficient Two-Step Networks for Temporal Action Segmentation This repository provides a PyTorch implementation of the paper Efficient Two-Step Network

8 Apr 16, 2022
SeqAttack: a framework for adversarial attacks on token classification models

A framework for adversarial attacks against token classification models

Walter 23 Nov 25, 2022
Official implementation for “Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior”

HEP Unsupervised Low-Light Image Enhancement via Histogram Equalization Prior Implementation Python3 PyTorch=1.0 NVIDIA GPU+CUDA Training process The

FengZhang 34 Dec 04, 2022
La source de mon module 'pyfade' disponible sur Pypi.

Version: 1.2 Introduction Pyfade est un module permettant de créer des dégradés colorés. Il vous permettra de changer chaque ligne de votre texte par

Billy 20 Sep 12, 2021
Normal Learning in Videos with Attention Prototype Network

Codes_APN Official codes of CVPR21 paper: Normal Learning in Videos with Attention Prototype Network (https://arxiv.org/abs/2108.11055) Overview of ou

11 Dec 13, 2022
[NeurIPS 2020] Semi-Supervision (Unlabeled Data) & Self-Supervision Improve Class-Imbalanced / Long-Tailed Learning

Rethinking the Value of Labels for Improving Class-Imbalanced Learning This repository contains the implementation code for paper: Rethinking the Valu

Yuzhe Yang 656 Dec 28, 2022
Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation)

Recall Loss for Semantic Segmentation (This repo implements the paper: Recall Loss for Semantic Segmentation) Download Synthia dataset The model uses

32 Sep 21, 2022
SSL_SLAM2: Lightweight 3-D Localization and Mapping for Solid-State LiDAR (mapping and localization separated) ICRA 2021

SSL_SLAM2 Lightweight 3-D Localization and Mapping for Solid-State LiDAR (Intel Realsense L515 as an example) This repo is an extension work of SSL_SL

Wang Han 王晗 1.3k Jan 08, 2023