当前位置:网站首页>YOLOv5添加注意力機制
YOLOv5添加注意力機制
2022-07-05 05:17:00 【馬少爺】
YOLOv5加入注意力機制可分為以下三個步驟:
1.common.py中加入注意力模塊
2.yolo.py中增加判斷條件
3.yaml文件中添加相應模塊
一、CBAM注意力機制添加
(1)在common.py中添加可調用的CBAM模塊
1.打開models文件夾中的common.py文件
2.將下面的CBAMC3代碼複制粘貼到common.py文件中
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.relu = nn.ReLU()
self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
# 寫法二,亦可使用順序容器
# self.sharedMLP = nn.Sequential(
# nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(),
# nn.Conv2d(in_planes // rotio, in_planes, 1, bias=False))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
out = self.sigmoid(avg_out + max_out)
return torch.mul(x, out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avg_out, max_out], dim=1)
out = self.sigmoid(self.conv(out))
return torch.mul(x, out)
class CBAMC3(nn.Module):
# CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(CBAMC3, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1)
self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
self.channel_attention = ChannelAttention(c2, 16)
self.spatial_attention = SpatialAttention(7)
# self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
def forward(self, x):
# 將最後的標准卷積模塊改為了注意力機制提取特征
return self.spatial_attention(
self.channel_attention(self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))))
如下圖所示,本文這裏是將其粘貼到了common.py的末尾
(2)向yolo.py文件添加CBAMC3判斷語句
1.打開models文件夾中的yolo.py文件
2.分別在239行和245行添加CBAMC3,如下圖所示
同樣改完之後記得點保存
3)修改yaml文件
注意力機制可以添加在backbone,Neck,Head等部分,大家可以在yaml文件中修改網絡的結構、添加其他模塊等等,接下來本文將以向主幹網絡(backbone)添加CBAM模塊為例,本文介紹的只是其中一種添加方式
1.在yolov5-5.0工程文件夾下,找到models文件夾下的yolov5s.yaml文件
2.backbone主幹網絡中的4個C3模塊改為CBAMC3,如下圖所示:
這樣我們就在yolov5s主幹網絡中添加了CBAM注意力機制
接下來開始訓練模型,我們就可以看到CBAMC3模塊已經成功添加到主幹網絡中了
二、SE注意力機制添加
(步驟和CBAM相似)
(1)在common.py中添加可調用的SE模塊
1.打開models文件夾中的common.py文件
2.將下面的SE代碼複制粘貼到common.py文件中
class SE(nn.Module):
def __init__(self, c1, c2, r=16):
super(SE, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.l1 = nn.Linear(c1, c1 // r, bias=False)
self.relu = nn.ReLU(inplace=True)
self.l2 = nn.Linear(c1 // r, c1, bias=False)
self.sig = nn.Sigmoid()
def forward(self, x):
print(x.size())
b, c, _, _ = x.size()
y = self.avgpool(x).view(b, c)
y = self.l1(y)
y = self.relu(y)
y = self.l2(y)
y = self.sig(y)
y = y.view(b, c, 1, 1)
return x * y.expand_as(x)
如下圖所示,本文這裏是將其粘貼到了common.py的末尾
(2)向yolo.py文件添加SE判斷語句
1.打開models文件夾中的yolo.py文件
2.分別在238行和245行添加SE,如下圖所示
同樣改完之後記得點保存
(3)修改yaml文件
注意力機制可以添加在backbone,Neck,Head等部分,大家可以在yaml文件中修改網絡的結構、添加其他模塊等等。與CBAM的添加過程一樣,接下來本文將以向主幹網絡(backbone)添加SE模塊為例,本文介紹的只是其中一種添加方式
1.在yolov5-5.0工程文件夾下,找到models文件夾下的yolov5s.yaml文件
2.backbone主幹網絡末尾添加下面的代碼,如下圖所示:
(注意逗號是英文,以及注意對齊)
[-1, 1, SE, [1024, 4]],
這樣我們就在yolov5s主幹網絡中添加了SE注意力機制
(在服務器上跑代碼修改後,記得點擊文本編輯器右上角的保存)
接下來開始訓練模型,我們就可以看到SE模塊已經成功添加到主幹網絡中了
三、其他幾種注意力機制代碼
添加過程不再贅述,模仿上方CBAM和SE的添加過程即可
(1)ECA注意力機制代碼
class eca_layer(nn.Module):
"""Constructs a ECA module.
Args:
channel: Number of channels of the input feature map
k_size: Adaptive selection of kernel size
"""
def __init__(self, channel, k_size=3):
super(eca_layer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# feature descriptor on the global spatial information
y = self.avg_pool(x)
# Two different branches of ECA module
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
# Multi-scale information fusion
y = self.sigmoid(y)
x=x*y.expand_as(x)
return x * y.expand_as(x)
(2)CA注意力機制代碼:
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
def forward(self, x):
return self.relu(x + 3) / 6
class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)
def forward(self, x):
return x * self.sigmoid(x)
class CoordAtt(nn.Module):
def __init__(self, inp, oup, reduction=32):
super(CoordAtt, self).__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
mip = max(8, inp // reduction)
self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(mip)
self.act = h_swish()
self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
def forward(self, x):
identity = x
n, c, h, w = x.size()
x_h = self.pool_h(x)
x_w = self.pool_w(x).permute(0, 1, 3, 2)
y = torch.cat([x_h, x_w], dim=2)
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)
x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()
out = identity * a_w * a_h
return out
參考文獻:https://blog.csdn.net/thy0000/article/details/125016410
边栏推荐
猜你喜欢
UE4/UE5 虚幻引擎,材质篇(三),不同距离的材质优化
C language Essay 1
Simple modal box
Unity find the coordinates of a point on the circle
[to be continued] [UE4 notes] L3 import resources and project migration
[turn]: OSGi specification in simple terms
UE fantasy engine, project structure
Ue4/ue5 illusory engine, material part (III), material optimization at different distances
【论文笔记】Multi-Goal Reinforcement Learning: Challenging Robotics Environments and Request for Research
Merge sort
随机推荐
GBase数据库助力湾区数字金融发展
使用Room数据库报警告: Schema export directory is not provided to the annotation processor so we cannot expor
十年不用一次的JVM调用
Lua GBK and UTF8 turn to each other
Pause and resume of cocos2dx Lua scenario
Applet Live + e - commerce, si vous voulez être un nouveau e - commerce de détail, utilisez - le!
win10虚拟机集群优化方案
Unity and database
[sum of two numbers] 169 sum of two numbers II - enter an ordered array
Generate filled text and pictures
PMP考生,请查收7月PMP考试注意事项
National teacher qualification examination in the first half of 2022
Embedded database development programming (zero)
Romance of programmers on Valentine's Day
Transport connection management of TCP
Do a small pressure test with JMeter tool
[allocation problem] 135 Distribute candy
[depth first search] 695 Maximum area of the island
Data is stored in the form of table
Merge sort