当前位置:网站首页>注意力机制的一种卷积替代方式
注意力机制的一种卷积替代方式
2022-07-06 08:51:00 【cyz0202】
参考自:fairseq
背景
常见的注意力机制如下

而且一般使用multi heads的方式进行上述计算;
耗时耗力;
卷积很早就用在NLP中,只是常见的使用方法计算量也不小,而且效果不如Attention;
有没有可能改进卷积在NLP的应用方式,达到既快又好的效果?
一个思路就是 depthwise convolution(减少计算量),并且模仿注意力的softmax机制(加权平均);
改进方案
常见的depthwise convolution如下

W为kernel,
,参数量大小为:当d=1024,k=7,d*k=7168
-------
为了进一步减少计算量,可以令
,其中一般
,如H=16,d=1024;
H所在维度变小,为了保持对x做完整的卷积计算,则需要在原来d这个维度 重复利用W,也就是不同channel可能用到W中同一行参数;
有两种重复利用方式,一种是W在d这个维度的平移重复,即每H行使用一次W;
另一种可以考虑W的行在d维度上的重复,即卷积计算过程中,x在d这个维度上 每d/H行 对应到 W中某一行,如x的d维度上的[0-d/H)行 depthwise卷积计算使用 W的第0行参数;(稍微有点绕,可参考下面公式3,会更直观)
这里使用第二种方式,即行的重复;考虑的是模仿Attention的multi heads,每个头(大小为d/H)能独立计算;第二方式相当于x的每d/H行(d维度上的每个头)使用各自独立的kernel参数;
细心的读者可以发现,这里的H其实就是模仿Attention中muliti heads的H;
通过以上设计,W的参数量缩减到 H*k(如H=16,k=7,则参数量只有112)
-------
除了减少计算量,还可以让卷积模仿Attention的softmax过程;
观察公式(1),
的作用和Attention softmax输出很像,即加权平均;
所以为什么我们不让
也是一个distribution呢?
因此令

-------
最终卷积计算如下

注:W下角标上取整默认
,也可以换成下取整并令![c \in [0, d-1]](http://img.inotgo.com/imagesLocal/202207/06/202207060850360766_2.gif)
具体实现
上述行重复的计算方式并不能直接利用已有的卷积算子计算;
为了利用矩阵计算,可以想一个折中的方式
令
,令
,执行BMM(Batch MatMul),可得

上述设计可以保证结果Out有我们需要的正确的shape,那怎么保证Out值也是正确,即怎么设计W',x的具体值?
不难,只要根据上一节的 令W行重复 以 对x进行multi-head计算 的思想出发即可;
考虑
,就是简单的一个reshape,维度1的BH代表batch*H,假设batch=1,那就是H个计算头;维度2的n代表序列长度n,即n个位置;最后一个维度(即channel)代表的是一个head,即一个计算头大小;
卷积计算时,
上述x'的每个头(最后一个维度 大小d/H,共H个),对应的W参数是W中 按顺序的一行参数(大小k,H个头就对应W的H行);同时x'第2个维度表示有n个位置要进行卷积计算,并且每个位置计算方式是与 窗口大小为k且固定值的序列 加权求和;
所以 W'维度1的BH 代表 与x'的BH个头一一对应的卷积参数(B=1时就是H个头对应H行kernel参数);
W'维度2的n,代表n个位置要进行卷积计算;
W'维度3的n,是对 当前头对应的kernel内一行 的 k个参数 扩展到n个参数,扩展方式是0填充;
之所以要对当前kernel大小为k的行参数 填充到n,主要是上述卷积计算是一个固定大小为k的窗口在长度为n的序列上滑动;
因为卷积参数个数恒定,但是位置发生变化,看起来是不规则计算;但是 想一下,虽然位置在滑动,但是是在长度为n的序列上滑动;能不能构造一个长度为n的fake kernel,除了窗口所在位置放置真正kerenl参数值,其他非窗口所在的位置都填充0呢?这样我们得到了一个长度固定为n的fake kernel,就能以统一形式计算了;
举例说明,假设此时计算位置为 10,则窗口中心也在10这个位置,而窗口具体位置可能是[7,13],那么我把除了 窗口所在位置外的 其他位置 都填充0,就得到一个长度为n的参数序列,就能用统一的 n*n 点积计算方式 去和 x长度为n的输入进行计算了;
当然,上述填充方式 每次得到的 填充序列(fake kernel) 都是不一样的,因为当前卷积行参数值虽然不变,但是位置在变;
了解过 深度学习框架是如何实现卷积计算的 同学 可能对这种填充方式会更熟悉;
上述说法可能还是有点绕,读者想一下
矩阵相乘,第1个n表示n个位置,第2/3个n表示当前某个位置 发生的卷积计算,之所以是n*n,是为了达到 统一计算形式 进行了 0填充;
等我补充更好的图例哈。。。
CUDA实现
上述计算方式比较耗资源,考虑自定义CUDA算子;
我另外写篇文章进行讲述;自定义卷积注意力算子的CUDA实现
实验结果
待续
总结
- 本文介绍了注意力机制的一种卷积替代方式;
- 通过一定的设计让卷积做到轻量化,同时模仿Attention的设计来达到更好的效果;
边栏推荐
- Trying to use is on a network resource that is unavailable
- Super efficient! The secret of swagger Yapi
- Warning in install. packages : package ‘RGtk2’ is not available for this version of R
- Chapter 1 :Application of Artificial intelligence in Drug Design:Opportunity and Challenges
- MYSQL卸载方法与安装方法
- Pytorch view tensor memory size
- TDengine 社区问题双周精选 | 第三期
- pytorch查看张量占用内存大小
- Esp8266-rtos IOT development
- CSP first week of question brushing
猜你喜欢

LeetCode:236. The nearest common ancestor of binary tree

Target detection - pytorch uses mobilenet series (V1, V2, V3) to build yolov4 target detection platform

广州推进儿童友好城市建设,将探索学校周边200米设安全区域

win10系统中的截图,win+prtSc保存位置

Double pointeur en langage C - - modèle classique
![[OC]-<UI入门>--常用控件-提示对话框 And 等待提示器(圈)](/img/af/a44c2845c254e4f48abde013344c2b.png)
[OC]-<UI入门>--常用控件-提示对话框 And 等待提示器(圈)

MYSQL卸载方法与安装方法

Chapter 1 :Application of Artificial intelligence in Drug Design:Opportunity and Challenges

优秀的软件测试人员,都具备这些能力

【剑指offer】序列化二叉树
随机推荐
自动化测试框架有什么作用?上海专业第三方软件测试公司安利
生成器参数传入参数
Mongodb installation and basic operation
Image, CV2 read the conversion and size resize change of numpy array of pictures
Purpose of computer F1-F12
LeetCode:673. 最长递增子序列的个数
LeetCode:41. Missing first positive number
[embedded] print log using JLINK RTT
After reading the programmer's story, I can't help covering my chest...
MYSQL卸载方法与安装方法
The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
如何进行接口测试测?有哪些注意事项?保姆级解读
有效提高软件产品质量,就找第三方软件测评机构
LeetCode:剑指 Offer 03. 数组中重复的数字
查看局域网中电脑设备
Screenshot in win10 system, win+prtsc save location
SAP ui5 date type sap ui. model. type. Analysis of the parsing format of date
Image,cv2读取图片的numpy数组的转换和尺寸resize变化
R language ggplot2 visualization: place the title of the visualization image in the upper left corner of the image (customize Title position in top left of ggplot2 graph)
Notes 01