当前位置:网站首页>Specific steps for yolov5 to add attention mechanism
Specific steps for yolov5 to add attention mechanism
2022-06-09 07:41:00 【Maple in Luobei Village】
This article takes CBAM and SE Take the process of adding attention mechanism as an example , It mainly introduces to YOLOv5 The specific steps of adding attention mechanism in
Based on this blog, this article presents YOLOv5-5.0 Add attention mechanism to version code
yolov5 model training ——— Use yolov5 Train your own dataset
This article mainly includes the following contents
YOLOv5 Adding attention mechanism can be divided into the following three steps :
1.common.py Add attention module to
2.yolo.py Add judgment conditions in
3.yaml Add corresponding modules to the file
One 、CBAM Attention mechanism added
(1) stay common.py Add callable CBAM modular
1. open models In folder common.py file 
2. The following CBAMC3 Copy and paste code into common.py In file
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.relu = nn.ReLU()
self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
# Write two , Sequential containers can also be used
# self.sharedMLP = nn.Sequential(
# nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(),
# nn.Conv2d(in_planes // rotio, in_planes, 1, bias=False))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
out = self.sigmoid(avg_out + max_out)
return torch.mul(x, out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avg_out, max_out], dim=1)
out = self.sigmoid(self.conv(out))
return torch.mul(x, out)
class CBAMC3(nn.Module):
# CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super(CBAMC3, self).__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1)
self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])
self.channel_attention = ChannelAttention(c2, 16)
self.spatial_attention = SpatialAttention(7)
# self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
def forward(self, x):
# The last standard convolution module is changed to the attention mechanism to extract features
return self.spatial_attention(
self.channel_attention(self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))))
As shown in the figure below , This article pastes it into common.py At the end of
( The picture is too long , So I intercepted two , And remember to click save )

(2) towards yolo.py File to add CBAMC3 Judgment statement
1. open models In folder yolo.py file 
2. Respectively in 218 Row sum 224 Line add CBAMC3, As shown in the figure below


Also remember to click save after the change
(3) modify yaml file
Attention mechanisms can be added to backbone,Neck,Head Other parts , You can yaml Modify the structure of the network in the file 、 Add other modules, etc , Next, this article will introduce the backbone network (backbone) add to CBAM Module as an example , This article introduces only one of the ways to add
1. stay yolov5-5.0 Under the project folder , find models Under folder yolov5s.yaml file 
2.backbone In the backbone network 4 individual C3 Module changed to CBAMC3, As shown in the figure below :

