当前位置:网站首页>用MLP代替掉Self-Attention
用MLP代替掉Self-Attention
2022-07-02 07:51:00 【MezereonXP】
用MLP代替掉Self-Attention
這次介紹的清華的一個工作 “Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks”
用兩個線性層代替掉Self-Attention機制,最終實現了在保持精度的同時實現速度的提昇。
這個工作讓人意外的是,我們可以使用MLP代替掉Attention機制,這使我們應該重新好好考慮Attention帶來的性能提昇的本質。
Transformer中的Self-Attention機制
首先,如下圖所示:
我們給出其形式化的結果:
A = softmax ( Q K T d k ) F o u t = A V A = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})\\ F_{out} = AV A=softmax(dkQKT)Fout=AV
其中, Q , K ∈ R N × d ′ Q,K \in \mathbb{R}^{N\times d'} Q,K∈RN×d′ 同時 V ∈ R N × d V\in \mathbb{R}^{N\times d} V∈RN×d
這裏,我們給出一個簡化版本,如下圖所示:
也就是將 Q , K , V Q,K,V Q,K,V 都以輸入特征 F F F 代替掉,其形式化為:
A = softmax ( F F T ) F o u t = A F A = \text{softmax}(FF^T)\\ F_{out} = AF A=softmax(FFT)Fout=AF
然而,這裏面的計算複雜度為 O ( d N 2 ) O(dN^2) O(dN2),這是Attention機制的一個較大的缺點。
外部注意力 (External Attention)
如下圖所示:
引入了兩個矩陣 M k ∈ R S × d M_k\in \mathbb{R}^{S\times d} Mk∈RS×d 以及 $M_v \in\mathbb{R}^{S\times d} $, 代替掉原來的 K , V K,V K,V
這裏直接給出其形式化:
A = Norm ( F M k T ) F o u t = A M v A = \text{Norm}(FM_k^T)\\ F_{out} = AM_v A=Norm(FMkT)Fout=AMv
這種設計,將複雜度降低到 O ( d S N ) O(dSN) O(dSN), 該工作發現,當 S ≪ N S\ll N S≪N 的時候,仍然能够保持足够的精度。
其中的 Norm ( ⋅ ) \text{Norm}(\cdot) Norm(⋅) 操作是先對列進行Softmax,然後對行進行歸一化。
實驗分析
首先,文章將Transformer中的Attention機制替換掉,然後在各類任務上進行測試,包括:
- 圖像分類
- 語義分割
- 圖像生成
- 點雲分類
- 點雲分割
這裏只給出部分結果,簡單說明一下替換後的精度損失情况。
圖像分類
語義分割
圖像生成
可以看到,在不同的任務上,基本上不會有精度損失。
边栏推荐
- Optimization method: meaning of common mathematical symbols
- Calculate the total in the tree structure data in PHP
- [CVPR‘22 Oral2] TAN: Temporal Alignment Networks for Long-term Video
- What if the laptop task manager is gray and unavailable
- How do vision transformer work?【论文解读】
- 【C#笔记】winform中保存DataGridView中的数据为Excel和CSV
- 使用百度网盘上传数据到服务器上
- conda常用命令
- [CVPR‘22 Oral2] TAN: Temporal Alignment Networks for Long-term Video
- 图片数据爬取工具Image-Downloader的安装和使用
猜你喜欢
jetson nano安装tensorflow踩坑记录(scipy1.4.1)
ABM thesis translation
Implementation of yolov5 single image detection based on onnxruntime
【Sparse-to-Dense】《Sparse-to-Dense:Depth Prediction from Sparse Depth Samples and a Single Image》
【MagNet】《Progressive Semantic Segmentation》
【MnasNet】《MnasNet:Platform-Aware Neural Architecture Search for Mobile》
【多模态】CLIP模型
【Random Erasing】《Random Erasing Data Augmentation》
Network metering - transport layer
【MobileNet V3】《Searching for MobileNetV3》
随机推荐
Convert timestamp into milliseconds and format time in PHP
label propagation 标签传播
TimeCLR: A self-supervised contrastive learning framework for univariate time series representation
点云数据理解(PointNet实现第3步)
A slide with two tables will help you quickly understand the target detection
Use Baidu network disk to upload data to the server
[Sparse to Dense] Sparse to Dense: Depth Prediction from Sparse Depth samples and a Single Image
程序的内存模型
Huawei machine test questions
程序的执行
【Sparse-to-Dense】《Sparse-to-Dense:Depth Prediction from Sparse Depth Samples and a Single Image》
CPU的寄存器
使用百度网盘上传数据到服务器上
[CVPR‘22 Oral2] TAN: Temporal Alignment Networks for Long-term Video
Mmdetection installation problem
CPU register
自然辩证辨析题整理
Determine whether the version number is continuous in PHP
PHP returns the corresponding key value according to the value in the two-dimensional array
Faster-ILOD、maskrcnn_benchmark安装过程及遇到问题