当前位置:网站首页>用MLP代替掉Self-Attention
用MLP代替掉Self-Attention
2022-07-02 07:51:00 【MezereonXP】
用MLP代替掉Self-Attention
這次介紹的清華的一個工作 “Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks”
用兩個線性層代替掉Self-Attention機制,最終實現了在保持精度的同時實現速度的提昇。
這個工作讓人意外的是,我們可以使用MLP代替掉Attention機制,這使我們應該重新好好考慮Attention帶來的性能提昇的本質。
Transformer中的Self-Attention機制
首先,如下圖所示:

我們給出其形式化的結果:
A = softmax ( Q K T d k ) F o u t = A V A = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})\\ F_{out} = AV A=softmax(dkQKT)Fout=AV
其中, Q , K ∈ R N × d ′ Q,K \in \mathbb{R}^{N\times d'} Q,K∈RN×d′ 同時 V ∈ R N × d V\in \mathbb{R}^{N\times d} V∈RN×d
這裏,我們給出一個簡化版本,如下圖所示:

也就是將 Q , K , V Q,K,V Q,K,V 都以輸入特征 F F F 代替掉,其形式化為:
A = softmax ( F F T ) F o u t = A F A = \text{softmax}(FF^T)\\ F_{out} = AF A=softmax(FFT)Fout=AF
然而,這裏面的計算複雜度為 O ( d N 2 ) O(dN^2) O(dN2),這是Attention機制的一個較大的缺點。
外部注意力 (External Attention)
如下圖所示:

引入了兩個矩陣 M k ∈ R S × d M_k\in \mathbb{R}^{S\times d} Mk∈RS×d 以及 $M_v \in\mathbb{R}^{S\times d} $, 代替掉原來的 K , V K,V K,V
這裏直接給出其形式化:
A = Norm ( F M k T ) F o u t = A M v A = \text{Norm}(FM_k^T)\\ F_{out} = AM_v A=Norm(FMkT)Fout=AMv
這種設計,將複雜度降低到 O ( d S N ) O(dSN) O(dSN), 該工作發現,當 S ≪ N S\ll N S≪N 的時候,仍然能够保持足够的精度。
其中的 Norm ( ⋅ ) \text{Norm}(\cdot) Norm(⋅) 操作是先對列進行Softmax,然後對行進行歸一化。
實驗分析
首先,文章將Transformer中的Attention機制替換掉,然後在各類任務上進行測試,包括:
- 圖像分類
- 語義分割
- 圖像生成
- 點雲分類
- 點雲分割
這裏只給出部分結果,簡單說明一下替換後的精度損失情况。
圖像分類

語義分割

圖像生成

可以看到,在不同的任務上,基本上不會有精度損失。
边栏推荐
- 解决jetson nano安装onnx错误(ERROR: Failed building wheel for onnx)总结
- open3d学习笔记四【表面重建】
- [Sparse to Dense] Sparse to Dense: Depth Prediction from Sparse Depth samples and a Single Image
- PHP returns the corresponding key value according to the value in the two-dimensional array
- Yolov3 trains its own data set (mmdetection)
- Faster-ILOD、maskrcnn_ Benchmark installation process and problems encountered
- 论文写作tip2
- ModuleNotFoundError: No module named ‘pytest‘
- 基于onnxruntime的YOLOv5单张图片检测实现
- [binocular vision] binocular stereo matching
猜你喜欢

使用百度网盘上传数据到服务器上

TimeCLR: A self-supervised contrastive learning framework for univariate time series representation

【Paper Reading】

win10+vs2017+denseflow编译

【双目视觉】双目立体匹配

【Sparse-to-Dense】《Sparse-to-Dense:Depth Prediction from Sparse Depth Samples and a Single Image》

Faster-ILOD、maskrcnn_benchmark安装过程及遇到问题

【Mixup】《Mixup:Beyond Empirical Risk Minimization》

【深度学习系列(八)】:Transoform原理及实战之原理篇

【AutoAugment】《AutoAugment:Learning Augmentation Policies from Data》
随机推荐
【TCDCN】《Facial landmark detection by deep multi-task learning》
PointNet理解(PointNet实现第4步)
Open3D学习笔记一【初窥门径,文件读取】
《Handwritten Mathematical Expression Recognition with Bidirectionally Trained Transformer》论文翻译
The difference and understanding between generative model and discriminant model
PPT的技巧
iOD及Detectron2搭建过程问题记录
常见CNN网络创新点
【Paper Reading】
解决latex图片浮动的问题
Yolov3 trains its own data set (mmdetection)
Pointnet understanding (step 4 of pointnet Implementation)
Determine whether the version number is continuous in PHP
ABM thesis translation
Implementation of yolov5 single image detection based on pytorch
open3d学习笔记二【文件读写】
C#与MySQL数据库连接
【双目视觉】双目立体匹配
超时停靠视频生成
【Mixed Pooling】《Mixed Pooling for Convolutional Neural Networks》