ICCV2021, Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet

Overview

Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021

Update:

2021/03/11: update our new results. Now our T2T-ViT-14 with 21.5M parameters can reach 81.5% top1-acc with 224x224 image resolution, and 83.3% top1-acc with 384x384 resolution.

2021/02/21: T2T-ViT can be trained on most of common GPUs: 1080Ti, 2080Ti, TiTAN V, V100 stably with '--amp' (Automatic Mixed Precision). In some specifical GPU like Tesla T4, 'amp' would cause NAN loss when training T2T-ViT. If you get NAN loss in training, you can disable amp by removing '--amp' in the training scripts.

2021/01/28: release codes and upload most of the pretrained models of T2T-ViT.

Reference

If you find this repo useful, please consider citing:

@InProceedings{Yuan_2021_ICCV,
    author    = {Yuan, Li and Chen, Yunpeng and Wang, Tao and Yu, Weihao and Shi, Yujun and Jiang, Zi-Hang and Tay, Francis E.H. and Feng, Jiashi and Yan, Shuicheng},
    title     = {Tokens-to-Token ViT: Training Vision Transformers From Scratch on ImageNet},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {558-567}
}

Our codes are based on the official imagenet example by PyTorch and pytorch-image-models by Ross Wightman

1. Requirements

timm, pip install timm==0.3.4

torch>=1.4.0

torchvision>=0.5.0

pyyaml

data prepare: ImageNet with the following folder structure, you can extract imagenet by this script.

│imagenet/
├──train/
│  ├── n01440764
│  │   ├── n01440764_10026.JPEG
│  │   ├── n01440764_10027.JPEG
│  │   ├── ......
│  ├── ......
├──val/
│  ├── n01440764
│  │   ├── ILSVRC2012_val_00000293.JPEG
│  │   ├── ILSVRC2012_val_00002138.JPEG
│  │   ├── ......
│  ├── ......

2. T2T-ViT Models

Model T2T Transformer Top1 Acc #params MACs Download
T2T-ViT-14 Performer 81.5 21.5M 4.8G here
T2T-ViT-19 Performer 81.9 39.2M 8.5G here
T2T-ViT-24 Performer 82.3 64.1M 13.8G here
T2T-ViT-14, 384 Performer 83.3 21.7M here
T2T-ViT-24, Token Labeling Performer 84.2 65M here
T2T-ViT_t-14 Transformer 81.7 21.5M 6.1G here
T2T-ViT_t-19 Transformer 82.4 39.2M 9.8G here
T2T-ViT_t-24 Transformer 82.6 64.1M 15.0G here

The 'T2T-ViT-14, 384' means we train T2T-ViT-14 with image size of 384 x 384.

The 'T2T-ViT-24, Token Labeling' means we train T2T-ViT-24 with Token Labeling.

The three lite variants of T2T-ViT (Comparing with MobileNets):

Model T2T Transformer Top1 Acc #params MACs Download
T2T-ViT-7 Performer 71.7 4.3M 1.1G here
T2T-ViT-10 Performer 75.2 5.9M 1.5G here
T2T-ViT-12 Performer 76.5 6.9M 1.8G here

Usage

The way to use our pretrained T2T-ViT:

from models.t2t_vit import *
from utils import load_for_transfer_learning 

# create model
model = t2t_vit_14()

# load the pretrained weights
load_for_transfer_learning(model, /path/to/pretrained/weights, use_ema=True, strict=False, num_classes=1000)  # change num_classes based on dataset, can work for different image size as we interpolate the position embeding for different image size.

3. Validation

Test the T2T-ViT-14 (take Performer in T2T module),

Download the T2T-ViT-14, then test it by running:

CUDA_VISIBLE_DEVICES=0 python main.py path/to/data --model t2t_vit_14 -b 100 --eval_checkpoint path/to/checkpoint

The results look like:

Test: [   0/499]  Time: 2.083 (2.083)  Loss:  0.3578 (0.3578)  [email protected]: 96.0000 (96.0000)  [email protected]: 99.0000 (99.0000)
Test: [  50/499]  Time: 0.166 (0.202)  Loss:  0.5823 (0.6404)  [email protected]: 85.0000 (86.1569)  [email protected]: 99.0000 (97.5098)
...
Test: [ 499/499]  Time: 0.272 (0.172)  Loss:  1.3983 (0.8261)  [email protected]: 62.0000 (81.5000)  [email protected]: 93.0000 (95.6660)
Top-1 accuracy of the model is: 81.5%

Test the three lite variants: T2T-ViT-7, T2T-ViT-10, T2T-ViT-12 (take Performer in T2T module),

Download the T2T-ViT-7, T2T-ViT-10 or T2T-ViT-12, then test it by running:

CUDA_VISIBLE_DEVICES=0 python main.py path/to/data --model t2t_vit_7 -b 100 --eval_checkpoint path/to/checkpoint

Test the model T2T-ViT-14, 384 with 83.3% top-1 accuracy:

CUDA_VISIBLE_DEVICES=0 python main.py path/to/data --model t2t_vit_14 --img-size 384 -b 100 --eval_checkpoint path/to/T2T-ViT-14-384 

4. Train

Train the three lite variants: T2T-ViT-7, T2T-ViT-10 and T2T-ViT-12 (take Performer in T2T module):

If only 4 GPUs are available,

