当前位置:网站首页>Yolov5网络修改教程(将backbone修改为EfficientNet、MobileNet3、RegNet等)
Yolov5网络修改教程(将backbone修改为EfficientNet、MobileNet3、RegNet等)
2022-07-02 04:04:00 【rglkt】
在我的本科毕业论文中,我使用了Yolov5,并尝试对其更改。可以对Yolov5进行一定程度的定制化修改,例如更轻量级的Yolov5-MobileNetv3 或者比Yolov5s更好的(存疑,没有跑过大数据集,可自己实验)Yolov5-EfficientNet。
首先在修改之前,先看Yolov5的网络结构。整体看起来很复杂,但是不用慌张,本篇文章的主要修改处Backbone(特征提取网络)可以抽象为只有三部分,也就是只需要修改这一处地方即可。
然后了解我们需要修改的代码。需要修改代码主要集中在yolov5的model文件夹下。yaml主要是修改代码后相对应的配置文件。common.py中添加新的模块,yolo.py中则是让模型能够支持读取相应的配置文件。
前情介绍完毕。下面正式开始进行模型修改,第一步是选择一些性能比较好的特征提取网络,比如前文提到的MobileNet、EfficientNet等。其实表现的比较好的特征提取网络,大部分都经过三次及以上的下采样,能够得到三种不同大小的特征图。在Yolov5中会将这三种大小特征图进行特征融合,FPN和APN的操作,这里不详细展开,主要需要注意的是特征提取网络需要提取出三种不同大小的特征图,我们选择特征提取网络的最后三次下采样的输出给Yolov5网络,就完成了特征提取网络的修改。
以MobileNetv3-Small为例(我们甚至不需要自己搭建网络,直接挪用pytorch官方网络,以下网络任君选择)pytorch官网
输出网络结构,观察网络。mobilenetv3中主要分为features、avgpool、classify三部分组成,作用分别为特征提取、全局池化、分类器。我们只需要关注特征提取部分,并且着重关注于最后三次降采样部分,所以我们从最后开始往前进行观察。

