当前位置:网站首页>Yolov5 network modification tutorial (modify the backbone to efficientnet, mobilenet3, regnet, etc.)
Yolov5 network modification tutorial (modify the backbone to efficientnet, mobilenet3, regnet, etc.)
2022-07-02 04:11:00 【rglkt】
In my undergraduate thesis , I use the Yolov5, And try to change it . It can be done to Yolov5 Make some customized modifications , For example, lighter Yolov5-MobileNetv3 Or Yolov5s better ( In doubt , I haven't run through big data sets , You can experiment by yourself )Yolov5-EfficientNet.
First, before modifying , First look at Yolov5 Network structure . The whole looks very complicated , But don't panic , The main modifications of this article Backbone( Feature extraction network ) It can be abstracted into only three parts , That is, you only need to modify this place .
Then understand the code we need to modify . The code that needs to be modified mainly focuses on yolov5 Of model Under the folder .yaml It is mainly the corresponding configuration file after modifying the code .common.py Add a new module ,yolo.py In, the model can support reading the corresponding configuration file .
The previous introduction is over . Now we officially start to modify the model , The first step is to select some feature extraction networks with better performance , For example, as mentioned above MobileNet、EfficientNet etc. . In fact, the better performance of feature extraction network , Most of them have been down sampled three times or more , We can get three different size feature maps . stay Yolov5 These three size feature maps will be fused ,FPN and APN The operation of , It's not detailed here , The main thing to note is that the feature extraction network needs to extract three different size feature maps , We choose the output of the last three down samples of the feature extraction network to Yolov5 The Internet , It completes the modification of the feature extraction network .
With MobileNetv3-Small For example ( We don't even need to build our own network , Direct appropriation pytorch Official network , Choose from the following networks )pytorch Official website 
Output network structure , Observation network .mobilenetv3 There are mainly features、avgpool、classify Three parts , The functions are feature extraction 、 Global pooling 、 classifier . We only need to focus on the feature extraction part , And focus on the last three downsampling , So let's look forward from the end .

MobileNet The penultimate downsampling in occurs in the ninth module .( How to quickly see downsampling , The short answer is stride by 2 The place of . Of course, there are kernel_size be equal to 5 Or something else , But generally, relatively new networks kernel_size by 5 Along with that is 2 Of padding, So lazy can only watch stride) therefore 9-11 Corresponding YOLOv5 The penultimate downsampling .
The penultimate downsampling 4-8
The penultimate downsampling 0-3
After determining the network extraction method , The second step , stay common.py Add module at the end of . You can see that it's very simple , Mainly add MobileNet Three parts of .
from torchvision import models
class MobileNet1(nn.Module):
# out channel 24
def __init__(self, ignore) -> None:
super().__init__()
model = models.mobilenet_v3_small(pretrained=True)
modules = list(model.children())
modules = modules[0][:4]
self.model = nn.Sequential(*modules)
def forward(self, x):
return self.model(x)
class MobileNet2(nn.Module):
# out 48 channel
def __init__(self, ignore) -> None:
super().__init__()
model = models.mobilenet_v3_small(pretrained=True)
modules = list(model.children())
modules = modules[0][4:9]
self.model = nn.Sequential(*modules)
def forward(self, x):
return self.model(x)
class MobileNet3(nn.Module):
# out 576 channel
def __init__(self, ignore) -> None:
super().__init__()
model = models.mobilenet_v3_small(pretrained=True)
modules = list(model.children())
modules = modules[0][9:]
self.model = nn.Sequential(*modules)
def forward(self, x):
return self.model(x)
The third step , modify yolo.py Add this line of code in this section , It means parsing yaml Put the corresponding module .arg[0] Express yaml The first parameter following the module , This parameter should tell the model , The number of channels output by this module . You can go back up and have a look , The number of output channels of the three modules is 24、48、576.
Finally, add the model yaml, I choose to yolov5n Modify the prototype .
yolov5n
# YOLOv5 by Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
yolov5n-mobilenet
# YOLOv5 by Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, MobileNet1, [24]], # 0
[-1, 1, MobileNet2, [48]], # 1
[-1, 1, MobileNet3, [576]], # 2
[-1, 1, SPPF, [1024, 5]], # 3
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 1], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 7
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 0], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 11 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 7], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 14 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 3], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 17 (P5/32-large)
[[11, 14, 17], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
The revised words are actually easy to understand ,yolov5n Of back You can press it # Serial number of ,concat Is the lower sampling layer , Draw a cat like a gourd , Change the serial number to our module .
Finally using –cfg Call
python train.py --cfg yolov5n-mobileNet.yaml --weight yolov5n.pt

Just a quick note Yolov5-MobileNetv3 The performance of the ,GFLOPs That is, while the amount of computation is greatly reduced , Precision vs yolov5n The performance of the network without pre training is similar . however GPU The computing speed has not improved in the environment , Mainly due to SE The characteristics of the module , Don't go into detail , More suitable for CPU The mobile platform .
Show me , Only one number has been changed Yolov5 contributor . The next article will show you how to use TensorRT C++ Speed up yolov5.
边栏推荐
- Wechat applet pull-down loading more waterfall flow loading
- Wpviewpdf Delphi and Net PDF viewing component
- 第十六周作业
- Lei Jun wrote a blog when he was a programmer. It's awesome
- 5G時代全面到來,淺談移動通信的前世今生
- uni-app - 实现获取手机验证码倒计时 60 秒(手机号+验证码登录功能)
- Nacos 配置中心整体设计原理分析(持久化,集群,信息同步)
- 藍湖的安裝及使用
- Today's plan: February 15, 2022
- Go language naming specification
猜你喜欢

Qt插件之Qt Designer插件实现

The original author is out! Faker. JS has been controlled by the community..

Go语言介绍

【IBDFE】基于IBDFE的频域均衡matlab仿真

Use a mask to restrict the input of the qlineedit control

Learn more about materialapp and common attribute parsing in fluent

How much can a job hopping increase? Today, I saw the ceiling of job hopping.

"No war on the Western Front" we just began to love life, but we had to shoot at everything

66.qt quick-qml自定义日历组件(支持竖屏和横屏)
![[source code analysis] NVIDIA hugectr, GPU version parameter server - (1)](/img/e3/fc2e78dc1e3e3cacbd1a389c82d33e.jpg)
[source code analysis] NVIDIA hugectr, GPU version parameter server - (1)
随机推荐
Monkey测试
Is the product of cancer prevention medical insurance safe?
第十六周作业
整理了一份ECS夏日省钱秘籍,这次@老用户快来领走
Go语言介绍
5G時代全面到來,淺談移動通信的前世今生
How much is the tuition fee of SCM training class? How long is the study time?
office_ Delete the last page of word (the seemingly blank page)
【力扣刷题】15.三数之和(双指针);17.电话号码的字母组合(递归回溯)
LxC limits the number of CPUs
QT designer plug-in implementation of QT plug-in
树莓派GPIO引脚控制红绿灯与轰鸣器
Target free or target specific: a simple and effective zero sample position detection comparative learning method
Is it safe to open an account with first venture securities? I like to open an account. How can I open it?
Lei Jun wrote a blog when he was a programmer. It's awesome
MySQL advanced SQL statement 2
PIP installation of third-party libraries
Use a mask to restrict the input of the qlineedit control
go 语言命名规范
regular expression