当前位置:网站首页>用MLP代替掉Self-Attention
用MLP代替掉Self-Attention
2022-07-02 07:51:00 【MezereonXP】
用MLP代替掉Self-Attention
這次介紹的清華的一個工作 “Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks”
用兩個線性層代替掉Self-Attention機制,最終實現了在保持精度的同時實現速度的提昇。
這個工作讓人意外的是,我們可以使用MLP代替掉Attention機制,這使我們應該重新好好考慮Attention帶來的性能提昇的本質。
Transformer中的Self-Attention機制
首先,如下圖所示:

我們給出其形式化的結果:
A = softmax ( Q K T d k ) F o u t = A V A = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})\\ F_{out} = AV A=softmax(dkQKT)Fout=AV
其中, Q , K ∈ R N × d ′ Q,K \in \mathbb{R}^{N\times d'} Q,K∈RN×d′ 同時 V ∈ R N × d V\in \mathbb{R}^{N\times d} V∈RN×d
這裏,我們給出一個簡化版本,如下圖所示:

也就是將 Q , K , V Q,K,V Q,K,V 都以輸入特征 F F F 代替掉,其形式化為:
A = softmax ( F F T ) F o u t = A F A = \text{softmax}(FF^T)\\ F_{out} = AF A=softmax(FFT)Fout=AF
然而,這裏面的計算複雜度為 O ( d N 2 ) O(dN^2) O(dN2),這是Attention機制的一個較大的缺點。
外部注意力 (External Attention)
如下圖所示:

引入了兩個矩陣 M k ∈ R S × d M_k\in \mathbb{R}^{S\times d} Mk∈RS×d 以及 $M_v \in\mathbb{R}^{S\times d} $, 代替掉原來的 K , V K,V K,V
這裏直接給出其形式化:
A = Norm ( F M k T ) F o u t = A M v A = \text{Norm}(FM_k^T)\\ F_{out} = AM_v A=Norm(FMkT)Fout=AMv
這種設計,將複雜度降低到 O ( d S N ) O(dSN) O(dSN), 該工作發現,當 S ≪ N S\ll N S≪N 的時候,仍然能够保持足够的精度。
其中的 Norm ( ⋅ ) \text{Norm}(\cdot) Norm(⋅) 操作是先對列進行Softmax,然後對行進行歸一化。
實驗分析
首先,文章將Transformer中的Attention機制替換掉,然後在各類任務上進行測試,包括:
- 圖像分類
- 語義分割
- 圖像生成
- 點雲分類
- 點雲分割
這裏只給出部分結果,簡單說明一下替換後的精度損失情况。
圖像分類

語義分割

圖像生成

可以看到,在不同的任務上,基本上不會有精度損失。
边栏推荐
- Semi supervised mixpatch
- open3d学习笔记四【表面重建】
- Latex formula normal and italic
- Implementation of yolov5 single image detection based on pytorch
- Thesis writing tip2
- 深度学习分类优化实战
- mmdetection训练自己的数据集--CVAT标注文件导出coco格式及相关操作
- Two dimensional array de duplication in PHP
- 【Batch】learning notes
- 《Handwritten Mathematical Expression Recognition with Bidirectionally Trained Transformer》论文翻译
猜你喜欢

What if the notebook computer cannot run the CMD command

【Batch】learning notes

【深度学习系列(八)】:Transoform原理及实战之原理篇

ModuleNotFoundError: No module named ‘pytest‘

【DIoU】《Distance-IoU Loss:Faster and Better Learning for Bounding Box Regression》

open3d学习笔记三【采样与体素化】

Proof and understanding of pointnet principle

Installation and use of image data crawling tool Image Downloader

Thesis writing tip2

Timeout docking video generation
随机推荐
【FastDepth】《FastDepth:Fast Monocular Depth Estimation on Embedded Systems》
【Batch】learning notes
ModuleNotFoundError: No module named ‘pytest‘
Memory model of program
The difference and understanding between generative model and discriminant model
Play online games with mame32k
MoCO ——Momentum Contrast for Unsupervised Visual Representation Learning
TimeCLR: A self-supervised contrastive learning framework for univariate time series representation
自然辩证辨析题整理
A slide with two tables will help you quickly understand the target detection
Win10+vs2017+denseflow compilation
open3d环境错误汇总
ModuleNotFoundError: No module named ‘pytest‘
Latex formula normal and italic
How to clean up logs on notebook computers to improve the response speed of web pages
【Mixup】《Mixup:Beyond Empirical Risk Minimization》
Feature Engineering: summary of common feature transformation methods
Using MATLAB to realize: Jacobi, Gauss Seidel iteration
Using compose to realize visible scrollbar
【MobileNet V3】《Searching for MobileNetV3》