当前位置:网站首页>注意力机制的一种卷积替代方式
注意力机制的一种卷积替代方式
2022-07-06 08:51:00 【cyz0202】
参考自:fairseq
背景
常见的注意力机制如下
而且一般使用multi heads的方式进行上述计算;
耗时耗力;
卷积很早就用在NLP中,只是常见的使用方法计算量也不小,而且效果不如Attention;
有没有可能改进卷积在NLP的应用方式,达到既快又好的效果?
一个思路就是 depthwise convolution(减少计算量),并且模仿注意力的softmax机制(加权平均);
改进方案
常见的depthwise convolution如下
W为kernel,,参数量大小为:当d=1024,k=7,d*k=7168
-------
为了进一步减少计算量,可以令 ,其中一般 ,如H=16,d=1024;
H所在维度变小,为了保持对x做完整的卷积计算,则需要在原来d这个维度 重复利用W,也就是不同channel可能用到W中同一行参数;
有两种重复利用方式,一种是W在d这个维度的平移重复,即每H行使用一次W;
另一种可以考虑W的行在d维度上的重复,即卷积计算过程中,x在d这个维度上 每d/H行 对应到 W中某一行,如x的d维度上的[0-d/H)行 depthwise卷积计算使用 W的第0行参数;(稍微有点绕,可参考下面公式3,会更直观)
这里使用第二种方式,即行的重复;考虑的是模仿Attention的multi heads,每个头(大小为d/H)能独立计算;第二方式相当于x的每d/H行(d维度上的每个头)使用各自独立的kernel参数;
细心的读者可以发现,这里的H其实就是模仿Attention中muliti heads的H;
通过以上设计,W的参数量缩减到 H*k(如H=16,k=7,则参数量只有112)
-------
除了减少计算量,还可以让卷积模仿Attention的softmax过程;
观察公式(1),的作用和Attention softmax输出很像,即加权平均;
所以为什么我们不让也是一个distribution呢?
因此令
-------
最终卷积计算如下
注:W下角标上取整默认,也可以换成下取整并令
具体实现
上述行重复的计算方式并不能直接利用已有的卷积算子计算;
为了利用矩阵计算,可以想一个折中的方式
令 ,令,执行BMM(Batch MatMul),可得
上述设计可以保证结果Out有我们需要的正确的shape,那怎么保证Out值也是正确,即怎么设计W',x的具体值?
不难,只要根据上一节的 令W行重复 以 对x进行multi-head计算 的思想出发即可;
考虑 ,就是简单的一个reshape,维度1的BH代表batch*H,假设batch=1,那就是H个计算头;维度2的n代表序列长度n,即n个位置;最后一个维度(即channel)代表的是一个head,即一个计算头大小;
卷积计算时,
上述x'的每个头(最后一个维度 大小d/H,共H个),对应的W参数是W中 按顺序的一行参数(大小k,H个头就对应W的H行);同时x'第2个维度表示有n个位置要进行卷积计算,并且每个位置计算方式是与 窗口大小为k且固定值的序列 加权求和;
所以 W'维度1的BH 代表 与x'的BH个头一一对应的卷积参数(B=1时就是H个头对应H行kernel参数);
W'维度2的n,代表n个位置要进行卷积计算;
W'维度3的n,是对 当前头对应的kernel内一行 的 k个参数 扩展到n个参数,扩展方式是0填充;
之所以要对当前kernel大小为k的行参数 填充到n,主要是上述卷积计算是一个固定大小为k的窗口在长度为n的序列上滑动;
因为卷积参数个数恒定,但是位置发生变化,看起来是不规则计算;但是 想一下,虽然位置在滑动,但是是在长度为n的序列上滑动;能不能构造一个长度为n的fake kernel,除了窗口所在位置放置真正kerenl参数值,其他非窗口所在的位置都填充0呢?这样我们得到了一个长度固定为n的fake kernel,就能以统一形式计算了;
举例说明,假设此时计算位置为 10,则窗口中心也在10这个位置,而窗口具体位置可能是[7,13],那么我把除了 窗口所在位置外的 其他位置 都填充0,就得到一个长度为n的参数序列,就能用统一的 n*n 点积计算方式 去和 x长度为n的输入进行计算了;
当然,上述填充方式 每次得到的 填充序列(fake kernel) 都是不一样的,因为当前卷积行参数值虽然不变,但是位置在变;
了解过 深度学习框架是如何实现卷积计算的 同学 可能对这种填充方式会更熟悉;
上述说法可能还是有点绕,读者想一下 矩阵相乘,第1个n表示n个位置,第2/3个n表示当前某个位置 发生的卷积计算,之所以是n*n,是为了达到 统一计算形式 进行了 0填充;
等我补充更好的图例哈。。。
CUDA实现
上述计算方式比较耗资源,考虑自定义CUDA算子;
我另外写篇文章进行讲述;自定义卷积注意力算子的CUDA实现
实验结果
待续
总结
- 本文介绍了注意力机制的一种卷积替代方式;
- 通过一定的设计让卷积做到轻量化,同时模仿Attention的设计来达到更好的效果;
边栏推荐
- Super efficient! The secret of swagger Yapi
- 随手记01
- Li Kou daily question 1 (2)
- Current situation and trend of character animation
- LeetCode:劍指 Offer 42. 連續子數組的最大和
- opencv+dlib实现给蒙娜丽莎“配”眼镜
- To effectively improve the quality of software products, find a third-party software evaluation organization
- Alibaba cloud server mining virus solution (practiced)
- LeetCode:236. The nearest common ancestor of binary tree
- Guangzhou will promote the construction of a child friendly city, and will explore the establishment of a safe area 200 meters around the school
猜你喜欢
个人电脑好用必备软件(使用过)
LeetCode:124. 二叉树中的最大路径和
Sublime text using ctrl+b to run another program without closing other runs
UML diagram memory skills
使用latex导出IEEE文献格式
Problems encountered in connecting the database of the project and their solutions
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
Current situation and trend of character animation
Restful API design specification
Generator parameters incoming parameters
随机推荐
Image, CV2 read the conversion and size resize change of numpy array of pictures
Current situation and trend of character animation
Screenshot in win10 system, win+prtsc save location
Promise 在uniapp的简单使用
LeetCode:162. 寻找峰值
LeetCode:236. 二叉树的最近公共祖先
The problem and possible causes of the robot's instantaneous return to the origin of the world coordinate during rviz simulation
LeetCode:剑指 Offer 04. 二维数组中的查找
UML diagram memory skills
Tcp/ip protocol
如何进行接口测试测?有哪些注意事项?保姆级解读
After reading the programmer's story, I can't help covering my chest...
Compétences en mémoire des graphiques UML
深度剖析C语言指针
LeetCode:394. 字符串解码
力扣每日一题(二)
Excellent software testers have these abilities
[Hacker News Weekly] data visualization artifact; Top 10 Web hacker technologies; Postman supports grpc
Deep analysis of C language pointer
LeetCode:221. Largest Square