当前位置:网站首页>Yolov5 network modification tutorial (modify the backbone to efficientnet, mobilenet3, regnet, etc.)
Yolov5 network modification tutorial (modify the backbone to efficientnet, mobilenet3, regnet, etc.)
2022-07-02 04:11:00 【rglkt】
In my undergraduate thesis , I use the Yolov5, And try to change it . It can be done to Yolov5 Make some customized modifications , For example, lighter Yolov5-MobileNetv3 Or Yolov5s better ( In doubt , I haven't run through big data sets , You can experiment by yourself )Yolov5-EfficientNet.
First, before modifying , First look at Yolov5 Network structure . The whole looks very complicated , But don't panic , The main modifications of this article Backbone( Feature extraction network ) It can be abstracted into only three parts , That is, you only need to modify this place .
Then understand the code we need to modify . The code that needs to be modified mainly focuses on yolov5 Of model Under the folder .yaml It is mainly the corresponding configuration file after modifying the code .common.py Add a new module ,yolo.py In, the model can support reading the corresponding configuration file .
The previous introduction is over . Now we officially start to modify the model , The first step is to select some feature extraction networks with better performance , For example, as mentioned above MobileNet、EfficientNet etc. . In fact, the better performance of feature extraction network , Most of them have been down sampled three times or more , We can get three different size feature maps . stay Yolov5 These three size feature maps will be fused ,FPN and APN The operation of , It's not detailed here , The main thing to note is that the feature extraction network needs to extract three different size feature maps , We choose the output of the last three down samples of the feature extraction network to Yolov5 The Internet , It completes the modification of the feature extraction network .
With MobileNetv3-Small For example ( We don't even need to build our own network , Direct appropriation pytorch Official network , Choose from the following networks )pytorch Official website 
Output network structure , Observation network .mobilenetv3 There are mainly features、avgpool、classify Three parts , The functions are feature extraction 、 Global pooling 、 classifier . We only need to focus on the feature extraction part , And focus on the last three downsampling , So let's look forward from the end .

MobileNet The penultimate downsampling in occurs in the ninth module .( How to quickly see downsampling , The short answer is stride by 2 The place of . Of course, there are kernel_size be equal to 5 Or something else , But generally, relatively new networks kernel_size by 5 Along with that is 2 Of padding, So lazy can only watch stride) therefore 9-11 Corresponding YOLOv5 The penultimate downsampling .
The penultimate downsampling 4-8
The penultimate downsampling 0-3
After determining the network extraction method , The second step , stay common.py Add module at the end of . You can see that it's very simple , Mainly add MobileNet Three parts of .
from torchvision import models
class MobileNet1(nn.Module):
# out channel 24
def __init__(self, ignore) -> None:
super().__init__()
model = models.mobilenet_v3_small(pretrained=True)
modules = list(model.children())
modules = modules[0][:4]
self.model = nn.Sequential(*modules)
def forward(self, x):
return self.model(x)
class MobileNet2(nn.Module):
# out 48 channel
def __init__(self, ignore) -> None:
super().__init__()
model = models.mobilenet_v3_small(pretrained=True)
modules = list(model.children())
modules = modules[0][4:9]
self.model = nn.Sequential(*modules)
def forward(self, x):
return self.model(x)
class MobileNet3(nn.Module):
# out 576 channel
def __init__(self, ignore) -> None:
super().__init__()
model = models.mobilenet_v3_small(pretrained=True)
modules = list(model.children())
modules = modules[0][9:]
self.model = nn.Sequential(*modules)
def forward(self, x):
return self.model(x)
The third step , modify yolo.py Add this line of code in this section , It means parsing yaml Put the corresponding module .arg[0] Express yaml The first parameter following the module , This parameter should tell the model , The number of channels output by this module . You can go back up and have a look , The number of output channels of the three modules is 24、48、576.
Finally, add the model yaml, I choose to yolov5n Modify the prototype .
yolov5n
# YOLOv5 by Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
yolov5n-mobilenet
# YOLOv5 by Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, MobileNet1, [24]], # 0
[-1, 1, MobileNet2, [48]], # 1
[-1, 1, MobileNet3, [576]], # 2
[-1, 1, SPPF, [1024, 5]], # 3
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 1], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 7
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 0], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 11 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 7], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 14 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 3], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 17 (P5/32-large)
[[11, 14, 17], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
The revised words are actually easy to understand ,yolov5n Of back You can press it # Serial number of ,concat Is the lower sampling layer , Draw a cat like a gourd , Change the serial number to our module .
Finally using –cfg Call
python train.py --cfg yolov5n-mobileNet.yaml --weight yolov5n.pt

Just a quick note Yolov5-MobileNetv3 The performance of the ,GFLOPs That is, while the amount of computation is greatly reduced , Precision vs yolov5n The performance of the network without pre training is similar . however GPU The computing speed has not improved in the environment , Mainly due to SE The characteristics of the module , Don't go into detail , More suitable for CPU The mobile platform .
Show me , Only one number has been changed Yolov5 contributor . The next article will show you how to use TensorRT C++ Speed up yolov5.
边栏推荐
- Fingertips life Chapter 4 modules and packages
- Pytorch---使用Pytorch进行图像定位
- [source code analysis] NVIDIA hugectr, GPU version parameter server - (1)
- Yolov5网络修改教程(将backbone修改为EfficientNet、MobileNet3、RegNet等)
- Opencv learning example code 3.2.4 LUT
- 10 minutes to understand CMS garbage collector in JVM
- Wechat applet - realize the countdown of 60 seconds to obtain the mobile verification code (mobile number + verification code login function)
- SQL: common SQL commands
- Homework of the 16th week
- Typescript practice for SAP ui5
猜你喜欢

Cloud service selection of enterprises: comparative analysis of SaaS, PAAS and IAAs

The confusion I encountered when learning stm32

How should the team choose the feature branch development mode or trunk development mode?

MySQL advanced SQL statement 2

Play with concurrency: draw a thread state transition diagram

Yyds dry inventory compiler and compiler tools

【leetcode】34. Find the first and last positions of elements in a sorted array

Yolov5网络修改教程(将backbone修改为EfficientNet、MobileNet3、RegNet等)

Www 2022 | rethinking the knowledge map completion of graph convolution network

2022-07-01: at the annual meeting of a company, everyone is going to play a game of giving bonuses. There are a total of N employees. Each employee has construction points and trouble points. They nee
随机推荐
Wechat applet pull-down loading more waterfall flow loading
66.qt quick QML Custom Calendar component (supports vertical and horizontal screens)
Wechat applet JWT login issue token
Go branch and loop
[JS -- map string]
[untitled]
PR zero foundation introductory guide note 2
Wechat applet - realize the countdown of 60 seconds to obtain the mobile verification code (mobile number + verification code login function)
[personnel density detection] matlab simulation of personnel density detection based on morphological processing and GRNN network
go 包的使用
Yyds dry goods inventory kubernetes introduction foundation pod concept and related operations
Analysis of the overall design principle of Nacos configuration center (persistence, clustering, information synchronization)
Set vscode. When double clicking, the selected string includes the $symbol - convenient for PHP operation
Learn more about materialapp and common attribute parsing in fluent
LxC limits the number of CPUs
Target free or target specific: a simple and effective zero sample position detection comparative learning method
Play with concurrency: what's the use of interruptedexception?
Www2022 | know your way back: self training method of graph neural network under distribution and migration
初识P4语言
cookie、session、tooken