当前位置:网站首页>【Transformer】ATS: Adaptive Token Sampling For Efficient Vision Transformers
【Transformer】ATS: Adaptive Token Sampling For Efficient Vision Transformers
2022-07-29 05:21:00 【呆呆的猫】

一、背景
尽管现有的 transformer 模型在分类等任务上取得了较好的效果,但计算量还是很高,需要很多的 GFLOPs,不适用于很多边缘设备,虽然GFLOPs 也可以通过降低网络中 token 数量来降低,DynamicViT 使用网络预测每个 token 的得分,从而判断哪个 token 是冗余的。虽然这个方法能够降低网络的 GFLOPs,但得分预测网络也会引入额外的参数,并且如果想要不同的降参比率需要再次进行训练。
二、动机
作者认为,对于分类任务,并非需要图中的所有信息来进行分类,因为图像的信息对分类任务来说是冗余的。所以本文提出了一个降低 token 数量的方法,可以适用于任何 transformer,不受降参比率限制,且更高效。
三、方法
作者提出了一种名为 " Adaptive Token Sampler (ATS) " 的模块,是一种动态的从输入 token 中选择重要的 token 的模块。也是一个 parameter-free 的方法,总体结构如图 2 所示。卷积网络中,一般会使用 pooling 来降低计算量,stage 越深,分辨率越小。但 Transformer 中不能直接使用这样的方法,因为 token 是与空间位置无关的,即改变位置不会影响最后的结果。而且如果使用下采样的话会有两个弊端,其一是会丢失目标的细节信息,其二是可能保留很多背景信息,对分类无实质性作用。所以作者提出了一种动态选择每个stage的token数量的方法。
ATS 的过程:
- 首先,对 N 个输入 token 分配得分,基于得分来确定哪些留下
- 然后,设定 K 为保留的 token 最大数量,这个 K 会决定 GFLOPs 的上限
- sampled tokens K’ 一般会比 K 小,且和输入图像的关系如图 6 所示

对于每个实例,图7展示了作者使用少数或多数的 patches ,就可以得到正确的分类,图 3 展示了不同每个 stage 使用的 token 数量。作者也提出了一个对每个图像选择正确 token 数量的方法。如图 6 所示,不同图像在不同stage的 token 数量是不同的。


3.1 Token Scoring
在标准的 self-attention 层,输入的 Q、K、V 都是从输入 token 得来的,然后会得到 attention matrix A:
由于 softmax 的存在,A 的每行和为 1,输出 token 会和 attention matrix 作用,从而加权。
A 的每行包含了输入 token 的 attention weights,这个 weights 其实就表示了所有 token 对输出 token 的作用,因为 A 的第一行是 cls token,表示了输入 token 对输出 classification token 的作用,所以作者使用第一行的元素作为修剪 A 的根据,如图 2 所示。作者也做了归一化,重要程度得分如下,对于多头注意力,分别对每个头进行计算,然后加起来:
3.2 Token Sampling
对每个 token 得到 score 之后,就可以根据 attention matrix A 对 tokens 进行修剪了。
一个比较基础的做法是直接选择 top-K 个 tokens,但是实验结果说明,这种方法没有动态选择 K’ 个 tokens 的效果好。其表现不好的原因在于,直接丢弃了所有得分低的token,但有些 token 其实在浅层可能会比较有用。
作者的抽样方法中,从几个相似的 token 中抽象的概率等于这些 token 的得分之和。而且从图 3 中也能看出,本文的抽样机制从浅层抽样的 token 数量比深层的更多一些。
方法:
因为 token score 是被归一化的,所以可以看出概率,可以计算累计密度函数(CDF):
对 CDF 取反,就得到了采样函数:
四、效果

边栏推荐
- Reporting Services- Web Service
- [database] database course design - vaccination database
- 初探fastJson的AutoType
- MySql统计函数COUNT详解
- Flink, the mainstream real-time stream processing computing framework, is the first experience.
- Activity交互问题,你确定都知道?
- Huawei 2020 school recruitment written test programming questions read this article is enough (Part 1)
- Spring, summer, autumn and winter with Miss Zhang (4)
- [clustmaps] visitor statistics
- File permissions of day02 operation
猜你喜欢

Thinkphp6 output QR code image format to solve the conflict with debug

【目标检测】Generalized Focal Loss V1

How to make interesting apps for deep learning with zero code (suitable for novices)

iSCSI vs iSER vs NVMe-TCP vs NVMe-RDMA

MySql统计函数COUNT详解

Ribbon learning notes II

Reporting service 2016 custom authentication

Thinkphp6 pipeline mode pipeline use

Markdown syntax

Huawei 2020 school recruitment written test programming questions read this article is enough (Part 2)
随机推荐
File文件上传的使用(2)--上传到阿里云Oss文件服务器
asyncawait和promise的区别
Training log III of "Shandong University mobile Internet development technology teaching website construction" project
Use of file upload (2) -- upload to Alibaba cloud OSS file server
Xsan is highly available - xdfs and San are integrated with new vitality
iSCSI vs iSER vs NVMe-TCP vs NVMe-RDMA
isAccessible()方法:使用反射技巧让你的性能提升数倍
在uni-app项目中,如何实现微信小程序openid的获取
Win10+opencv3.2+vs2015 configuration
Training log 4 of the project "construction of Shandong University mobile Internet development technology teaching website"
Process management of day02 operation
Personal learning website
以‘智’提‘质|金融影像平台解决方案
Spring, summer, autumn and winter with Miss Zhang (2)
Reporting Service 2016 自定义身份验证
赓续新征程,共驭智存储
并发编程学习笔记 之 ReentrantLock实现原理的探究
Semaphore (semaphore) for learning notes of concurrent programming
ReportingService WebService Form身份验证
【TensorRT】将 PyTorch 转化为可部署的 TensorRT