当前位置:网站首页>PVT's spatial reduction attention (SRA)
PVT's spatial reduction attention (SRA)
2022-07-27 08:58:00 【hxxjxw】
It can be understood as the R Points converge into one , then attention When Q And aggregated points K and V count
import torch from torch import nn class Attention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): super().__init__() assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." self.dim = dim self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.sr_ratio = sr_ratio # In implementation, it is equivalent to a convolution layer if sr_ratio > 1: self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) self.norm = nn.LayerNorm(dim) def forward(self, x, H, W): B, N, C = x.shape q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) if self.sr_ratio > 1: x_ = x.permute(0, 2, 1).reshape(B, C, H, W) x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) # here x_.shape = (B, N/R^2, C) x_ = self.norm(x_) kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) else: kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) k, v = kv[0], kv[1] attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = self.proj(x) x = self.proj_drop(x) return x x = torch.rand(4, 28672, 256) attn = Attention(dim=256, sr_ratio = 2) output = attn(x, 224, 128)
边栏推荐
- 4278. 峰会
- Flink1.15 source code reading Flink clients client execution process (reading is boring)
- 接口测试工具-Postman使用详解
- NIO this.selector.select()
- 低成本、低门槛、易部署,4800+万户中小企业数字化转型新选择
- How to permanently set source
- pollFirst(),pollLast(),peekFirst(),peekLast()
- JWT authentication and login function implementation, exit login
- A survey of robust lidar based 3D object detection methods for autonomous driving paper notes
- 3311. Longest arithmetic
猜你喜欢

User management - restrictions

Mmrotate trains its dataset from scratch

软件测试功能测试全套常见面试题【功能测试-零基础】必备4-1

“寻源到结算“与“采购到付款“两者有什么不同或相似之处?

被三星和台积电挤压的Intel终放下身段,为中国芯片定制芯片工艺

【进程间通信IPC】- 信号量的学习

What are the differences or similarities between "demand fulfillment to settlement" and "purchase to payment"?

“鼓浪屿元宇宙”,能否成为中国文旅产业的“升级样本”

PVT的spatial reduction attention(SRA)

【Flutter -- GetX】准备篇
随机推荐
[flutter -- geTx] preparation
Flask login implementation
[nonebot2] several simple robot modules (Yiyan + rainbow fart + 60s per day)
接口测试工具-Jmeter压力测试使用
NIO示例
async/await的执行顺序以及宏任务和微任务
How to permanently set source
PyTorch自定义CUDA算子教程与运行时间分析
Is it safe to buy funds every day? Online and other answers
4276. 擅长C
CUDA programming-04: CUDA memory model
Low cost, low threshold, easy deployment, a new choice for the digital transformation of 48 million + small and medium-sized enterprises
4275. Dijkstra sequence
新年小目标!代码更规范!
flex布局 (实战小米官网)
B tree
4275. Dijkstra序列
Do a reptile project by yourself
【Flutter -- GetX】准备篇
Use of flask