Nested-Co-teaching
([email protected]) Pytorch implementation of paper "Boosting Co-teaching with Compression Regularization for Label Noise"
If our project is helpful for your research, please consider citing :
@inproceedings{chen2021boosting,
title={Boosting Co-teaching with Compression Regularization for Label Noise},
author={Chen, Yingyi and Shen, Xi and Hu, Shell Xu and Suykens, Johan AK},
booktitle={CVPR Learning from Limited and Imperfect Data (L2ID) workshop},
year={2021}
}
Our model can be learnt in a single GPU GeForce GTX 1080Ti (12G), this code has been tested with Pytorch 1.7.1
Table of Content
1. Toy Results
The nested regularization allows us to learn ordered representation which would be useful to combat noisy label. In this toy example, we aim at learning a projection from X to Y with noisy pairs. By adding nested regularization, the most informative recontruction is stored in the first few channels.
Baseline, same MLP | Nested200, 1st channel |
---|---|
Nested200,first 10 channels | Nested200, first 100 channels |
2. Results on Clothing1M and Animal
[Xiao et al., 2015]
Clothing1M- We provide average accuracy as well as the standard deviation for three runs (%) on the test set of Clothing1M [Xiao et al., 2015]. Results with “*“ are either using a balanced subset or a balanced loss.
Methods | [email protected] | result_ref/download |
---|---|---|
CE | 67.2 | [Wei et al., 2020] |
F-correction [Patrini et al., 2017] | 68.9 | [Wei et al., 2020] |
Decoupling [Malach and Shalev-Shwartz, 2017] | 68.5 | [Wei et al., 2020] |
Co-teaching [Han et al., 2018] | 69.2 | [Wei et al., 2020] |
Co-teaching+ [Yu et al., 2019] | 59.3 | [Wei et al., 2020] |
JoCoR [Wei et al., 2020] | 70.3 | -- |
JO [Tanaka et al., 2018] | 72.2 | -- |
Dropout* [Srivastava et al., 2014] | 72.8 | -- |
PENCIL* [Yi and Wu, 2019] | 73.5 | -- |
MLNT [Li et al., 2019] | 73.5 | -- |
PLC* [Zhang et al., 2021] | 74.0 | -- |
DivideMix* [Li et al., 2020] | 74.8 | -- |
Nested* (Ours) | 73.1 ± 0.3 | model |
Nested + Co-teaching* (Ours) | 74.9 ± 0.2 | model |
[Song et al., 2019]
ANIMAL-10N- We provide test set accuracy (%) on ANIMAL-10N [Song et al., 2019]. We report average accuracy as well as the standard deviation for three runs.
Methods | [email protected] | result_ref/download |
---|---|---|
CE | 79.4 ± 0.1 | [Song et al., 2019] |
Dropout [Srivastava et al., 2014] | 81.3 ± 0.3 | -- |
SELFIE [Song et al., 2019] | 81.8 ± 0.1 | -- |
PLC [Zhang et al., 2021] | 83.4 ± 0.4 | -- |
Nested (Ours) | 81.3 ± 0.6 | model |
Nested + Co-teaching (Ours) | 84.1 ± 0.1 | model |
3. Datasets
Clothing1M
To download Clothing1M dataset [Xiao et al., 2015], please refer to here. Once it is downloaded, put it into ./data/
. The structure of the file should be:
./data/Clothing1M
├── noisy_train
├── clean_val
└── clean_test
Generate two random Clothing1M noisy subsets for training after unzipping :
cd data/
# generate two random subsets for training
python3 clothing1M_rand_subset.py --name noisy_rand_subtrain1 --data-dir ./Clothing1M/ --seed 123
python3 clothing1M_rand_subset.py --name noisy_rand_subtrain2 --data-dir ./Clothing1M/ --seed 321
Please refer to data/gen_data.sh for more details.
ANIMAL-10N
To download ANIMAL-10N dataset [Song et al., 2019], please refer to here. It includes one training and one test set. Once it is downloaded, put it into ./data/
. The structure of the file should be:
./data/Animal10N/
├── train
└── test
4. Train
4.1. Stage One : Training Nested Dropout Networks
We first train two Nested Dropout networks separately to provide reliable base networks for the subsequent stage. You can run the training of this stage by :
- For training networks on Clothing1M (ResNet-18). You can also train baseline/dropout networks for comparisons. More details are provided in nested/run_clothing1m.sh.
cd nested/
# train one Nested network
python3 train_resnet.py --train-dir ../data/Clothing1M/noisy_rand_subtrain1/ --val-dir ../data/Clothing1M/clean_val/ --dataset Clothing1M --arch resnet18 --lrSchedule 5 --lr 0.02 --nbEpoch 30 --batchsize 448 --nested 100 --pretrained --freeze-bn --out-dir ./checkpoints/Cloth1M_nested100_lr2e-2_bs448_freezeBN_imgnet_model1 --gpu 0
- For training networks on ANIMAL-10N (VGG-19+BN). You can also train baseline/dropout networks for comparisons. More details are provided in nested/run_animal10n.sh.
cd nested/
python3 train_vgg.py --train-dir ../data/Animal10N/train/ --val-dir ../data/Animal10N/test/ --dataset Animal10N --arch vgg19-bn --lr-gamma 0.2 --batchsize 128 --warmUpIter 6000 --nested1 100 --nested2 100 --alter-train --out-dir ./checkpoints_animal10n/Animal10N_alter_nested100_100_vgg19bn_lr0.1_warm6000_bs128_model1 --gpu 0
4.2. Stage Two : Fine-tuning with Co-teaching
In this stage, the two trained networks are further fine-tuned with Co-teaching. You can run the training of this stage by :
- For fine-tuning with Co-teaching on Clothing1M (ResNet-18) :
cd co_teaching_resnet/
python3 main.py --train-dir ../data/Clothing1M/noisy_rand_subtrain1/ --val-dir ../data/Clothing1M/clean_val/ --dataset Clothing1M --lrSchedule 5 --nGradual 0 --lr 0.002 --nbEpoch 30 --warmUpIter 0 --batchsize 448 --freeze-bn --forgetRate 0.3 --out-dir ./finetune_ckpt/Cloth1M_nested100_lr2e-3_bs448_freezeBN_fgr0.3_pre_nested100_100 --resumePthList ../nested/checkpoints/Cloth1M_nested100_lr2e-2_bs448_imgnet_freezeBN_model1_Acc0.735_K12 ../nested/checkpoints/Cloth1M_nested100_lr2e-2_bs448_imgnet_freezeBN_model2_Acc0.733_K15 --nested 100 --gpu 0
The two Nested ResNet-18 networks trained in stage one can be downloaded here: ckpt1, ckpt2. We also provide commands for training Co-teaching from scratch for comparisons in co_teaching_resnet/run_clothing1m.sh.
- For fine-tuning with Co-teaching on ANIMAL-10N (VGG-19+BN) :
cd co_teaching_vgg/
python3 main.py --train-dir ../data/Animal10N/train/ --val-dir ../data/Animal10N/test/ --dataset Animal10N --arch vgg19-bn --lrSchedule 5 --nGradual 0 --lr 0.004 --nbEpoch 30 --warmUpIter 0 --batchsize 128 --freeze-bn --forgetRate 0.2 --out-dir ./finetune_ckpt/Animal10N_alter_nested100_lr4e-3_bs128_freezeBN_fgr0.2_pre_nested100_100_nested100_100 --resumePthList ../nested/checkpoints_animal10n/new_code_nested/Animal10N_alter_nested100_100_vgg19bn_lr0.1_warm6000_bs128_model1_Acc0.803_K14 ../nested/checkpoints_animal10n/new_code_nested/Animal10N_alter_nested100_100_vgg19bn_lr0.1_warm6000_bs128_model2_Acc0.811_K14 --nested1 100 --nested2 100 --alter-train --gpu 0
The two Nested VGG-19+BN networks trained in stage one can be downloaded here: ckpt1, ckpt2. We also provide commands for training Co-teaching from scratch for comparisons in co_teaching_vgg/run_animal10n.sh.
5. Evaluation
To evaluate models' ability of combating with label noise, we compute classification accuracy on a provided clean test set.
5.1. Stage One : Nested Dropout Networks
Evaluation of networks derived from stage one are provided here :
cd nested/
# for networks on
python3 test.py --test-dir ../data/Clothing1M/clean_test/ --dataset Clothing1M --arch resnet18 --resumePthList ./checkpoints/Cloth1M_nested100_lr2e-2_bs448_imgnet_freezeBN_model1_Acc0.735_K12 --KList 12 --gpu 0
More details can be found in nested/run_test.sh. Note that "_K12" in the model's name denotes the index of the optimal K, and the optimal number of channels for the model is actually 13 (nb of optimal channels = index of channel + 1).
5.2. Stage Two : Fine-tuning Co-teaching Networks
Evaluation of networks derived from stage two are provided as follows.
- Networks trained on Clothing1M:
cd co_teaching_resnet/
python3 test.py --test-dir ../data/Clothing1M/clean_test/ --dataset Clothing1M --arch resnet18 --resumePthList ./finetune_ckpt/Cloth1M_nested100_lr2e-3_bs448_freezeBN_fgr0.3_pre_nested100_100_model2_Acc0.749_K24 --KList 24 --gpu 0
More details can be found in co_teaching_resnet/run_test.sh.
- Networks trained on ANIMAL-10N:
cd co_teaching_vgg/
python3 test.py --test-dir ../data/Animal10N/test/ --dataset Animal10N --resumePthList ./finetune_ckpt/Animal10N_nested100_lr4e-3_bs128_freezeBN_fgr0.2_pre_nested100_100_nested100_100_model1_Acc0.842_K12 --KList 12 --gpu 0
More details can be found in co_teaching_vgg/run_test.sh.