CUDA_VISIBLE_DEVICES=0,1,2,3 ./distributed_train.sh 4 path/to/data --model t2t_vit_7 -b 128 --lr 1e-3 --weight-decay .03 --amp --img-size 224

The top1-acc in 4 GPUs would be slightly lower than 8 GPUs (around 0.1%-0.3% lower).

If 8 GPUs are available:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 path/to/data --model t2t_vit_7 -b 64 --lr 1e-3 --weight-decay .03 --amp --img-size 224

Train the T2T-ViT-14 and T2T-ViT_t-14 (run on 4 or 8 GPUs):

CUDA_VISIBLE_DEVICES=0,1,2,3 ./distributed_train.sh 4 path/to/data --model t2t_vit_14 -b 128 --lr 1e-3 --weight-decay .05 --amp --img-size 224
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 path/to/data --model t2t_vit_14 -b 64 --lr 5e-4 --weight-decay .05 --amp --img-size 224

If you want to train our T2T-ViT on images with 384x384 resolution, please use '--img-size 384'.

Train the T2T-ViT-19, T2T-ViT-24 or T2T-ViT_t-19, T2T-ViT_t-24:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 path/to/data --model t2t_vit_19 -b 64 --lr 5e-4 --weight-decay .065 --amp --img-size 224

5. Transfer T2T-ViT to CIFAR10/CIFAR100

Model ImageNet CIFAR10 CIFAR100 #params
T2T-ViT-14 81.5 98.3 88.4 21.5M
T2T-ViT-19 81.9 98.4 89.0 39.2M

We resize CIFAR10/100 to 224x224 and finetune our pretrained T2T-ViT-14/19 to CIFAR10/100 by running:

CUDA_VISIBLE_DEVICES=0,1 transfer_learning.py --lr 0.05 --b 64 --num-classes 10 --img-size 224 --transfer-learning True --transfer-model /path/to/pretrained/T2T-ViT-19

6. Visualization

Visualize the image features of ResNet50, you can open and run the visualization_resnet.ipynb file in jupyter notebook or jupyter lab; some results are given as following:

Visualize the image features of ViT, you can open and run the visualization_vit.ipynb file in jupyter notebook or jupyter lab; some results are given as following:

Visualize attention map, you can refer to this file. A simple example by visualizing the attention map in attention block 4 and 5 is:

