当前位置:网站首页>Yolov3 network model building
Yolov3 network model building
2022-07-03 19:03:00 【The little girl is so cute】
The structure diagram is copied from others

import torch
import torch.nn as nn
from collections import OrderedDict
class CBL(nn.Module):
def __init__(self, channel_in, channel_out, ks, p=1, strides=(1, 1)):
super(CBL, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(channel_in, channel_out, ks, padding=p),
nn.BatchNorm2d(channel_out),
nn.LeakyReLU(0.1)
)
def forward(self, x):
return self.block(x)
class ResidualBlock(nn.Module):
# Using a 1x1 Number of convolution descent channels ===> Using a 3x3 Convolution feature extraction ===> Using a 1x1 Number of convolution rising channels
def __init__(self, inp, planes):
""" This class needs to pass in two parameters , One is 【input_channel,output_channel】"""
super(ResidualBlock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(inp, planes[0], kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(planes[0]),
nn.LeakyReLU(0.1))
self.conv2 = nn.Sequential(
nn.Conv2d(planes[0], planes[1], kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(planes[1]),
nn.LeakyReLU(0.1))
self.conv3 = nn.Sequential(
nn.Conv2d(planes[1], planes[1], kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(planes[1]),
nn.LeakyReLU(0.1))
def forward(self, inputs):
x = self.conv1(inputs)
x = self.conv2(x)
x = self.conv3(x)
return (x + inputs)
class DarkNet(nn.Module):
def __init__(self, blocks):
super(DarkNet, self).__init__()
self.inp = 32
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=self.inp, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(self.inp),
nn.LeakyReLU(0.1))
self.block1 = nn.Sequential(self._make_residual_layer(planes=[32, 64], block=blocks[0]),
self._make_residual_layer(planes=[64, 128], block=blocks[1]),
self._make_residual_layer(planes=[128, 256], block=blocks[2]))
self.block2 = self._make_residual_layer(planes=[256, 512], block=blocks[3])
self.block3 = self._make_residual_layer(planes=[512, 1024], block=blocks[4])
def forward(self, inputs):
#########################
# 1.Darknet part
#########################
x = self.conv1(inputs)
x = self.block1(x) # Including the first three res
feat1 = x
x = self.block2(x)
feat2 = x
x = self.block3(x)
feat3 = x
return feat1, feat2, feat3
def _make_residual_layer(self, planes, block):
# planes yes 【input_channel,output_channel】,blocks Is the number of times the residual is repeated
layers = []
for i in range(block):
if i < 1: # For the first time, because channel Dissimilarity , Therefore, the number of channels needs to be adjusted
layers.append(
("ds_conv", nn.Conv2d(self.inp, planes[1], kernel_size=3, stride=2, padding=1, bias=False)))
layers.append(("ds_bn", nn.BatchNorm2d(planes[1])))
layers.append(("ds_relu", nn.LeakyReLU(0.1)))
self.inp = planes[1]
else:
layers.append(("residual_{}".format(i), ResidualBlock(self.inp, planes)))
return nn.Sequential(OrderedDict(layers))
class YoloBody(nn.Module):
def __init__(self, num_classes):
super(YoloBody, self).__init__()
self.darknet = DarkNet([1, 2, 8, 8, 4])
self.num_classes = num_classes
self.up1 = nn.Sequential(
CBL(512, 256, 1, 0),
nn.UpsamplingBilinear2d(scale_factor=2)
)
self.up2 = nn.Sequential(
CBL(256, 128, 1, 0),
nn.UpsamplingBilinear2d(scale_factor=2)
)
def forward(self, inputs):
feat1, feat2, feat3 = self.darknet(inputs)
# print(feat1.shape)
#########################
# 2. Strengthen feature extraction
#########################
big_map, x = self.make_five_conv(feat3, 1024, 512)
big_output = self.output_conv(big_map, 512, (self.num_classes+5)*3)
x = self.up1(x)
x = torch.cat([feat2, x],dim=1)
middel_map, x = self.make_five_conv(x, 768, 256)
middel_output = self.output_conv(middel_map, 256, (self.num_classes+5)*3)
x = self.up2(x)
x = torch.cat([feat1, x], dim=1)
small_map, _ = self.make_five_conv(x, 384, 128)
small_output = self.output_conv(small_map, 128, (self.num_classes+5)*3)
print(big_output.shape, middel_output.shape, small_output.shape)
return (big_output, middel_output, small_output)
def make_five_conv(self, x, channel_in, channel_out):
x = CBL(channel_in, channel_out, ks=1, p=0)(x)
x = CBL(channel_out, channel_out, ks=3, p=1)(x)
x = CBL(channel_out, channel_out, ks=1, p=0)(x)
x = CBL(channel_out, channel_out, ks=3, p=1)(x)
x = CBL(channel_out, channel_out, ks=1, p=0)(x)
return x,x
def output_conv(self, x, channel_in, channel_out):
x = CBL(channel_in, channel_out, ks=3)(x)
y = nn.Conv2d(channel_out, channel_out, kernel_size=1, padding=0)(x)
return y
if __name__ == '__main__':
inputs = torch.zeros(size=(1,3, 416, 416))
print(inputs.shape)
# model = DarkNet(blocks=[1, 2, 8, 8, 4])
model = YoloBody(num_classes=20)
out = model.forward(inputs)
边栏推荐
- flask 生成swagger文档
- Chisel tutorial - 06 Phased summary: implement an FIR filter (chisel implements 4-bit FIR filter and parameterized FIR filter)
- 【光学】基于matlab介电常数计算【含Matlab源码 1926期】
- Ego planner code parsing Bspline_ Optimizer section (2)
- PyTorch中在反向传播前为什么要手动将梯度清零?
- KINGS
- leetcode:11. Container with the most water [double pointer + greed + remove the shortest board]
- Day-27 database
- 虚拟机和开发板互Ping问题
- [Yu Yue education] world reference materials of Microbiology in Shanghai Jiaotong University
猜你喜欢

235. 二叉搜索树的最近公共祖先【lca模板 + 找路径相同】

application

为什么要做特征的归一化/标准化?

Recommend a simple browser tab

leetcode:11. 盛最多水的容器【双指针 + 贪心 + 去除最短板】

Record the errors reported when running fluent in the simulator

Multifunctional web file manager filestash

We have built an intelligent retail settlement platform
![[leetcode weekly race] game 300 - 6110 Number of incremental paths in the grid graph - difficult](/img/8d/0e515af6c17971ddf461e3f3b87c30.png)
[leetcode weekly race] game 300 - 6110 Number of incremental paths in the grid graph - difficult

Kratos微服务框架下实现CQRS架构模式
随机推荐
SSM整合-前后台协议联调(列表功能、添加功能、添加功能状态处理、修改功能、删除功能)
Shell script return value with which output
Record: pymysql is used in pycharm to connect to the database
Nous avons fait une plateforme intelligente de règlement de détail
达梦数据库的物理备份和还原简解
Change is the eternal theme
Transformer T5 model read slowly
How to disable the clear button of ie10 insert text box- How can I disable the clear button that IE10 inserts into textboxes?
[leetcode] [SQL] notes
Integrated easy to pay secondary domain name distribution system
DriveSeg:动态驾驶场景分割数据集
Caddy server agent
High concurrency architecture cache
硬盘监控和分析工具:Smartctl
SSM integration - joint debugging of front and rear protocols (list function, add function, add function status processing, modify function, delete function)
记录在模拟器中运行flutter时报的错
【LeetCode】【SQL】刷题笔记
Typescript configuration
Help change the socket position of PCB part
A green plug-in that allows you to stay focused, live and work hard