当前位置:网站首页>【Transformer】AdaViT: Adaptive Vision Transformers for Efficient Image Recognition
【Transformer】AdaViT: Adaptive Vision Transformers for Efficient Image Recognition
2022-07-29 05:21:00 【呆呆的猫】

一、背景
Transformer 在多个领域实现了良好的效果,但其计算量随着 patches 数量的增加、self-attention head 数量增加、transformer block 数量的增加会有很大的增大。
但作者提出了两个个问题:
是否所有的 patches 都需要通过整个网络,才能得到更好的分类结果?
是否所有的 self-attention 都需要很多头来寻找整个图中的潜在关联?
作者认为,只有背景复杂、遮挡严重等复杂难例需要更多的 patch 和 self-attention block,简单的样本只需要少量的 patch 和 self-attention block 就可以实现足够好的效果了。
基于此,作者实现了一种动态计算量的框架,来学习哪些 patch 或 哪些 self-attention heads/blocks 需要保留。所以,网络会给简单样本降低 patch 和 self-attention 层,难样本使用全部的网络层。
作者提出的 Adaptive Vision Transformer (AdaViT) 是一个端到端的结构,能够动态的判断 transformer 结构中,哪些 patch、self-attention block、self-attention heads 需要保留。
AdaViT 能够提升训练速度 2x,降低了 0.8% 的分类准确率,是效果和速度平衡的方法。

二、方法

1、Decision Network
作者给每个 transformer block 中插入了一个轻量的多头子网络,也就是 decision network,该网络能够学习一个二值结果,来决定对 patch embedding、self-attention heads、blocks 是否使用。
第 l l l 个 block 的 decision network 有 3 个线性层,参数为 W l = { W l p , W l h , W l b } W_l=\{W_l^p, W_l^h, W_l^b\} Wl={ Wlp,Wlh,Wlb},分别预测 patch、attention head、transformer block 是否需要保留。
所以,对于 block Z l Z_l Zl,会计算如下:
- N 和 H 分别为 transformer block 中的 patch 数量和 self-attention head 的数量,得到的三个 m l m_l ml 会经过 sigmoid 函数,表示 patch、attention head、transformer block 被保留的概率。
由于 decision 需要是二值的,所以保留/丢弃可以在infer的时候采用阈值来判断。
但由于不同样本的最优的阈值是不同的,所以作者定义了随机变量 M l p , M l h , M l b M_l^p, M_l^h, M_l^b Mlp,Mlh,Mlb 通过从 m l p , m l h , m l b m_l^p, m_l^h, m_l^b mlp,mlh,mlb 中采样来判断,即如果 M l , j p = 1 M_{l,j}^p=1 Ml,jp=1,则保留第 l l l 个 block 中的第 j j j 个 patch embedding,如果 M l , j p = 0 M_{l,j}^p=0 Ml,jp=0 则舍弃。并且,作者使用 Gumbel-Softmax trick [25] 来保证在训练时候的多样性。
2、Patch Selection
Transformer block 的输入中,作者想要保留那些信息丰富的 patch embedding。
对于第 l l l 个 block,如果 M i p = 0 M_i^p=0 Mip=0,则丢弃该 patch:
- z l , c l s z_{l,cls} zl,cls 会被保留,因为这是用来分类的
3、Head Selection
多头注意力机制中的不同头会关注不同的区域,挖掘更多的潜在信息。
作者为了提高推理速度,会自适应的将某些 head 舍弃掉,为了抑制某些头,也就是 deactivation,作者探究了两种方法:
1、 partial deactivation
第 l l l 个block 的第 i i i 个 head 的 attention 计算如下:

2、full deactivation
整体的激活抑制如下,所有的 head 都被移除了,MSA 的输出编码尺寸减少如下:

4、Block Selection
跳过不必要的 transformer block 也能减少很大的计算量,为了提升跳过的灵活性,作者使得 transformer block 中的 MSA 和 FFN 可以分别跳过,而非捆绑在一起。

三、效果






边栏推荐
- Tear the ORM framework by hand (generic + annotation + reflection)
- Markdown语法
- Research and implementation of flash loan DAPP
- ssm整合
- IDEA中设置自动build-改动代码,不用重启工程,刷新页面即可
- Detailed explanation of atomic operation class atomicinteger in learning notes of concurrent programming
- Detailed explanation of MySQL statistical function count
- DataX installation
- [competition website] collect machine learning / deep learning competition website (continuously updated)
- 与张小姐的春夏秋冬(2)
猜你喜欢

My ideal job, the absolute freedom of coder farmers is the most important - the pursuit of entrepreneurship in the future

Power BI Report Server 自定义身份验证

centos7 静默安装oracle

Use of file upload (2) -- upload to Alibaba cloud OSS file server

Intelligent security of the fifth space ⼤ real competition problem ----------- PNG diagram ⽚ converter

Flutter正在被悄悄放弃?浅析Flutter的未来

Flink connector Oracle CDC 实时同步数据到MySQL(Oracle19c)

Huawei 2020 school recruitment written test programming questions read this article is enough (Part 1)

【图像分类】如何使用 mmclassification 训练自己的分类模型

Research on the implementation principle of reentrantlock in concurrent programming learning notes
随机推荐
IDEA中设置自动build-改动代码,不用重启工程,刷新页面即可
Ribbon learning notes 1
SQL repair duplicate data
XDFS&中国日报社在线协同编辑平台典型案例
【go】defer的使用
Performance comparison | FASS iSCSI vs nvme/tcp
Reporting Services- Web Service
Android Studio 实现登录注册-源代码 (连接MySql数据库)
Huawei 2020 school recruitment written test programming questions read this article is enough (Part 2)
【图像分类】如何使用 mmclassification 训练自己的分类模型
Realize the scheduled backup of MySQL database in Linux environment through simple script (mysqldump command backup)
[ml] PMML of machine learning model -- Overview
Reporting Service 2016 自定义身份验证
Ribbon学习笔记二
Win10+opencv3.2+vs2015 configuration
Interesting talk about performance optimization thread pool: is the more threads open, the better?
【比赛网站】收集机器学习/深度学习比赛网站(持续更新)
Thinkphp6 pipeline mode pipeline use
【Clustrmaps】访客统计
[DL] introduction and understanding of tensor