So here we are yolov5s Added to the backbone network CBAM Attention mechanism
( Run the modified code on the server , Remember to click... In the upper right corner of the text editor preservation )
Next, start training the model , We can see CBAMC3 The module has been successfully added to the backbone network 
Two 、SE Attention mechanism added
( Steps and CBAM be similar )
(1) stay common.py Add callable SE modular
1. open models In folder common.py file 
2. The following SE Copy and paste code into common.py In file
class SE(nn.Module):
def __init__(self, c1, c2, r=16):
super(SE, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.l1 = nn.Linear(c1, c1 // r, bias=False)
self.relu = nn.ReLU(inplace=True)
self.l2 = nn.Linear(c1 // r, c1, bias=False)
self.sig = nn.Sigmoid()
def forward(self, x):
print(x.size())
b, c, _, _ = x.size()
y = self.avgpool(x).view(b, c)
y = self.l1(y)
y = self.relu(y)
y = self.l2(y)
y = self.sig(y)
y = y.view(b, c, 1, 1)
return x * y.expand_as(x)
As shown in the figure below , This article pastes it into common.py At the end of 
(2) towards yolo.py File to add SE Judgment statement
1. open models In folder yolo.py file 
2. Respectively in 218 Row sum 224 Line add SE, As shown in the figure below


Also remember to click save after the change
(3) modify yaml file
Attention mechanisms can be added to backbone,Neck,Head Other parts , You can yaml Modify the structure of the network in the file 、 Add other modules, etc . And CBAM The same process as adding , Next, this article will introduce the backbone network (backbone) add to SE Module as an example , This article introduces only one of the ways to add
1. stay yolov5-5.0 Under the project folder , find models Under folder yolov5s.yaml file 
2.backbone Add the following code at the end of the backbone network , As shown in the figure below :
( Note that commas are in English , And pay attention to alignment )
[-1, 1, SE, [1024, 4]],


So here we are yolov5s Added to the backbone network SE Attention mechanism
( Run the modified code on the server , Remember to click... In the upper right corner of the text editor preservation )
Next, start training the model , We can see SE The module has been successfully added to the backbone network 
3、 ... and 、 Several other attention mechanism codes
The addition process will not be repeated , Imitate the top CBAM and SE The adding process of the
(1)ECA Attention mechanism code
class eca_layer(nn.Module):
"""Constructs a ECA module. Args: channel: Number of channels of the input feature map k_size: Adaptive selection of kernel size """
def __init__(self, channel, k_size=3):
super(eca_layer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# feature descriptor on the global spatial information
y = self.avg_pool(x)
# Two different branches of ECA module
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
# Multi-scale information fusion
y = self.sigmoid(y)
x=x*y.expand_as(x)
return x * y.expand_as(x)
(2)CA Attention mechanism code :
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
def forward(self, x):
return self.relu(x + 3) / 6
class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)
def forward(self, x):
return x * self.sigmoid(x)
class CoordAtt(nn.Module):
def __init__(self, inp, oup, reduction=32):
super(CoordAtt, self).__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
mip = max(8, inp // reduction)
self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(mip)
self.act = h_swish()
self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
def forward(self, x):
identity = x
n, c, h, w = x.size()
x_h = self.pool_h(x)
x_w = self.pool_w(x).permute(0, 1, 3, 2)
y = torch.cat([x_h, x_w], dim=2)
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)
x_h, x_w = torch.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()
out = identity * a_w * a_h
return out
边栏推荐
- Record some common C library functions and UNIX system calls
- Realization of C 2~36 hexadecimal number conversion
- [path of system analyst] Chapter 15 double disk database system (relational database application)
- What is SQA in cmm/cmmi? What is the relationship between SQA and testing
- Selenium: element positioning
- 多余的时间不要浪费,玩玩手机可开启“副业人生”
- 从Excel读取各种格式的日期转换为目标格式
- Google browser F12 (developer tool) -- function introduction
- 2022 examination questions and mock examination for the third batch of Guangdong Provincial Safety Officer C certificate (full-time safety production management personnel)
- UML summary
猜你喜欢

How do I add the hours of the current date in SQL Server- How to add hours to current date in SQL Server?

Apache 网页与安全优化

1. Talking about the system construction model -- one of the learning directions of advanced management post for testers
![[untitled]](/img/af/4d7370b72b0f558d5c8cd194c4ae3d.png)
[untitled]

Flutter learning Hello World

容器部署和无服务器部署那些事儿

PostgreSQL数据库复制——后台一等公民进程WalReceiver ready_to_display

Push related summary

推送相关的总结

Understand the whole test process with one diagram
随机推荐
Google browser F12 (developer tool) ---network
Fabric. JS activation input box
redis核心知识点总结(超详细)
Homebrew installing MySQL
知行之桥EDI系统Shopify端口的使用
Force to trigger epollin and epollout events
High precision personnel positioning system, power plant indoor positioning application solution
Apache配置与应用(构建web主机、日志分割及AWStats分析系统)
线程池的使用
[learning records of the first week in June] UU computer vision (1):3d Reconstruction & camera calibration
谈谈实施数据治理时常犯的十大错误
Cognition of clothing and textile industry
[path of system analyst] Chapter 18 security analysis and design of double disk system (encryption and decryption)
How about opening an account for shares of tongdaxin? Is it safe to open an account?
多余的时间不要浪费,玩玩手机可开启“副业人生”
mysql常见面试知识点
线程的调度、线程的优先级
[untitled]
ftp服务
PostgreSQL database replication - background first-class citizen process walreceiver ready_ to_ display