MobileNet中的倒数第一次下采样发生在第九个模块。(如何快速看到降采样,简单来讲就是stride为2的地方。当然实际还有kernel_size等于5或者其他情况,但是一般比较新的网络kernel_size为5伴随的还有2的padding,所以偷懒可以只看stride) 因此9-11对应YOLOv5倒数第一次降采样。
倒数第二次降采样4-8
倒数第三次降采样0-3
确定好网络提取方式后,第二步,在common.py中最后添加模块。可以看到非常简单,主要添加MobileNet的三个部分。
from torchvision import models
class MobileNet1(nn.Module):
# out channel 24
def __init__(self, ignore) -> None:
super().__init__()
model = models.mobilenet_v3_small(pretrained=True)
modules = list(model.children())
modules = modules[0][:4]
self.model = nn.Sequential(*modules)
def forward(self, x):
return self.model(x)
class MobileNet2(nn.Module):
# out 48 channel
def __init__(self, ignore) -> None:
super().__init__()
model = models.mobilenet_v3_small(pretrained=True)
modules = list(model.children())
modules = modules[0][4:9]
self.model = nn.Sequential(*modules)
def forward(self, x):
return self.model(x)
class MobileNet3(nn.Module):
# out 576 channel
def __init__(self, ignore) -> None:
super().__init__()
model = models.mobilenet_v3_small(pretrained=True)
modules = list(model.children())
modules = modules[0][9:]
self.model = nn.Sequential(*modules)
def forward(self, x):
return self.model(x)
第三步,修改yolo.py 在这部分添加这行代码,意思是在解析yaml时放入相应的模块。arg[0]表示yaml模块后跟着的第一个参数,这个参数要告诉模型,此模块输出的通道数。可以回到上面看一看,三个模块的输出通道数为24、48、576。
最后添加模型的yaml,我选择以yolov5n为原型进行修改。
yolov5n
# YOLOv5 by Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
yolov5n-mobilenet
# YOLOv5 by Ultralytics, GPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, MobileNet1, [24]], # 0
[-1, 1, MobileNet2, [48]], # 1
[-1, 1, MobileNet3, [576]], # 2
[-1, 1, SPPF, [1024, 5]], # 3
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 1], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 7
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 0], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 11 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 7], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 14 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 3], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 17 (P5/32-large)
[[11, 14, 17], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
修改的话其实很好理解,yolov5n的back可以按着#的序号来数,concat的就是下采样层,照葫芦画猫,序号改成我们的模块即可。
最后使用–cfg调用即可
python train.py --cfg yolov5n-mobileNet.yaml --weight yolov5n.pt

简单讲一下Yolov5-MobileNetv3的表现,GFLOPs即运算量大幅度减少的同时,精度与yolov5n未使用预训练网络的性能相近。但是GPU环境下运算速度没有提升,主要由于SE模块的特点,不展开细讲,更适合CPU移动平台。
小秀一下,只更改过一个数字的Yolov5贡献者。下一篇文章将介绍如何使用TensorRT C++加速yolov5.
边栏推荐
- The 11th Blue Bridge Cup single chip microcomputer provincial competition
- Target free or target specific: a simple and effective zero sample position detection comparative learning method
- Lost a few hairs, and finally learned - graph traversal -dfs and BFS
- [personal notes] PHP common functions - custom functions
- QT designer plug-in implementation of QT plug-in
- The fourth provincial competition of Bluebridge cup single chip microcomputer
- Basic operations of MySQL database (based on tables)
- 文档声明与字符编码
- okcc为什么云呼叫中心比传统呼叫中心更好?
- Welcome the winter vacation multi school league game 2 partial solution (B, C, D, F, G, H)
猜你喜欢

First acquaintance with string+ simple usage (II)

pip 安装第三方库

WPViewPDF Delphi 和 .NET 的 PDF 查看组件

MySQL error: expression 1 of select list is not in group by claim and contains nonaggre

66.qt quick QML Custom Calendar component (supports vertical and horizontal screens)

Installation and use of blue lake

Target free or target specific: a simple and effective zero sample position detection comparative learning method

Jetpack之LiveData扩展MediatorLiveData

How to model noise data? Hong Kong Baptist University's latest review paper on "label noise representation learning" comprehensively expounds the data, objective function and optimization strategy of

Delete the code you wrote? Sentenced to 10 months!
随机推荐
A thorough understanding of the development of scorecards - the determination of Y (Vintage analysis, rolling rate analysis, etc.)
Suggestions on settlement solution of u standard contract position explosion
66.qt quick-qml自定义日历组件(支持竖屏和横屏)
Uni app - realize the countdown of 60 seconds to obtain the mobile verification code (mobile number + verification code login function)
Which is better, industrial intelligent gateway or edge computing gateway? How to choose the right one?
《西线无战事》我们才刚开始热爱生活,却不得不对一切开炮
Welcome the winter vacation multi school league game 2 partial solution (B, C, D, F, G, H)
The 8th Blue Bridge Cup single chip microcomputer provincial competition
Typescript practice for SAP ui5
【leetcode】34. Find the first and last positions of elements in a sorted array
树莓派GPIO引脚控制红绿灯与轰鸣器
Lei Jun wrote a blog when he was a programmer. It's awesome
Is the product of cancer prevention medical insurance safe?
Wechat applet pull-down loading more waterfall flow loading
Basic operations of MySQL database (based on tables)
Qt插件之Qt Designer插件实现
Introduction to vmware workstation and vSphere
Spring recruitment of Internet enterprises: Kwai meituan has expanded the most, and the annual salary of technical posts is up to nearly 400000
The first practical project of software tester: web side (video tutorial + document + use case library)
[tips] use Matlab GUI to read files in dialog mode