Comments
  • Nan during training even without '--amp'

    Nan during training even without '--amp'

    Hello! I would like to train T2t_vit_14 model on ImageNet-100 dataset and 3 gpus Quadro RTX 5000 but I have gotten Nan in loss and the error. Could you please help to run the code?

    I run the following command: CUDA_VISIBLE_DEVICES=1,2,6,7 bash distributed_train.sh 4 /data/datasets/imagenet-100/ --model T2t_vit_14 -b 128 --lr 1e-3 --weight-decay .03 --cutmix 0.0 --reprob 0.25 --img-size 224

    Some printout: 200,4.4025687376658125,4.13563300743103,7.760000015258789,24.43999983520508 201,4.374216079711914,4.1354576759338375,7.639999951171875,24.200000134277342 202,4.392171382904053,4.136218957519532,7.7399999694824215,24.439999853515626 203,4.384297768274943,4.140018928909302,7.659999923706055,24.29999981689453 204,4.371897141138713,4.14544691696167,7.619999905395508,23.91999990234375 205,4.374680519104004,4.1505038471221924,7.7600000366210935,23.93999978027344 206,4.359750032424927,4.154387146377563,7.619999943542481,24.040000201416017 207,4.37085485458374,4.158743778991699,7.820000009155273,24.019999743652345 208,4.367704391479492,4.161868629837036,7.62000002746582,23.86000018310547 209,nan,4.160795520019532,7.7200000274658205,24.01999976196289 210,nan,nan,1.0,5.0 211,nan,nan,1.0,5.0 212,nan,nan,1.0,5.0 213,nan,nan,1.0,5.0

    Error:

    File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/handle.py", line 123, in scale_loss optimizer._post_amp_backward(loss_scaler) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 249, in post_backward_no_master_weights post_backward_models_are_masters(scaler, params, stashed_grads) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 135, in post_backward_models_are_masters scale_override=(grads_have_scale, stashed_have_scale, out_scale)) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/scaler.py", line 183, in unscale_with_stashed out_scale/grads_have_scale, ZeroDivisionError: float division by zero Traceback (most recent call last): File "main.py", line 764, in <module> main() File "main.py", line 560, in main amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) File "main.py", line 637, in train_epoch loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/timm/utils/cuda.py", line 20, in __call__ scaled_loss.backward(create_graph=create_graph) File "/usr/lib/python3.6/contextlib.py", line 88, in __exit__ next(self.gen) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/handle.py", line 123, in scale_loss optimizer._post_amp_backward(loss_scaler) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 249, in post_backward_no_master_weights post_backward_models_are_masters(scaler, params, stashed_grads) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 135, in post_backward_models_are_masters scale_override=(grads_have_scale, stashed_have_scale, out_scale)) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/apex/amp/scaler.py", line 183, in unscale_with_stashed out_scale/grads_have_scale, ZeroDivisionError: float division by zero Traceback (most recent call last): File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main "__main__", mod_spec) File "/usr/lib/python3.6/runpy.py", line 85, in _run_code exec(code, run_globals) File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/torch/distributed/launch.py", line 261, in <module> main() File "/home/ekrivosheev/cv_env/lib/python3.6/site-packages/torch/distributed/launch.py", line 257, in main cmd=cmd) subprocess.CalledProcessError: Command '['/home/ekrivosheev/cv_env/bin/python', '-u', 'main.py', '--local_rank=3', '/data/datasets/imagenet-100/', '--model', 'T2t_vit_14', '-b', '128', '--lr', '1e-3', '--weight-decay', '.03', '--cutmix', '0.0', '--reprob', '0.25', '--img-size', '224']' returned non-zero exit status 1.

    opened by Evgeneus 8
  • NAN Loss for provided model

    NAN Loss for provided model

    I trained the model with the following two scripts. Both result nan loss after 1 epoch training. Any thought to address this issue?

    CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 path/to/data --model T2t_vit_7 -b 64 --lr 1e-3 --weight-decay .03 --cutmix 0.0 --reprob 0.25 --img-size 224

    CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 ./distributed_train.sh 8 path/to/data --model T2t_vit_14 -b 64 --lr 5e-4 --weight-decay .05 --img-size 224

    Training in distributed mode with multiple processes, 1 GPU per process. Process 0, total 8. Training in distributed mode with multiple processes, 1 GPU per process. Process 6, total 8. Training in distributed mode with multiple processes, 1 GPU per process. Process 7, total 8. Training in distributed mode with multiple processes, 1 GPU per process. Process 3, total 8. adopt performer encoder for tokens-to-token adopt performer encoder for tokens-to-token adopt performer encoder for tokens-to-token adopt performer encoder for tokens-to-token adopt performer encoder for tokens-to-token adopt performer encoder for tokens-to-token Training in distributed mode with multiple processes, 1 GPU per process. Process 1, total 8. adopt performer encoder for tokens-to-token Model T2t_vit_14 created, param count: 21545550 Data processing configuration for current model + dataset: input_size: (3, 224, 224) interpolation: bicubic mean: (0.485, 0.456, 0.406) std: (0.229, 0.224, 0.225) crop_pct: 0.9 Using native Torch AMP. Training in mixed precision. Using native Torch DistributedDataParallel. Scheduled epochs: 310 Train: 0 [ 0/2502 ( 0%)] Loss: 7.023479 (7.0235) Time: 3.680s, 139.14/s (3.680s, 139.14/s) LR: 1.000e-06 Data: 1.776 (1.776) Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Reducer buckets have been rebuilt in this iteration. Train: 0 [ 50/2502 ( 2%)] Loss: 6.971423 (6.9975) Time: 0.323s, 1586.02/s (0.385s, 1330.47/s) LR: 1.000e-06 Data: 0.006 (0.041) Train: 0 [ 100/2502 ( 4%)] Loss: 6.978786 (6.9912) Time: 0.305s, 1679.64/s (0.351s, 1457.64/s) LR: 1.000e-06 Data: 0.006 (0.024) Train: 0 [ 150/2502 ( 6%)] Loss: 6.975621 (6.9873) Time: 0.300s, 1705.67/s (0.340s, 1507.75/s) LR: 1.000e-06 Data: 0.005 (0.018) Train: 0 [ 200/2502 ( 8%)] Loss: 6.966157 (6.9831) Time: 0.360s, 1422.92/s (0.334s, 1530.97/s) LR: 1.000e-06 Data: 0.006 (0.015) Train: 0 [ 250/2502 ( 10%)] Loss: 6.980019 (6.9826) Time: 0.309s, 1657.73/s (0.331s, 1545.27/s) LR: 1.000e-06 Data: 0.005 (0.013) Train: 0 [ 300/2502 ( 12%)] Loss: 6.964942 (6.9801) Time: 0.327s, 1565.87/s (0.329s, 1556.59/s) LR: 1.000e-06 Data: 0.006 (0.012) Train: 0 [ 350/2502 ( 14%)] Loss: 6.957265 (6.9772) Time: 0.332s, 1541.96/s (0.327s, 1563.37/s) LR: 1.000e-06 Data: 0.005 (0.011) Train: 0 [ 400/2502 ( 16%)] Loss: 6.953742 (6.9746) Time: 0.318s, 1609.71/s (0.326s, 1570.11/s) LR: 1.000e-06 Data: 0.006 (0.011) Train: 0 [ 450/2502 ( 18%)] Loss: 6.967467 (6.9739) Time: 0.309s, 1658.46/s (0.325s, 1573.87/s) LR: 1.000e-06 Data: 0.007 (0.010) Train: 0 [ 500/2502 ( 20%)] Loss: 6.970360 (6.9736) Time: 0.322s, 1590.08/s (0.325s, 1577.36/s) LR: 1.000e-06 Data: 0.007 (0.010) Train: 0 [ 550/2502 ( 22%)] Loss: 6.931087 (6.9700) Time: 0.313s, 1637.96/s (0.324s, 1579.20/s) LR: 1.000e-06 Data: 0.005 (0.009) Train: 0 [ 600/2502 ( 24%)] Loss: 6.939621 (6.9677) Time: 0.329s, 1555.19/s (0.324s, 1580.93/s) LR: 1.000e-06 Data: 0.007 (0.009) Train: 0 [ 650/2502 ( 26%)] Loss: 6.943333 (6.9660) Time: 0.318s, 1607.70/s (0.324s, 1582.42/s) LR: 1.000e-06 Data: 0.005 (0.009) Train: 0 [ 700/2502 ( 28%)] Loss: 6.940698 (6.9643) Time: 0.316s, 1621.93/s (0.323s, 1584.56/s) LR: 1.000e-06 Data: 0.006 (0.009) Train: 0 [ 750/2502 ( 30%)] Loss: 6.941026 (6.9628) Time: 0.323s, 1584.28/s (0.323s, 1586.07/s) LR: 1.000e-06 Data: 0.006 (0.008) Train: 0 [ 800/2502 ( 32%)] Loss: 6.936088 (6.9612) Time: 0.310s, 1649.05/s (0.323s, 1587.13/s) LR: 1.000e-06 Data: 0.006 (0.008) Train: 0 [ 850/2502 ( 34%)] Loss: 6.931849 (6.9596) Time: 0.308s, 1662.24/s (0.322s, 1588.20/s) LR: 1.000e-06 Data: 0.005 (0.008) Train: 0 [ 900/2502 ( 36%)] Loss: 6.947849 (6.9590) Time: 0.320s, 1599.60/s (0.322s, 1589.06/s) LR: 1.000e-06 Data: 0.005 (0.008) Train: 0 [ 950/2502 ( 38%)] Loss: 6.928242 (6.9575) Time: 0.308s, 1659.89/s (0.322s, 1590.35/s) LR: 1.000e-06 Data: 0.005 (0.008) Train: 0 [1000/2502 ( 40%)] Loss: 6.926805 (6.9560) Time: 0.310s, 1649.80/s (0.322s, 1591.55/s) LR: 1.000e-06 Data: 0.006 (0.008) Train: 0 [1050/2502 ( 42%)] Loss: 6.950564 (6.9557) Time: 0.308s, 1660.43/s (0.322s, 1592.16/s) LR: 1.000e-06 Data: 0.005 (0.008) Train: 0 [1100/2502 ( 44%)] Loss: 6.930144 (6.9546) Time: 0.300s, 1707.17/s (0.321s, 1593.30/s) LR: 1.000e-06 Data: 0.005 (0.008) Train: 0 [1150/2502 ( 46%)] Loss: 6.919596 (6.9532) Time: 0.331s, 1547.59/s (0.321s, 1593.54/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [1200/2502 ( 48%)] Loss: 6.922656 (6.9520) Time: 0.310s, 1652.26/s (0.321s, 1594.28/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [1250/2502 ( 50%)] Loss: 6.919957 (6.9507) Time: 0.311s, 1645.52/s (0.321s, 1595.21/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1300/2502 ( 52%)] Loss: 6.930165 (6.9500) Time: 0.333s, 1539.73/s (0.321s, 1595.62/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1350/2502 ( 54%)] Loss: 6.918827 (6.9488) Time: 0.331s, 1544.88/s (0.321s, 1596.13/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1400/2502 ( 56%)] Loss: 6.923580 (6.9480) Time: 0.311s, 1644.41/s (0.321s, 1596.67/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [1450/2502 ( 58%)] Loss: 6.924307 (6.9472) Time: 0.333s, 1538.95/s (0.321s, 1597.32/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1500/2502 ( 60%)] Loss: 6.909927 (6.9460) Time: 0.309s, 1659.58/s (0.320s, 1597.74/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1550/2502 ( 62%)] Loss: 6.924455 (6.9453) Time: 0.339s, 1512.00/s (0.320s, 1598.03/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1600/2502 ( 64%)] Loss: 6.931414 (6.9449) Time: 0.315s, 1623.24/s (0.320s, 1598.55/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [1650/2502 ( 66%)] Loss: 6.916759 (6.9441) Time: 0.332s, 1542.18/s (0.320s, 1599.07/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1700/2502 ( 68%)] Loss: 6.941891 (6.9440) Time: 0.314s, 1632.83/s (0.320s, 1599.53/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [1750/2502 ( 70%)] Loss: 6.922241 (6.9434) Time: 0.312s, 1640.83/s (0.320s, 1599.91/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1800/2502 ( 72%)] Loss: 6.918221 (6.9427) Time: 0.315s, 1625.92/s (0.320s, 1600.40/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1850/2502 ( 74%)] Loss: 6.903537 (6.9417) Time: 0.322s, 1587.80/s (0.320s, 1600.59/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1900/2502 ( 76%)] Loss: 6.934650 (6.9415) Time: 0.315s, 1623.17/s (0.320s, 1601.00/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [1950/2502 ( 78%)] Loss: 6.916628 (6.9409) Time: 0.315s, 1625.91/s (0.320s, 1601.38/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2000/2502 ( 80%)] Loss: 6.907085 (6.9401) Time: 0.302s, 1695.00/s (0.320s, 1601.57/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2050/2502 ( 82%)] Loss: 6.915219 (6.9395) Time: 0.331s, 1547.05/s (0.320s, 1601.70/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2100/2502 ( 84%)] Loss: 6.920197 (6.9390) Time: 0.337s, 1520.82/s (0.320s, 1601.97/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [2150/2502 ( 86%)] Loss: 6.924037 (6.9387) Time: 0.325s, 1574.30/s (0.320s, 1602.26/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2200/2502 ( 88%)] Loss: 6.920416 (6.9383) Time: 0.300s, 1705.11/s (0.319s, 1602.63/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [2250/2502 ( 90%)] Loss: 6.898316 (6.9374) Time: 0.310s, 1649.44/s (0.319s, 1602.97/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [2300/2502 ( 92%)] Loss: 6.924686 (6.9371) Time: 0.309s, 1655.87/s (0.319s, 1602.88/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2350/2502 ( 94%)] Loss: 6.907205 (6.9365) Time: 0.326s, 1572.94/s (0.319s, 1602.90/s) LR: 1.000e-06 Data: 0.005 (0.007) /home/shawn/anaconda3/envs/deit/lib/python3.8/site-packages/PIL/TiffImagePlugin.py:788: UserWarning: Corrupt EXIF data. Expecting to read 4 bytes but only got 0. warnings.warn(str(msg)) Train: 0 [2400/2502 ( 96%)] Loss: 6.908824 (6.9359) Time: 0.310s, 1652.27/s (0.319s, 1603.15/s) LR: 1.000e-06 Data: 0.006 (0.007) Train: 0 [2450/2502 ( 98%)] Loss: 6.911987 (6.9355) Time: 0.317s, 1615.97/s (0.319s, 1603.37/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2500/2502 (100%)] Loss: 6.918730 (6.9351) Time: 0.312s, 1641.96/s (0.319s, 1603.78/s) LR: 1.000e-06 Data: 0.005 (0.007) Train: 0 [2501/2502 (100%)] Loss: 6.918357 (6.9348) Time: 0.644s, 795.44/s (0.319s, 1603.13/s) LR: 1.000e-06 Data: 0.344 (0.007) Test: [ 0/97] Time: 1.865 (1.865) Loss: 6.8164 (6.8164) [email protected]: 0.0000 ( 0.0000) [email protected]: 0.0000 ( 0.0000) Test: [ 50/97] Time: 0.100 (0.192) Loss: 6.8828 (6.8914) [email protected]: 0.0000 ( 0.0613) [email protected]: 0.0000 ( 0.5859) Test: [ 97/97] Time: 0.220 (0.162) Loss: 6.7188 (6.8880) [email protected]: 0.0000 ( 0.1820) [email protected]: 0.0000 ( 0.9180) Test (EMA): [ 0/97] Time: 2.051 (2.051) Loss: 7.0312 (7.0312) [email protected]: 0.0000 ( 0.0000) [email protected]: 1.1719 ( 1.1719) Test (EMA): [ 50/97] Time: 0.109 (0.193) Loss: 6.9570 (6.9737) [email protected]: 0.0000 ( 0.1072) [email protected]: 0.0000 ( 0.5093) Test (EMA): [ 97/97] Time: 0.224 (0.163) Loss: 7.0273 (6.9708) [email protected]: 0.0000 ( 0.0900) [email protected]: 0.0000 ( 0.5080) Current checkpoints: ('./output/train/20210219-222319-T2t_vit_14-224/checkpoint-0.pth.tar', 0.09)

    Train: 1 [ 0/2502 ( 0%)] Loss: 6.897799 (6.8978) Time: 2.695s, 189.97/s (2.695s, 189.97/s) LR: 1.673e-04 Data: 2.323 (2.323) Train: 1 [ 50/2502 ( 2%)] Loss: nan ( nan) Time: 0.279s, 1834.73/s (0.337s, 1518.12/s) LR: 1.673e-04 Data: 0.005 (0.051) Train: 1 [ 100/2502 ( 4%)] Loss: nan ( nan) Time: 0.276s, 1857.70/s (0.309s, 1655.29/s) LR: 1.673e-04 Data: 0.006 (0.029) Train: 1 [ 150/2502 ( 6%)] Loss: nan ( nan) Time: 0.289s, 1773.38/s (0.300s, 1705.98/s) LR: 1.673e-04 Data: 0.007 (0.021) Train: 1 [ 200/2502 ( 8%)] Loss: nan ( nan) Time: 0.273s, 1877.76/s (0.295s, 1733.59/s) LR: 1.673e-04 Data: 0.005 (0.018) Train: 1 [ 250/2502 ( 10%)] Loss: nan ( nan) Time: 0.268s, 1912.76/s (0.292s, 1752.17/s) LR: 1.673e-04 Data: 0.005 (0.015) Train: 1 [ 300/2502 ( 12%)] Loss: nan ( nan) Time: 0.285s, 1793.85/s (0.290s, 1764.29/s) LR: 1.673e-04 Data: 0.005 (0.014) Train: 1 [ 350/2502 ( 14%)] Loss: nan ( nan) Time: 0.281s, 1819.69/s (0.289s, 1769.46/s) LR: 1.673e-04 Data: 0.006 (0.013) Train: 1 [ 400/2502 ( 16%)] Loss: nan ( nan) Time: 0.268s, 1908.61/s (0.290s, 1767.59/s) LR: 1.673e-04 Data: 0.005 (0.012) Train: 1 [ 450/2502 ( 18%)] Loss: nan ( nan) Time: 0.287s, 1783.58/s (0.289s, 1773.71/s) LR: 1.673e-04 Data: 0.006 (0.011) Train: 1 [ 500/2502 ( 20%)] Loss: nan ( nan) Time: 0.285s, 1796.56/s (0.288s, 1778.22/s) LR: 1.673e-04 Data: 0.005 (0.011) Train: 1 [ 550/2502 ( 22%)] Loss: nan ( nan) Time: 0.280s, 1825.68/s (0.287s, 1781.91/s) LR: 1.673e-04 Data: 0.005 (0.010) Train: 1 [ 600/2502 ( 24%)] Loss: nan ( nan) Time: 0.275s, 1859.97/s (0.287s, 1785.50/s) LR: 1.673e-04 Data: 0.009 (0.010) Train: 1 [ 650/2502 ( 26%)] Loss: nan ( nan) Time: 0.278s, 1841.99/s (0.286s, 1788.40/s) LR: 1.673e-04 Data: 0.005 (0.010) Train: 1 [ 700/2502 ( 28%)] Loss: nan ( nan) Time: 0.275s, 1860.43/s (0.286s, 1790.68/s) LR: 1.673e-04 Data: 0.006 (0.009) Train: 1 [ 750/2502 ( 30%)] Loss: nan ( nan) Time: 0.287s, 1784.59/s (0.286s, 1792.93/s) LR: 1.673e-04 Data: 0.006 (0.009) Train: 1 [ 800/2502 ( 32%)] Loss: nan ( nan) Time: 0.277s, 1848.72/s (0.285s, 1794.68/s) LR: 1.673e-04 Data: 0.006 (0.009) Train: 1 [ 850/2502 ( 34%)] Loss: nan ( nan) Time: 0.286s, 1792.44/s (0.285s, 1795.76/s) LR: 1.673e-04 Data: 0.006 (0.009) Train: 1 [ 900/2502 ( 36%)] Loss: nan ( nan) Time: 0.279s, 1833.06/s (0.285s, 1795.15/s) LR: 1.673e-04 Data: 0.006 (0.008) Train: 1 [ 950/2502 ( 38%)] Loss: nan ( nan) Time: 0.277s, 1847.88/s (0.285s, 1795.23/s) LR: 1.673e-04 Data: 0.005 (0.008) Train: 1 [1000/2502 ( 40%)] Loss: nan ( nan) Time: 0.286s, 1789.41/s (0.285s, 1796.69/s) LR: 1.673e-04 Data: 0.005 (0.008) Train: 1 [1050/2502 ( 42%)] Loss: nan ( nan) Time: 0.277s, 1848.11/s (0.285s, 1798.21/s) LR: 1.673e-04 Data: 0.005 (0.008) Train: 1 [1100/2502 ( 44%)] Loss: nan ( nan) Time: 0.284s, 1799.80/s (0.285s, 1799.40/s) LR: 1.673e-04 Data: 0.005 (0.008) Train: 1 [1150/2502 ( 46%)] Loss: nan ( nan) Time: 0.285s, 1799.56/s (0.284s, 1800.19/s) LR: 1.673e-04 Data: 0.006 (0.008) Train: 1 [1200/2502 ( 48%)] Loss: nan ( nan) Time: 0.294s, 1742.39/s (0.284s, 1801.04/s) LR: 1.673e-04 Data: 0.006 (0.008) Train: 1 [1250/2502 ( 50%)] Loss: nan ( nan) Time: 0.285s, 1796.71/s (0.284s, 1802.07/s) LR: 1.673e-04 Data: 0.005 (0.008) Train: 1 [1300/2502 ( 52%)] Loss: nan ( nan) Time: 0.274s, 1870.25/s (0.284s, 1802.85/s) LR: 1.673e-04 Data: 0.006 (0.008) Train: 1 [1350/2502 ( 54%)] Loss: nan ( nan) Time: 0.271s, 1886.95/s (0.284s, 1803.84/s) LR: 1.673e-04 Data: 0.006 (0.008) Train: 1 [1400/2502 ( 56%)] Loss: nan ( nan) Time: 0.288s, 1776.96/s (0.284s, 1804.18/s) LR: 1.673e-04 Data: 0.006 (0.008) Train: 1 [1450/2502 ( 58%)] Loss: nan ( nan) Time: 0.282s, 1818.29/s (0.284s, 1802.31/s) LR: 1.673e-04 Data: 0.006 (0.007) Train: 1 [1500/2502 ( 60%)] Loss: nan ( nan) Time: 0.262s, 1952.51/s (0.284s, 1803.01/s) LR: 1.673e-04 Data: 0.007 (0.007)

    opened by yix081 8
  • Cannot get the reported MACs in paper

    Cannot get the reported MACs in paper

    Hi,

    I've calcuated the MACs of the model, and found it is not consistent with the paper reported.

    If I understand correctly, The T2T-ViTt-14 model would have this T2T module and extra 14 original ViT blocks. The MACs for that 14 depth-ViT blocks would be 0.321 x 14 = 4.494 G.

    For the first token-to-token attention, you will calculate attention of 56x56 tokens, which is 3136 tokens, with feature dim=64. Consider only getting the affinity matrix and getting the value would have MACs: 3136 * 3136 * 64 + 3136 * 3136 * 64 = 1.26 G, which already adds up to 5.754 G, higher than the reported 5.2G. My full calculation of the T2T-ViTt-14 model would be 6.09 G MACs. Can you tell me if I miscalculate something?

    Best, Haiping

    opened by happywu 5
  • How to visualize the attention map of t2t-vit?

    How to visualize the attention map of t2t-vit?

    The refered example file was using vit, which provided attention weights output. But the t2t-vit model only have logists output, so i can't reuse their code. I really wish a more detaild way to visualize the attention map of t2t-vit. It's really important to me. thx a lot

    opened by Salen158 4
  • Hard to train

    Hard to train

    Hi.Dear @yuanli2333 I try to use t2t-vit for downstream sem.seg tasks. However ,as we know Vit backbone it's very hard to train. The default settings of train epochs in ImageNet is 300. I have try two different network structure with t2t-vit 14. The 1st train with SGD optimizer and cosine-warmup.After 120 epochs, the loss curves as follow QQ截图20210327144126 The 2nd train with Adam optimizer and cosine-warmup.(not use timm.create_optimizer to set adamw sice i need to set different lr for different blocks.) The set of lr is similar to your setting.After 40 epochs, the loss curves also as follow. QQ截图20210327143246 It's look like that the 2nd training much better and the loss is still in decrease.But I'm not sure is it on the right path.(according to my calculation, it will take 6 days to train 300 epochs with a single 3090 GPU, so I don't have time to trial & error:sob::sob::sob:) Could you show me your training log as a reference or give me some advice? Thank you very much.

    opened by huixiancheng 3
  • small question about lr_scheduler

    small question about lr_scheduler

    Thanks for the opensource code!! Could you tell me the meaning of metric in lr step? https://github.com/yitu-opensource/T2T-ViT/blob/f436fe4043069989ec5e0c2d07407b6d898493a7/main.py#L577-L579 https://github.com/yitu-opensource/T2T-ViT/blob/f436fe4043069989ec5e0c2d07407b6d898493a7/main.py#L688-L689

    In my understanding.Look like in timm it's don't have special meaning.

    opened by huixiancheng 3
  • Questions about feature visualization of vit_large_patch16_384

    Questions about feature visualization of vit_large_patch16_384

    For Figure 2. in the paper, I tried to plot the feature visualization of T2T-ViT-24 trained on ImageNet using the code provided in visualization_vit.ipynb and the same input image “dog.png”. The input image was resized to (1024, 1024), and I found the feature maps have the size of (64, 64). However, the plotted feature maps look very different from those in your paper. The following figure is my feature maps from T2T-ViT-24 block 1:

    layer_0

    There are lots of noises in my feature maps and the low-level structure features such as edges and lines are not clear. I’m not sure what caused the discrepancy. Also, the resolution of feature maps in the paper looks higher that 64*64. Could you please provide more instructions on feature visualization of this model? That would help me understand your work better! Thank you in advance!

    opened by Hongyu-He 3
  • Very low performance within the first 10 epochs

    Very low performance within the first 10 epochs

    @yuanli2333 Really impressive results with fully transformer architecture!

    I have tried to reproduce the results of T2t_vit_t_14, T2t_vit_t_19, and T2t_vit_t_24 while finding their top-1 accuracy is very low within the first few epochs:

    # results based on T2t_vit_t_14
    epoch,train_loss,eval_loss,eval_top1,eval_top5
    0,6.9310056154544535,6.97243375,0.09599999996185303,0.4740000003051758
    1,6.575615681134737,6.9578225,0.104,0.512
    2,6.16587602175199,6.94662625,0.116,0.5079999998855591
    3,5.808463848554171,6.9398025,0.114,0.572
    4,5.42472545000223,6.93104625,0.156,0.632
    5,5.137583054029024,6.92024125,0.142,0.77
    6,4.931810901715205,6.90476,0.194,0.8419999999809266
    7,4.7973018517861,6.874345,0.246,1.044
    8,4.646611140324519,6.82100625,0.358,1.554
    

    where we can see that the top1 accuracy is only 0.358 at the 8-th epoch. I am wondering whether this result is reasonable?

    Thanks!

    opened by PkuRainBow 3
  • The released models seem broken and cannot be opened?

    The released models seem broken and cannot be opened?

    Thanks a lot for sharing your codes. But it seems the T2T-ViT Models are broken and I can't open them. Would you like to upload them again? Many thanks :-)

    opened by QiushiYang 3
  • No forward with  Token_performer?

    No forward with Token_performer?

    I try to run the T2T-ViT, but meet a error: there is no forward on the Token_performer. So will you provide the forward part code for Token_performer? Hoping for your reply.

    opened by GuideWsp 3
  • could you share the log of training T2T-ViT_t-24 and T2T-ViT_t-19?

    could you share the log of training T2T-ViT_t-24 and T2T-ViT_t-19?

    Thanks for your wonderful work, and I wonder whether you could share the log of training T2T-ViT_t-24 and T2T-ViT_t-19? As I want to compare my method with yours from the training point. Many thanks.

    opened by wangpichao 2
  •  Input size is not a square, what should I do in this line?

    Input size is not a square, what should I do in this line?

    Thanks to your excellent work! There has a tensor which size is (3, 56, 112), What should I do in this line to modify ? self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 sfot split, stride are 4,2,2 seperately

    opened by JerryKingQAQ 0
  • colab

    colab

    Hi I want to train your model on my dataset that made from 15000 sample train image and 120000 train image.also I want to use google colab,is it possible to train this model in colab?How do i run your github model in colab?is it better to train model in sample train image or train image?

    opened by Maryam-Hosseini 0
  • Downloads all say

    Downloads all say "tar: This does not look like a tar archive" when I try to un-tar

    I've tried several tools, and downloaded all the files just to see if maybe the first one was corrupt, but every file I try I get this error (from multiple archive utilities including classic tar).

    Anyone else seeing this, or do I just have a problem on my end?

    opened by lowfuel 2
  • [maybe a bug] loss nan

    [maybe a bug] loss nan

    https://github.com/yitu-opensource/T2T-ViT/blob/main/models/token_performer.py#L18 My code has turned on fp16, so the 1e-8 on this line to prevent division by 0 is not enough for my code... the loss of the network calculation appears nan due to this code : https://github.com/yitu-opensource/T2T-ViT/blob/main/models/token_performer.py#L50

    opened by xmy0916 4
Owner
YITUTech
YITUTech
A Python library that enables ML teams to share, load, and transform data in a collaborative, flexible, and efficient way :chestnut:

Squirrel Core Share, load, and transform data in a collaborative, flexible, and efficient way What is Squirrel? Squirrel is a Python library that enab

Merantix Momentum 249 Dec 07, 2022
source code and pre-trained/fine-tuned checkpoint for NAACL 2021 paper LightningDOT

LightningDOT: Pre-training Visual-Semantic Embeddings for Real-Time Image-Text Retrieval This repository contains source code and pre-trained/fine-tun

Siqi 65 Dec 26, 2022
MACE is a deep learning inference framework optimized for mobile heterogeneous computing platforms.

Documentation | FAQ | Release Notes | Roadmap | MACE Model Zoo | Demo | Join Us | 中文 Mobile AI Compute Engine (or MACE for short) is a deep learning i

Xiaomi 4.7k Dec 29, 2022
Experiments for Neural Flows paper

Neural Flows: Efficient Alternative to Neural ODEs [arxiv] TL;DR: We directly model the neural ODE solutions with neural flows, which is much faster a

54 Dec 07, 2022
PyTorch code for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation

Salesforce 1.3k Dec 31, 2022
AgeGuesser: deep learning based age estimation system. Powered by EfficientNet and Yolov5

AgeGuesser AgeGuesser is an end-to-end, deep-learning based Age Estimation system, presented at the CAIP 2021 conference. You can find the related pap

5 Nov 10, 2022
A convolutional recurrent neural network for classifying A/B phases in EEG signals recorded for sleep analysis.

CAP-Classification-CRNN A deep learning model based on Inception modules paired with gated recurrent units (GRU) for the classification of CAP phases

Apurva R. Umredkar 2 Nov 25, 2022
Semi-supervised semantic segmentation needs strong, varied perturbations

Semi-supervised semantic segmentation using CutMix and Colour Augmentation Implementations of our papers: Semi-supervised semantic segmentation needs

146 Dec 20, 2022
Unofficial PyTorch implementation of Attention Free Transformer (AFT) layers by Apple Inc.

aft-pytorch Unofficial PyTorch implementation of Attention Free Transformer's layers by Zhai, et al. [abs, pdf] from Apple Inc. Installation You can i

Rishabh Anand 184 Dec 12, 2022
Customer-Transaction-Analysis - This analysis is based on a synthesised transaction dataset containing 3 months worth of transactions for 100 hypothetical customers.

Customer-Transaction-Analysis - This analysis is based on a synthesised transaction dataset containing 3 months worth of transactions for 100 hypothetical customers. It contains purchases, recurring

Ayodeji Yekeen 1 Jan 01, 2022
LOFO (Leave One Feature Out) Importance calculates the importances of a set of features based on a metric of choice,

LOFO (Leave One Feature Out) Importance calculates the importances of a set of features based on a metric of choice, for a model of choice, by iteratively removing each feature from the set, and eval

Ahmet Erdem 691 Dec 23, 2022
Metadata-Extractor - Metadata Extractor Script can be used to read in exif metadata

Metadata Extractor The exifextract script can be used to read in exif metadata f

1 Feb 16, 2022
Official PyTorch Implementation for InfoSwap: Information Bottleneck Disentanglement for Identity Swapping

InfoSwap: Information Bottleneck Disentanglement for Identity Swapping Code usage Please check out the user manual page. Paper Gege Gao, Huaibo Huang,

Grace Hešeri 56 Dec 20, 2022
An open-source, low-cost, image-based weed detection device for fallow scenarios.

Welcome to the OpenWeedLocator (OWL) project, an opensource hardware and software green-on-brown weed detector that uses entirely off-the-shelf compon

Guy Coleman 145 Jan 05, 2023
A Learning-based Camera Calibration Toolbox

Learning-based Camera Calibration A Learning-based Camera Calibration Toolbox Paper The pdf file can be found here. @misc{zhang2022learningbased,

Eason 14 Dec 21, 2022
Fair Recommendation in Two-Sided Platforms

Fair Recommendation in Two-Sided Platforms

gourabgggg 1 Nov 10, 2021
Code and data (Incidents Dataset) for ECCV 2020 Paper "Detecting natural disasters, damage, and incidents in the wild".

Incidents Dataset See the following pages for more details: Project page: IncidentsDataset.csail.mit.edu. ECCV 2020 Paper "Detecting natural disasters

Ethan Weber 67 Dec 27, 2022
Black-Box-Tuning - Black-Box Tuning for Language-Model-as-a-Service

Black-Box-Tuning Source code for paper "Black-Box Tuning for Language-Model-as-a

Tianxiang Sun 149 Jan 04, 2023
Prototype python implementation of the ome-ngff table spec

Prototype python implementation of the ome-ngff table spec

Kevin Yamauchi 8 Nov 20, 2022
[ACL 20] Probing Linguistic Features of Sentence-level Representations in Neural Relation Extraction

REval Table of Contents Introduction Overview Requirements Installation Probing Usage Citation License 🎓 Introduction REval is a simple framework for

13 Jan 06, 2023