当前位置:网站首页>用MLP代替掉Self-Attention
用MLP代替掉Self-Attention
2022-07-02 06:26:00 【MezereonXP】
用MLP代替掉Self-Attention
这次介绍的清华的一个工作 “Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks”
用两个线性层代替掉Self-Attention机制,最终实现了在保持精度的同时实现速度的提升。
这个工作让人意外的是,我们可以使用MLP代替掉Attention机制,这使我们应该重新好好考虑Attention带来的性能提升的本质。
Transformer中的Self-Attention机制
首先,如下图所示:

我们给出其形式化的结果:
A = softmax ( Q K T d k ) F o u t = A V A = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})\\ F_{out} = AV A=softmax(dkQKT)Fout=AV
其中, Q , K ∈ R N × d ′ Q,K \in \mathbb{R}^{N\times d'} Q,K∈RN×d′ 同时 V ∈ R N × d V\in \mathbb{R}^{N\times d} V∈RN×d
这里,我们给出一个简化版本,如下图所示:

也就是将 Q , K , V Q,K,V Q,K,V 都以输入特征 F F F 代替掉,其形式化为:
A = softmax ( F F T ) F o u t = A F A = \text{softmax}(FF^T)\\ F_{out} = AF A=softmax(FFT)Fout=AF
然而,这里面的计算复杂度为 O ( d N 2 ) O(dN^2) O(dN2),这是Attention机制的一个较大的缺点。
外部注意力 (External Attention)
如下图所示:

引入了两个矩阵 M k ∈ R S × d M_k\in \mathbb{R}^{S\times d} Mk∈RS×d 以及 $M_v \in\mathbb{R}^{S\times d} $, 代替掉原来的 K , V K,V K,V
这里直接给出其形式化:
A = Norm ( F M k T ) F o u t = A M v A = \text{Norm}(FM_k^T)\\ F_{out} = AM_v A=Norm(FMkT)Fout=AMv
这种设计,将复杂度降低到 O ( d S N ) O(dSN) O(dSN), 该工作发现,当 S ≪ N S\ll N S≪N 的时候,仍然能够保持足够的精度。
其中的 Norm ( ⋅ ) \text{Norm}(\cdot) Norm(⋅) 操作是先对列进行Softmax,然后对行进行归一化。
实验分析
首先,文章将Transformer中的Attention机制替换掉,然后在各类任务上进行测试,包括:
- 图像分类
- 语义分割
- 图像生成
- 点云分类
- 点云分割
这里只给出部分结果,简单说明一下替换后的精度损失情况。
图像分类

语义分割

图像生成

可以看到,在不同的任务上,基本上不会有精度损失。
边栏推荐
- Timeout docking video generation
- 【TCDCN】《Facial landmark detection by deep multi-task learning》
- What if the laptop task manager is gray and unavailable
- 生成模型与判别模型的区别与理解
- Huawei machine test questions
- PHP returns the corresponding key value according to the value in the two-dimensional array
- win10+vs2017+denseflow编译
- [tricks] whiteningbert: an easy unsupervised sentence embedding approach
- CONDA common commands
- 【Mixup】《Mixup:Beyond Empirical Risk Minimization》
猜你喜欢

【深度学习系列(八)】:Transoform原理及实战之原理篇

A slide with two tables will help you quickly understand the target detection

ModuleNotFoundError: No module named ‘pytest‘

ModuleNotFoundError: No module named ‘pytest‘

open3d学习笔记四【表面重建】

TimeCLR: A self-supervised contrastive learning framework for univariate time series representation

Win10+vs2017+denseflow compilation

Faster-ILOD、maskrcnn_ Benchmark trains its own VOC data set and problem summary

《Handwritten Mathematical Expression Recognition with Bidirectionally Trained Transformer》论文翻译

程序的内存模型
随机推荐
半监督之mixmatch
[CVPR‘22 Oral2] TAN: Temporal Alignment Networks for Long-term Video
Regular expressions in MySQL
Faster-ILOD、maskrcnn_ Benchmark training coco data set and problem summary
【双目视觉】双目立体匹配
A slide with two tables will help you quickly understand the target detection
Generate random 6-bit invitation code in PHP
【Sparse-to-Dense】《Sparse-to-Dense:Depth Prediction from Sparse Depth Samples and a Single Image》
Network metering - transport layer
Calculate the difference in days, months, and years between two dates in PHP
Installation and use of image data crawling tool Image Downloader
解决latex图片浮动的问题
【Mixup】《Mixup:Beyond Empirical Risk Minimization》
PointNet理解(PointNet实现第4步)
[tricks] whiteningbert: an easy unsupervised sentence embedding approach
Using MATLAB to realize: Jacobi, Gauss Seidel iteration
What if the laptop task manager is gray and unavailable
【Cascade FPD】《Deep Convolutional Network Cascade for Facial Point Detection》
【雙目視覺】雙目矯正
【Mixed Pooling】《Mixed Pooling for Convolutional Neural Networks》